471 lines
13 KiB
Go
471 lines
13 KiB
Go
package app
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"git.echol.cn/loser/st/server/global"
|
||
"git.echol.cn/loser/st/server/middleware"
|
||
appModel "git.echol.cn/loser/st/server/model/app"
|
||
"git.echol.cn/loser/st/server/model/app/request"
|
||
"git.echol.cn/loser/st/server/model/common/response"
|
||
appService "git.echol.cn/loser/st/server/service/app"
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
type ChatApi struct{}
|
||
|
||
// ==================== 对话管理 ====================
|
||
|
||
// CreateChat 创建对话
|
||
// @Tags 对话
|
||
// @Summary 创建新对话
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
// @Param data body request.CreateChatRequest true "角色卡ID"
|
||
// @Success 200 {object} response.Response{data=appResponse.ChatResponse} "创建成功"
|
||
// @Router /app/chat [post]
|
||
func (ca *ChatApi) CreateChat(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
var req request.CreateChatRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
chat, err := chatService.CreateChat(req, userID)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("创建对话失败", zap.Error(err))
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
response.OkWithData(chat, c)
|
||
}
|
||
|
||
// GetChatList 获取对话列表
|
||
// @Tags 对话
|
||
// @Summary 获取对话列表
|
||
// @Security ApiKeyAuth
|
||
// @Produce application/json
|
||
// @Param page query int true "页码"
|
||
// @Param pageSize query int true "每页数量"
|
||
// @Success 200 {object} response.Response{data=appResponse.ChatListResponse} "获取成功"
|
||
// @Router /app/chat/list [get]
|
||
func (ca *ChatApi) GetChatList(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
var req request.ChatListRequest
|
||
if err := c.ShouldBindQuery(&req); err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
if req.Page == 0 {
|
||
req.Page = 1
|
||
}
|
||
if req.PageSize == 0 {
|
||
req.PageSize = 20
|
||
}
|
||
|
||
list, err := chatService.GetChatList(req, userID)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("获取对话列表失败", zap.Error(err))
|
||
response.FailWithMessage("获取失败", c)
|
||
return
|
||
}
|
||
|
||
response.OkWithData(list, c)
|
||
}
|
||
|
||
// GetChatDetail 获取对话详情(含消息)
|
||
// @Tags 对话
|
||
// @Summary 获取对话详情
|
||
// @Security ApiKeyAuth
|
||
// @Produce application/json
|
||
// @Param id path uint true "对话ID"
|
||
// @Success 200 {object} response.Response{data=appResponse.ChatDetailResponse} "获取成功"
|
||
// @Router /app/chat/:id [get]
|
||
func (ca *ChatApi) GetChatDetail(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
idStr := c.Param("id")
|
||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||
if err != nil {
|
||
response.FailWithMessage("无效的ID", c)
|
||
return
|
||
}
|
||
|
||
detail, err := chatService.GetChatDetail(uint(id), userID)
|
||
if err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
response.OkWithData(detail, c)
|
||
}
|
||
|
||
// GetChatMessages 分页获取消息
|
||
// @Tags 对话
|
||
// @Summary 获取对话消息
|
||
// @Security ApiKeyAuth
|
||
// @Produce application/json
|
||
// @Param id path uint true "对话ID"
|
||
// @Param page query int true "页码"
|
||
// @Param pageSize query int true "每页数量"
|
||
// @Success 200 {object} response.Response{data=appResponse.MessageListResponse} "获取成功"
|
||
// @Router /app/chat/:id/messages [get]
|
||
func (ca *ChatApi) GetChatMessages(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
idStr := c.Param("id")
|
||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||
if err != nil {
|
||
response.FailWithMessage("无效的ID", c)
|
||
return
|
||
}
|
||
|
||
var req request.ChatMessagesRequest
|
||
if err := c.ShouldBindQuery(&req); err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
if req.Page == 0 {
|
||
req.Page = 1
|
||
}
|
||
if req.PageSize == 0 {
|
||
req.PageSize = 50
|
||
}
|
||
|
||
messages, err := chatService.GetChatMessages(uint(id), req, userID)
|
||
if err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
response.OkWithData(messages, c)
|
||
}
|
||
|
||
// DeleteChat 删除对话
|
||
// @Tags 对话
|
||
// @Summary 删除对话
|
||
// @Security ApiKeyAuth
|
||
// @Produce application/json
|
||
// @Param id path uint true "对话ID"
|
||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||
// @Router /app/chat/:id [delete]
|
||
func (ca *ChatApi) DeleteChat(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
idStr := c.Param("id")
|
||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||
if err != nil {
|
||
response.FailWithMessage("无效的ID", c)
|
||
return
|
||
}
|
||
|
||
if err := chatService.DeleteChat(uint(id), userID); err != nil {
|
||
global.GVA_LOG.Error("删除对话失败", zap.Error(err))
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
response.OkWithMessage("删除成功", c)
|
||
}
|
||
|
||
// ==================== 消息操作 ====================
|
||
|
||
// SendMessage 发送消息并获取AI回复(SSE 流式响应)
|
||
// @Tags 对话
|
||
// @Summary 发送消息(SSE流式)
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce text/event-stream
|
||
// @Param data body request.SendMessageRequest true "消息内容"
|
||
// @Router /app/chat/send [post]
|
||
func (ca *ChatApi) SendMessage(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
var req request.SendMessageRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 1. 获取对话上下文
|
||
chat, character, historyMsgs, err := chatService.GetChatForAI(req.ChatID, userID)
|
||
if err != nil {
|
||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 2. 保存用户消息
|
||
_, err = chatService.SaveUserMessage(chat.ID, userID, req.Content)
|
||
if err != nil {
|
||
c.JSON(500, gin.H{"code": 7, "msg": "保存消息失败"})
|
||
return
|
||
}
|
||
|
||
// 2.1 世界书匹配(根据角色卡 + 历史消息)
|
||
var worldInfoTexts []string
|
||
if character != nil && chat.CharacterID != nil {
|
||
// 收集最近的消息文本(用于关键词匹配),附加当前用户输入
|
||
messagesForMatch := make([]string, 0, len(historyMsgs)+1)
|
||
for _, m := range historyMsgs {
|
||
if m.Content != "" {
|
||
messagesForMatch = append(messagesForMatch, m.Content)
|
||
}
|
||
}
|
||
if req.Content != "" {
|
||
messagesForMatch = append(messagesForMatch, req.Content)
|
||
}
|
||
|
||
matchReq := request.MatchWorldInfoRequest{
|
||
CharacterID: *chat.CharacterID,
|
||
Messages: messagesForMatch,
|
||
ScanDepth: 10,
|
||
MaxTokens: 2000,
|
||
}
|
||
|
||
if result, wErr := worldInfoService.MatchWorldInfo(userID, &matchReq); wErr != nil {
|
||
global.GVA_LOG.Warn("匹配世界书失败",
|
||
zap.Uint("chatID", chat.ID),
|
||
zap.Uint("characterID", *chat.CharacterID),
|
||
zap.Error(wErr))
|
||
} else if result != nil && len(result.Entries) > 0 {
|
||
// 按 position 分组,暂时统一作为 System 级别补充上下文
|
||
beforeChar := make([]string, 0)
|
||
afterChar := make([]string, 0)
|
||
for _, entry := range result.Entries {
|
||
text := strings.TrimSpace(entry.Content)
|
||
if text == "" {
|
||
continue
|
||
}
|
||
if entry.Position == "after_char" {
|
||
afterChar = append(afterChar, text)
|
||
} else {
|
||
// 默认 before_char
|
||
beforeChar = append(beforeChar, text)
|
||
}
|
||
}
|
||
|
||
if len(beforeChar) > 0 {
|
||
worldInfoTexts = append(worldInfoTexts, strings.Join(beforeChar, "\n\n"))
|
||
}
|
||
if len(afterChar) > 0 {
|
||
worldInfoTexts = append(worldInfoTexts, strings.Join(afterChar, "\n\n"))
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. 获取 AI 提供商和模型
|
||
var provider *appModel.AIProvider
|
||
var modelName string
|
||
|
||
if req.ProviderID != nil {
|
||
// 指定了提供商
|
||
p, pErr := providerService.GetUserDefaultProvider(userID) // 简化:暂时仅用默认
|
||
if pErr != nil {
|
||
c.JSON(400, gin.H{"code": 7, "msg": "指定的 AI 接口不可用"})
|
||
return
|
||
}
|
||
provider = p
|
||
} else {
|
||
// 使用默认提供商
|
||
p, pErr := providerService.GetUserDefaultProvider(userID)
|
||
if pErr != nil {
|
||
c.JSON(400, gin.H{"code": 7, "msg": pErr.Error()})
|
||
return
|
||
}
|
||
provider = p
|
||
}
|
||
|
||
if req.ModelName != "" {
|
||
modelName = req.ModelName
|
||
} else {
|
||
// 使用提供商的第一个启用的聊天模型
|
||
modelName = getDefaultChatModelForProvider(provider.ID)
|
||
if modelName == "" {
|
||
c.JSON(400, gin.H{"code": 7, "msg": "该 AI 接口没有可用的聊天模型"})
|
||
return
|
||
}
|
||
}
|
||
|
||
// 4. 构建 Prompt(角色卡 + 世界书 + 历史消息)
|
||
prompt := appService.BuildPrompt(character, historyMsgs)
|
||
|
||
// 将世界书内容插入为额外 System 提示(靠近角色定义)
|
||
if len(worldInfoTexts) > 0 {
|
||
worldInfoContent := "World Info:\n" + strings.Join(worldInfoTexts, "\n\n")
|
||
inserted := false
|
||
for i, msg := range prompt {
|
||
if msg.Role == "system" {
|
||
// 在第一个 system 消息之后插入一条世界书消息
|
||
newPrompt := make([]appService.AIMessagePayload, 0, len(prompt)+1)
|
||
newPrompt = append(newPrompt, prompt[:i+1]...)
|
||
newPrompt = append(newPrompt, appService.AIMessagePayload{
|
||
Role: "system",
|
||
Content: worldInfoContent,
|
||
})
|
||
newPrompt = append(newPrompt, prompt[i+1:]...)
|
||
prompt = newPrompt
|
||
inserted = true
|
||
break
|
||
}
|
||
}
|
||
if !inserted {
|
||
prompt = append([]appService.AIMessagePayload{{
|
||
Role: "system",
|
||
Content: worldInfoContent,
|
||
}}, prompt...)
|
||
}
|
||
}
|
||
|
||
// 添加用户的新消息
|
||
prompt = append(prompt, appService.AIMessagePayload{
|
||
Role: "user",
|
||
Content: req.Content,
|
||
})
|
||
|
||
// 5. 设置 SSE 响应头
|
||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||
c.Writer.Header().Set("Connection", "keep-alive")
|
||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||
|
||
// 6. 流式调用 AI
|
||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||
defer cancel()
|
||
|
||
ch := make(chan appService.AIStreamChunk, 100)
|
||
go appService.StreamAIResponse(ctx, provider, modelName, prompt, ch)
|
||
|
||
var fullContent string
|
||
var promptTokens, completionTokens int
|
||
|
||
flusher, ok := c.Writer.(interface{ Flush() })
|
||
if !ok {
|
||
c.JSON(500, gin.H{"code": 7, "msg": "服务器不支持流式响应"})
|
||
return
|
||
}
|
||
|
||
for chunk := range ch {
|
||
if chunk.Error != "" {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"type": "error",
|
||
"error": chunk.Error,
|
||
})
|
||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||
flusher.Flush()
|
||
break
|
||
}
|
||
|
||
if chunk.Content != "" {
|
||
fullContent += chunk.Content
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"type": "content",
|
||
"content": chunk.Content,
|
||
})
|
||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||
flusher.Flush()
|
||
}
|
||
|
||
if chunk.Done {
|
||
promptTokens = chunk.PromptTokens
|
||
completionTokens = chunk.CompletionTokens
|
||
|
||
// 保存 AI 回复
|
||
if fullContent != "" {
|
||
chatService.SaveAssistantMessage(
|
||
chat.ID, chat.CharacterID, fullContent,
|
||
modelName, promptTokens, completionTokens,
|
||
)
|
||
}
|
||
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"type": "done",
|
||
"model": modelName,
|
||
"promptTokens": promptTokens,
|
||
"completionTokens": completionTokens,
|
||
})
|
||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
// EditMessage 编辑消息
|
||
// @Tags 对话
|
||
// @Summary 编辑消息内容
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
// @Param data body request.EditMessageRequest true "编辑信息"
|
||
// @Success 200 {object} response.Response{data=appResponse.MessageResponse} "编辑成功"
|
||
// @Router /app/chat/message/edit [post]
|
||
func (ca *ChatApi) EditMessage(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
var req request.EditMessageRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
msg, err := chatService.EditMessage(req, userID)
|
||
if err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
response.OkWithData(msg, c)
|
||
}
|
||
|
||
// DeleteMessage 删除消息
|
||
// @Tags 对话
|
||
// @Summary 删除消息
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
// @Param data body request.DeleteMessageRequest true "消息ID"
|
||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||
// @Router /app/chat/message/delete [post]
|
||
func (ca *ChatApi) DeleteMessage(c *gin.Context) {
|
||
userID := middleware.GetAppUserID(c)
|
||
|
||
var req request.DeleteMessageRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
if err := chatService.DeleteMessage(req.MessageID, userID); err != nil {
|
||
response.FailWithMessage(err.Error(), c)
|
||
return
|
||
}
|
||
|
||
response.OkWithMessage("删除成功", c)
|
||
}
|
||
|
||
// ==================== 内部辅助 ====================
|
||
|
||
// getDefaultChatModelForProvider 获取提供商的默认聊天模型
|
||
func getDefaultChatModelForProvider(providerID uint) string {
|
||
var model appModel.AIModel
|
||
err := global.GVA_DB.Where("provider_id = ? AND model_type = ? AND is_enabled = ?",
|
||
providerID, "chat", true).
|
||
Order("created_at ASC").
|
||
First(&model).Error
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
return model.ModelName
|
||
}
|