🎨 优化扩展模块,完成ai接入和对话功能
This commit is contained in:
470
server/api/v1/app/chat.go
Normal file
470
server/api/v1/app/chat.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/middleware"
|
||||
appModel "git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/common/response"
|
||||
appService "git.echol.cn/loser/st/server/service/app"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ChatApi struct{}
|
||||
|
||||
// ==================== 对话管理 ====================
|
||||
|
||||
// CreateChat 创建对话
|
||||
// @Tags 对话
|
||||
// @Summary 创建新对话
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.CreateChatRequest true "角色卡ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatResponse} "创建成功"
|
||||
// @Router /app/chat [post]
|
||||
func (ca *ChatApi) CreateChat(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.CreateChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := chatService.CreateChat(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("创建对话失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(chat, c)
|
||||
}
|
||||
|
||||
// GetChatList 获取对话列表
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话列表
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatListResponse} "获取成功"
|
||||
// @Router /app/chat/list [get]
|
||||
func (ca *ChatApi) GetChatList(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.ChatListRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
|
||||
list, err := chatService.GetChatList(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取对话列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败", c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(list, c)
|
||||
}
|
||||
|
||||
// GetChatDetail 获取对话详情(含消息)
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话详情
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatDetailResponse} "获取成功"
|
||||
// @Router /app/chat/:id [get]
|
||||
func (ca *ChatApi) GetChatDetail(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
detail, err := chatService.GetChatDetail(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(detail, c)
|
||||
}
|
||||
|
||||
// GetChatMessages 分页获取消息
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话消息
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Success 200 {object} response.Response{data=appResponse.MessageListResponse} "获取成功"
|
||||
// @Router /app/chat/:id/messages [get]
|
||||
func (ca *ChatApi) GetChatMessages(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
var req request.ChatMessagesRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 50
|
||||
}
|
||||
|
||||
messages, err := chatService.GetChatMessages(uint(id), req, userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(messages, c)
|
||||
}
|
||||
|
||||
// DeleteChat 删除对话
|
||||
// @Tags 对话
|
||||
// @Summary 删除对话
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/chat/:id [delete]
|
||||
func (ca *ChatApi) DeleteChat(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := chatService.DeleteChat(uint(id), userID); err != nil {
|
||||
global.GVA_LOG.Error("删除对话失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 消息操作 ====================
|
||||
|
||||
// SendMessage 发送消息并获取AI回复(SSE 流式响应)
|
||||
// @Tags 对话
|
||||
// @Summary 发送消息(SSE流式)
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce text/event-stream
|
||||
// @Param data body request.SendMessageRequest true "消息内容"
|
||||
// @Router /app/chat/send [post]
|
||||
func (ca *ChatApi) SendMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.SendMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 获取对话上下文
|
||||
chat, character, historyMsgs, err := chatService.GetChatForAI(req.ChatID, userID)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 保存用户消息
|
||||
_, err = chatService.SaveUserMessage(chat.ID, userID, req.Content)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"code": 7, "msg": "保存消息失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2.1 世界书匹配(根据角色卡 + 历史消息)
|
||||
var worldInfoTexts []string
|
||||
if character != nil && chat.CharacterID != nil {
|
||||
// 收集最近的消息文本(用于关键词匹配),附加当前用户输入
|
||||
messagesForMatch := make([]string, 0, len(historyMsgs)+1)
|
||||
for _, m := range historyMsgs {
|
||||
if m.Content != "" {
|
||||
messagesForMatch = append(messagesForMatch, m.Content)
|
||||
}
|
||||
}
|
||||
if req.Content != "" {
|
||||
messagesForMatch = append(messagesForMatch, req.Content)
|
||||
}
|
||||
|
||||
matchReq := request.MatchWorldInfoRequest{
|
||||
CharacterID: *chat.CharacterID,
|
||||
Messages: messagesForMatch,
|
||||
ScanDepth: 10,
|
||||
MaxTokens: 2000,
|
||||
}
|
||||
|
||||
if result, wErr := worldInfoService.MatchWorldInfo(userID, &matchReq); wErr != nil {
|
||||
global.GVA_LOG.Warn("匹配世界书失败",
|
||||
zap.Uint("chatID", chat.ID),
|
||||
zap.Uint("characterID", *chat.CharacterID),
|
||||
zap.Error(wErr))
|
||||
} else if result != nil && len(result.Entries) > 0 {
|
||||
// 按 position 分组,暂时统一作为 System 级别补充上下文
|
||||
beforeChar := make([]string, 0)
|
||||
afterChar := make([]string, 0)
|
||||
for _, entry := range result.Entries {
|
||||
text := strings.TrimSpace(entry.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if entry.Position == "after_char" {
|
||||
afterChar = append(afterChar, text)
|
||||
} else {
|
||||
// 默认 before_char
|
||||
beforeChar = append(beforeChar, text)
|
||||
}
|
||||
}
|
||||
|
||||
if len(beforeChar) > 0 {
|
||||
worldInfoTexts = append(worldInfoTexts, strings.Join(beforeChar, "\n\n"))
|
||||
}
|
||||
if len(afterChar) > 0 {
|
||||
worldInfoTexts = append(worldInfoTexts, strings.Join(afterChar, "\n\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 获取 AI 提供商和模型
|
||||
var provider *appModel.AIProvider
|
||||
var modelName string
|
||||
|
||||
if req.ProviderID != nil {
|
||||
// 指定了提供商
|
||||
p, pErr := providerService.GetUserDefaultProvider(userID) // 简化:暂时仅用默认
|
||||
if pErr != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": "指定的 AI 接口不可用"})
|
||||
return
|
||||
}
|
||||
provider = p
|
||||
} else {
|
||||
// 使用默认提供商
|
||||
p, pErr := providerService.GetUserDefaultProvider(userID)
|
||||
if pErr != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": pErr.Error()})
|
||||
return
|
||||
}
|
||||
provider = p
|
||||
}
|
||||
|
||||
if req.ModelName != "" {
|
||||
modelName = req.ModelName
|
||||
} else {
|
||||
// 使用提供商的第一个启用的聊天模型
|
||||
modelName = getDefaultChatModelForProvider(provider.ID)
|
||||
if modelName == "" {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": "该 AI 接口没有可用的聊天模型"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 构建 Prompt(角色卡 + 世界书 + 历史消息)
|
||||
prompt := appService.BuildPrompt(character, historyMsgs)
|
||||
|
||||
// 将世界书内容插入为额外 System 提示(靠近角色定义)
|
||||
if len(worldInfoTexts) > 0 {
|
||||
worldInfoContent := "World Info:\n" + strings.Join(worldInfoTexts, "\n\n")
|
||||
inserted := false
|
||||
for i, msg := range prompt {
|
||||
if msg.Role == "system" {
|
||||
// 在第一个 system 消息之后插入一条世界书消息
|
||||
newPrompt := make([]appService.AIMessagePayload, 0, len(prompt)+1)
|
||||
newPrompt = append(newPrompt, prompt[:i+1]...)
|
||||
newPrompt = append(newPrompt, appService.AIMessagePayload{
|
||||
Role: "system",
|
||||
Content: worldInfoContent,
|
||||
})
|
||||
newPrompt = append(newPrompt, prompt[i+1:]...)
|
||||
prompt = newPrompt
|
||||
inserted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !inserted {
|
||||
prompt = append([]appService.AIMessagePayload{{
|
||||
Role: "system",
|
||||
Content: worldInfoContent,
|
||||
}}, prompt...)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加用户的新消息
|
||||
prompt = append(prompt, appService.AIMessagePayload{
|
||||
Role: "user",
|
||||
Content: req.Content,
|
||||
})
|
||||
|
||||
// 5. 设置 SSE 响应头
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
// 6. 流式调用 AI
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan appService.AIStreamChunk, 100)
|
||||
go appService.StreamAIResponse(ctx, provider, modelName, prompt, ch)
|
||||
|
||||
var fullContent string
|
||||
var promptTokens, completionTokens int
|
||||
|
||||
flusher, ok := c.Writer.(interface{ Flush() })
|
||||
if !ok {
|
||||
c.JSON(500, gin.H{"code": 7, "msg": "服务器不支持流式响应"})
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range ch {
|
||||
if chunk.Error != "" {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": chunk.Error,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
if chunk.Content != "" {
|
||||
fullContent += chunk.Content
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "content",
|
||||
"content": chunk.Content,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if chunk.Done {
|
||||
promptTokens = chunk.PromptTokens
|
||||
completionTokens = chunk.CompletionTokens
|
||||
|
||||
// 保存 AI 回复
|
||||
if fullContent != "" {
|
||||
chatService.SaveAssistantMessage(
|
||||
chat.ID, chat.CharacterID, fullContent,
|
||||
modelName, promptTokens, completionTokens,
|
||||
)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "done",
|
||||
"model": modelName,
|
||||
"promptTokens": promptTokens,
|
||||
"completionTokens": completionTokens,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EditMessage 编辑消息
|
||||
// @Tags 对话
|
||||
// @Summary 编辑消息内容
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.EditMessageRequest true "编辑信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.MessageResponse} "编辑成功"
|
||||
// @Router /app/chat/message/edit [post]
|
||||
func (ca *ChatApi) EditMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.EditMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := chatService.EditMessage(req, userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(msg, c)
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息
|
||||
// @Tags 对话
|
||||
// @Summary 删除消息
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.DeleteMessageRequest true "消息ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/chat/message/delete [post]
|
||||
func (ca *ChatApi) DeleteMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.DeleteMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := chatService.DeleteMessage(req.MessageID, userID); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 内部辅助 ====================
|
||||
|
||||
// getDefaultChatModelForProvider 获取提供商的默认聊天模型
|
||||
func getDefaultChatModelForProvider(providerID uint) string {
|
||||
var model appModel.AIModel
|
||||
err := global.GVA_DB.Where("provider_id = ? AND model_type = ? AND is_enabled = ?",
|
||||
providerID, "chat", true).
|
||||
Order("created_at ASC").
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return model.ModelName
|
||||
}
|
||||
Reference in New Issue
Block a user