完善插件

This commit is contained in:
李寻欢
2023-12-11 10:44:23 +08:00
parent 7e545cef95
commit e6c0bfe2cc
12 changed files with 320 additions and 72 deletions

143
plugin/plugin.go Normal file
View File

@@ -0,0 +1,143 @@
package plugin
import (
"go-wechat/model"
)
// MessageHandler 消息处理函数
type MessageHandler func(msg *model.Message)
// MessageDispatcher 消息分发处理接口
// 跟 DispatchMessage 结合封装成 MessageHandler
type MessageDispatcher interface {
Dispatch(msg *model.Message)
}
// DispatchMessage 跟 MessageDispatcher 结合封装成 MessageHandler
func DispatchMessage(dispatcher MessageDispatcher) func(msg *model.Message) {
return func(msg *model.Message) { dispatcher.Dispatch(msg) }
}
// MessageDispatcher impl
// MessageContextHandler 消息处理函数
type MessageContextHandler func(ctx *MessageContext)
type MessageContextHandlerGroup []MessageContextHandler
// MessageContext 消息处理上下文对象
type MessageContext struct {
index int
abortIndex int
messageHandlers MessageContextHandlerGroup
*model.Message
}
// Next 主动调用下一个消息处理函数(或开始调用)
func (c *MessageContext) Next() {
c.index++
for c.index <= len(c.messageHandlers) {
if c.IsAbort() {
return
}
handle := c.messageHandlers[c.index-1]
handle(c)
c.index++
}
}
// IsAbort 判断是否被中断
func (c *MessageContext) IsAbort() bool {
return c.abortIndex > 0
}
// Abort 中断当前消息处理, 不会调用下一个消息处理函数, 但是不会中断当前的处理函数
func (c *MessageContext) Abort() {
c.abortIndex = c.index
}
// AbortHandler 获取当前中断的消息处理函数
func (c *MessageContext) AbortHandler() MessageContextHandler {
if c.abortIndex > 0 {
return c.messageHandlers[c.abortIndex-1]
}
return nil
}
// MatchFunc 消息匹配函数,返回为true则表示匹配
type MatchFunc func(*model.Message) bool
// MatchFuncList 将多个MatchFunc封装成一个MatchFunc
func MatchFuncList(matchFuncs ...MatchFunc) MatchFunc {
return func(message *model.Message) bool {
for _, matchFunc := range matchFuncs {
if !matchFunc(message) {
return false
}
}
return true
}
}
type matchNode struct {
matchFunc MatchFunc
group MessageContextHandlerGroup
}
type matchNodes []*matchNode
// MessageMatchDispatcher impl MessageDispatcher interface
//
// dispatcher := NewMessageMatchDispatcher()
// dispatcher.OnText(func(msg *model.Message){
// msg.ReplyText("hello")
// })
// bot := DefaultBot()
// bot.MessageHandler = DispatchMessage(dispatcher)
type MessageMatchDispatcher struct {
async bool
matchNodes matchNodes
}
// NewMessageMatchDispatcher Constructor
func NewMessageMatchDispatcher() *MessageMatchDispatcher {
return &MessageMatchDispatcher{}
}
// SetAsync 设置是否异步处理
func (m *MessageMatchDispatcher) SetAsync(async bool) {
m.async = async
}
// Dispatch impl MessageDispatcher
// 遍历 MessageMatchDispatcher 所有的消息处理函数
// 获取所有匹配上的函数
// 执行处理的消息处理方法
func (m *MessageMatchDispatcher) Dispatch(msg *model.Message) {
var group MessageContextHandlerGroup
for _, node := range m.matchNodes {
if node.matchFunc(msg) {
group = append(group, node.group...)
}
}
ctx := &MessageContext{Message: msg, messageHandlers: group}
if m.async {
go m.do(ctx)
} else {
m.do(ctx)
}
}
func (m *MessageMatchDispatcher) do(ctx *MessageContext) {
ctx.Next()
}
// RegisterHandler 注册消息处理函数, 根据自己的需求自定义
// matchFunc返回true则表示处理对应的handlers
func (m *MessageMatchDispatcher) RegisterHandler(matchFunc MatchFunc, handlers ...MessageContextHandler) {
if matchFunc == nil {
panic("MatchFunc can not be nil")
}
node := &matchNode{matchFunc: matchFunc, group: handlers}
m.matchNodes = append(m.matchNodes, node)
}

