Files
st/server/api/v1/app/chat.go

471 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/middleware"
appModel "git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/common/response"
appService "git.echol.cn/loser/st/server/service/app"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type ChatApi struct{}
// ==================== 对话管理 ====================
// CreateChat 创建对话
// @Tags 对话
// @Summary 创建新对话
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.CreateChatRequest true "角色卡ID"
// @Success 200 {object} response.Response{data=appResponse.ChatResponse} "创建成功"
// @Router /app/chat [post]
func (ca *ChatApi) CreateChat(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.CreateChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
chat, err := chatService.CreateChat(req, userID)
if err != nil {
global.GVA_LOG.Error("创建对话失败", zap.Error(err))
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(chat, c)
}
// GetChatList 获取对话列表
// @Tags 对话
// @Summary 获取对话列表
// @Security ApiKeyAuth
// @Produce application/json
// @Param page query int true "页码"
// @Param pageSize query int true "每页数量"
// @Success 200 {object} response.Response{data=appResponse.ChatListResponse} "获取成功"
// @Router /app/chat/list [get]
func (ca *ChatApi) GetChatList(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.ChatListRequest
if err := c.ShouldBindQuery(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if req.Page == 0 {
req.Page = 1
}
if req.PageSize == 0 {
req.PageSize = 20
}
list, err := chatService.GetChatList(req, userID)
if err != nil {
global.GVA_LOG.Error("获取对话列表失败", zap.Error(err))
response.FailWithMessage("获取失败", c)
return
}
response.OkWithData(list, c)
}
// GetChatDetail 获取对话详情(含消息)
// @Tags 对话
// @Summary 获取对话详情
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "对话ID"
// @Success 200 {object} response.Response{data=appResponse.ChatDetailResponse} "获取成功"
// @Router /app/chat/:id [get]
func (ca *ChatApi) GetChatDetail(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
detail, err := chatService.GetChatDetail(uint(id), userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(detail, c)
}
// GetChatMessages 分页获取消息
// @Tags 对话
// @Summary 获取对话消息
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "对话ID"
// @Param page query int true "页码"
// @Param pageSize query int true "每页数量"
// @Success 200 {object} response.Response{data=appResponse.MessageListResponse} "获取成功"
// @Router /app/chat/:id/messages [get]
func (ca *ChatApi) GetChatMessages(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
var req request.ChatMessagesRequest
if err := c.ShouldBindQuery(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if req.Page == 0 {
req.Page = 1
}
if req.PageSize == 0 {
req.PageSize = 50
}
messages, err := chatService.GetChatMessages(uint(id), req, userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(messages, c)
}
// DeleteChat 删除对话
// @Tags 对话
// @Summary 删除对话
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "对话ID"
// @Success 200 {object} response.Response{msg=string} "删除成功"
// @Router /app/chat/:id [delete]
func (ca *ChatApi) DeleteChat(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
if err := chatService.DeleteChat(uint(id), userID); err != nil {
global.GVA_LOG.Error("删除对话失败", zap.Error(err))
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithMessage("删除成功", c)
}
// ==================== 消息操作 ====================
// SendMessage 发送消息并获取AI回复SSE 流式响应)
// @Tags 对话
// @Summary 发送消息SSE流式
// @Security ApiKeyAuth
// @accept application/json
// @Produce text/event-stream
// @Param data body request.SendMessageRequest true "消息内容"
// @Router /app/chat/send [post]
func (ca *ChatApi) SendMessage(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.SendMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
return
}
// 1. 获取对话上下文
chat, character, historyMsgs, err := chatService.GetChatForAI(req.ChatID, userID)
if err != nil {
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
return
}
// 2. 保存用户消息
_, err = chatService.SaveUserMessage(chat.ID, userID, req.Content)
if err != nil {
c.JSON(500, gin.H{"code": 7, "msg": "保存消息失败"})
return
}
// 2.1 世界书匹配(根据角色卡 + 历史消息)
var worldInfoTexts []string
if character != nil && chat.CharacterID != nil {
// 收集最近的消息文本(用于关键词匹配),附加当前用户输入
messagesForMatch := make([]string, 0, len(historyMsgs)+1)
for _, m := range historyMsgs {
if m.Content != "" {
messagesForMatch = append(messagesForMatch, m.Content)
}
}
if req.Content != "" {
messagesForMatch = append(messagesForMatch, req.Content)
}
matchReq := request.MatchWorldInfoRequest{
CharacterID: *chat.CharacterID,
Messages: messagesForMatch,
ScanDepth: 10,
MaxTokens: 2000,
}
if result, wErr := worldInfoService.MatchWorldInfo(userID, &matchReq); wErr != nil {
global.GVA_LOG.Warn("匹配世界书失败",
zap.Uint("chatID", chat.ID),
zap.Uint("characterID", *chat.CharacterID),
zap.Error(wErr))
} else if result != nil && len(result.Entries) > 0 {
// 按 position 分组,暂时统一作为 System 级别补充上下文
beforeChar := make([]string, 0)
afterChar := make([]string, 0)
for _, entry := range result.Entries {
text := strings.TrimSpace(entry.Content)
if text == "" {
continue
}
if entry.Position == "after_char" {
afterChar = append(afterChar, text)
} else {
// 默认 before_char
beforeChar = append(beforeChar, text)
}
}
if len(beforeChar) > 0 {
worldInfoTexts = append(worldInfoTexts, strings.Join(beforeChar, "\n\n"))
}
if len(afterChar) > 0 {
worldInfoTexts = append(worldInfoTexts, strings.Join(afterChar, "\n\n"))
}
}
}
// 3. 获取 AI 提供商和模型
var provider *appModel.AIProvider
var modelName string
if req.ProviderID != nil {
// 指定了提供商
p, pErr := providerService.GetUserDefaultProvider(userID) // 简化:暂时仅用默认
if pErr != nil {
c.JSON(400, gin.H{"code": 7, "msg": "指定的 AI 接口不可用"})
return
}
provider = p
} else {
// 使用默认提供商
p, pErr := providerService.GetUserDefaultProvider(userID)
if pErr != nil {
c.JSON(400, gin.H{"code": 7, "msg": pErr.Error()})
return
}
provider = p
}
if req.ModelName != "" {
modelName = req.ModelName
} else {
// 使用提供商的第一个启用的聊天模型
modelName = getDefaultChatModelForProvider(provider.ID)
if modelName == "" {
c.JSON(400, gin.H{"code": 7, "msg": "该 AI 接口没有可用的聊天模型"})
return
}
}
// 4. 构建 Prompt角色卡 + 世界书 + 历史消息)
prompt := appService.BuildPrompt(character, historyMsgs)
// 将世界书内容插入为额外 System 提示(靠近角色定义)
if len(worldInfoTexts) > 0 {
worldInfoContent := "World Info:\n" + strings.Join(worldInfoTexts, "\n\n")
inserted := false
for i, msg := range prompt {
if msg.Role == "system" {
// 在第一个 system 消息之后插入一条世界书消息
newPrompt := make([]appService.AIMessagePayload, 0, len(prompt)+1)
newPrompt = append(newPrompt, prompt[:i+1]...)
newPrompt = append(newPrompt, appService.AIMessagePayload{
Role: "system",
Content: worldInfoContent,
})
newPrompt = append(newPrompt, prompt[i+1:]...)
prompt = newPrompt
inserted = true
break
}
}
if !inserted {
prompt = append([]appService.AIMessagePayload{{
Role: "system",
Content: worldInfoContent,
}}, prompt...)
}
}
// 添加用户的新消息
prompt = append(prompt, appService.AIMessagePayload{
Role: "user",
Content: req.Content,
})
// 5. 设置 SSE 响应头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 6. 流式调用 AI
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
ch := make(chan appService.AIStreamChunk, 100)
go appService.StreamAIResponse(ctx, provider, modelName, prompt, ch)
var fullContent string
var promptTokens, completionTokens int
flusher, ok := c.Writer.(interface{ Flush() })
if !ok {
c.JSON(500, gin.H{"code": 7, "msg": "服务器不支持流式响应"})
return
}
for chunk := range ch {
if chunk.Error != "" {
data, _ := json.Marshal(map[string]interface{}{
"type": "error",
"error": chunk.Error,
})
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
flusher.Flush()
break
}
if chunk.Content != "" {
fullContent += chunk.Content
data, _ := json.Marshal(map[string]interface{}{
"type": "content",
"content": chunk.Content,
})
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
flusher.Flush()
}
if chunk.Done {
promptTokens = chunk.PromptTokens
completionTokens = chunk.CompletionTokens
// 保存 AI 回复
if fullContent != "" {
chatService.SaveAssistantMessage(
chat.ID, chat.CharacterID, fullContent,
modelName, promptTokens, completionTokens,
)
}
data, _ := json.Marshal(map[string]interface{}{
"type": "done",
"model": modelName,
"promptTokens": promptTokens,
"completionTokens": completionTokens,
})
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
flusher.Flush()
}
}
}
// EditMessage 编辑消息
// @Tags 对话
// @Summary 编辑消息内容
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.EditMessageRequest true "编辑信息"
// @Success 200 {object} response.Response{data=appResponse.MessageResponse} "编辑成功"
// @Router /app/chat/message/edit [post]
func (ca *ChatApi) EditMessage(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.EditMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
msg, err := chatService.EditMessage(req, userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(msg, c)
}
// DeleteMessage 删除消息
// @Tags 对话
// @Summary 删除消息
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.DeleteMessageRequest true "消息ID"
// @Success 200 {object} response.Response{msg=string} "删除成功"
// @Router /app/chat/message/delete [post]
func (ca *ChatApi) DeleteMessage(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.DeleteMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if err := chatService.DeleteMessage(req.MessageID, userID); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithMessage("删除成功", c)
}
// ==================== 内部辅助 ====================
// getDefaultChatModelForProvider 获取提供商的默认聊天模型
func getDefaultChatModelForProvider(providerID uint) string {
var model appModel.AIModel
err := global.GVA_DB.Where("provider_id = ? AND model_type = ? AND is_enabled = ?",
providerID, "chat", true).
Order("created_at ASC").
First(&model).Error
if err != nil {
return ""
}
return model.ModelName
}