172
plugin/plugins/ai.go Normal file
View File

@@ -0,0 +1,172 @@
package plugins
import (
"context"
"fmt"
"github.com/duke-git/lancet/v2/slice"
"github.com/sashabaranov/go-openai"
"go-wechat/client"
"go-wechat/common/current"
"go-wechat/config"
"go-wechat/entity"
"go-wechat/plugin"
"go-wechat/service"
"go-wechat/types"
"go-wechat/utils"
"log"
"regexp"
"strings"
"time"
)
// AI
// @description: AI消息
// @param m
func AI(m *plugin.MessageContext) {
if !config.Conf.Ai.Enable {
return
}
// 取出所有启用了AI的好友或群组
var count int64
client.MySQL.Model(&entity.Friend{}).Where("enable_ai IS TRUE").Where("wxid = ?", m.FromUser).Count(&count)
if count < 1 {
return
}
// 预处理一下发送的消息,用正则去掉@机器人的内容
re := regexp.MustCompile(`@([^| ]+)`)
matches := re.FindStringSubmatch(m.Content)
if len(matches) > 0 {
// 过滤掉第一个匹配到的
m.Content = strings.Replace(m.Content, matches[0], "", 1)
}
// 组装消息体
messages := make([]openai.ChatCompletionMessage, 0)
if config.Conf.Ai.Personality != "" {
// 填充人设
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: config.Conf.Ai.Personality,
})
}
// 查询发信人前面几条文字信息,组装进来
var oldMessages []entity.Message
if m.GroupUser == "" {
// 私聊
oldMessages = getUserPrivateMessages(m.FromUser)
} else {
// 群聊
oldMessages = getGroupUserMessages(m.MsgId, m.FromUser, m.GroupUser)
}
// 翻转数组
slice.Reverse(oldMessages)
// 循环填充消息
for _, message := range oldMessages {
// 剔除@机器人的内容
msgStr := message.Content
matches = re.FindStringSubmatch(msgStr)
if len(matches) > 0 {
// 过滤掉第一个匹配到的
msgStr = strings.Replace(msgStr, matches[0], "", 1)
}
// 填充消息
role := openai.ChatMessageRoleUser
if message.FromUser == current.GetRobotInfo().WxId {
// 如果收信人不是机器人,表示这条消息是 AI 发的
role = openai.ChatMessageRoleAssistant
}
messages = append(messages, openai.ChatCompletionMessage{
Role: role,
Content: msgStr,
})
}
// 填充用户消息
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: m.Content,
})
// 配置模型
chatModel := openai.GPT3Dot5Turbo0613
if config.Conf.Ai.Model != "" {
chatModel = config.Conf.Ai.Model
}
// 默认使用AI回复
conf := openai.DefaultConfig(config.Conf.Ai.ApiKey)
if config.Conf.Ai.BaseUrl != "" {
conf.BaseURL = fmt.Sprintf("%s/v1", config.Conf.Ai.BaseUrl)
}
ai := openai.NewClientWithConfig(conf)
resp, err := ai.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: chatModel,
Messages: messages,
},
)
if err != nil {
log.Printf("OpenAI聊天发起失败: %v", err.Error())
utils.SendMessage(m.FromUser, m.GroupUser, "AI炸啦~", 0)
return
}
// 保存一下AI 返回的消息,消息 Id 使用传入 Id 的负数
var replyMessage entity.Message
replyMessage.MsgId = -m.MsgId
replyMessage.CreateTime = int(time.Now().Local().Unix())
replyMessage.CreateAt = time.Now().Local()
replyMessage.Content = resp.Choices[0].Message.Content
replyMessage.FromUser = current.GetRobotInfo().WxId // 发信人是机器人
replyMessage.GroupUser = m.GroupUser // 群成员
replyMessage.ToUser = m.FromUser // 收信人是发信人
replyMessage.Type = types.MsgTypeText
service.SaveMessage(replyMessage) // 保存消息
// 发送消息
replyMsg := resp.Choices[0].Message.Content
if m.GroupUser != "" {
replyMsg = "\n" + resp.Choices[0].Message.Content
}
utils.SendMessage(m.FromUser, m.GroupUser, replyMsg, 0)
}
// getGroupUserMessages
// @description: 获取群成员消息
// @return records
func getGroupUserMessages(msgId int64, groupId, groupUserId string) (records []entity.Message) {
subQuery := client.MySQL.
Where("from_user = ? AND group_user = ? AND display_full_content LIKE ?", groupId, groupUserId, "%在群聊中@了你").
Or("to_user = ? AND group_user = ?", groupId, groupUserId)
client.MySQL.Model(&entity.Message{}).
Where("msg_id != ?", msgId).
Where("type = ?", types.MsgTypeText).
Where("create_at >= DATE_SUB(NOW(),INTERVAL 30 MINUTE)").
Where(subQuery).
Order("create_at desc").
Limit(4).Find(&records)
return
}
// getUserPrivateMessages
// @description: 获取用户私聊消息
// @return records
func getUserPrivateMessages(userId string) (records []entity.Message) {
subQuery := client.MySQL.
Where("from_user = ?", userId).Or("to_user = ?", userId)
client.MySQL.Model(&entity.Message{}).
Where("type = ?", types.MsgTypeText).
Where("create_at >= DATE_SUB(NOW(),INTERVAL 30 MINUTE)").
Where(subQuery).
Order("create_at desc").
Limit(4).Find(&records)
return
}

27
plugin/plugins/save2db.go Normal file
View File

@@ -0,0 +1,27 @@
package plugins
import (
"go-wechat/entity"
"go-wechat/plugin"
"go-wechat/service"
"time"
)
// SaveToDb
// @description: 保存消息到数据库
// @param m
func SaveToDb(m *plugin.MessageContext) {
var ent entity.Message
ent.MsgId = m.MsgId
ent.CreateTime = m.CreateTime
ent.CreateAt = time.Unix(int64(m.CreateTime), 0)
ent.Content = m.Content
ent.FromUser = m.FromUser
ent.GroupUser = m.GroupUser
ent.ToUser = m.ToUser
ent.Type = m.Type
ent.DisplayFullContent = m.DisplayFullContent
ent.Raw = m.Raw
// 保存入库
service.SaveMessage(ent)
}

View File

@@ -0,0 +1,39 @@
package plugins
import (
"go-wechat/client"
"go-wechat/config"
"go-wechat/entity"
"go-wechat/plugin"
"go-wechat/utils"
)
// WelcomeNew
// @description: 欢迎新成员
// @param m
func WelcomeNew(m *plugin.MessageContext) {
// 判断是否开启迎新
var count int64
client.MySQL.Model(&entity.Friend{}).Where("enable_welcome IS TRUE").Where("wxid = ?", m.FromUser).Count(&count)
if count < 1 {
return
}
// 读取欢迎新成员配置
conf, ok := config.Conf.Resource["welcome-new"]
if !ok {
// 未配置,跳过
return
}
switch conf.Type {
case "text":
// 文字类型
utils.SendMessage(m.FromUser, "", conf.Path, 0)
case "image":
// 图片类型
utils.SendImage(m.FromUser, conf.Path, 0)
case "emotion":
// 表情类型
utils.SendEmotion(m.FromUser, conf.Path, 0)
}
}