409 lines
11 KiB
Go
409 lines
11 KiB
Go
package app
|
||
|
||
import (
|
||
"errors"
|
||
"time"
|
||
|
||
"git.echol.cn/loser/st/server/global"
|
||
"git.echol.cn/loser/st/server/model/app"
|
||
"git.echol.cn/loser/st/server/model/app/request"
|
||
"git.echol.cn/loser/st/server/model/app/response"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type ChatService struct{}
|
||
|
||
// ==================== 对话 CRUD ====================
|
||
|
||
// CreateChat 创建对话
|
||
func (cs *ChatService) CreateChat(req request.CreateChatRequest, userID uint) (response.ChatResponse, error) {
|
||
// 获取角色卡信息
|
||
var character app.AICharacter
|
||
err := global.GVA_DB.First(&character, req.CharacterID).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return response.ChatResponse{}, errors.New("角色卡不存在")
|
||
}
|
||
return response.ChatResponse{}, err
|
||
}
|
||
|
||
// 对话标题
|
||
title := req.Title
|
||
if title == "" {
|
||
title = character.Name
|
||
}
|
||
|
||
now := time.Now()
|
||
chat := app.AIChat{
|
||
Title: title,
|
||
UserID: userID,
|
||
CharacterID: &req.CharacterID,
|
||
ChatType: "single",
|
||
LastMessageAt: &now,
|
||
}
|
||
|
||
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||
// 创建对话
|
||
if err := tx.Create(&chat).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 如果角色卡有 FirstMessage,自动创建第一条系统消息
|
||
if character.FirstMessage != "" {
|
||
firstMsg := app.AIMessage{
|
||
ChatID: chat.ID,
|
||
Content: character.FirstMessage,
|
||
Role: "assistant",
|
||
CharacterID: &req.CharacterID,
|
||
SequenceNumber: 1,
|
||
}
|
||
if err := tx.Create(&firstMsg).Error; err != nil {
|
||
return err
|
||
}
|
||
chat.MessageCount = 1
|
||
tx.Model(&chat).Update("message_count", 1)
|
||
}
|
||
|
||
// 更新角色卡使用次数
|
||
tx.Model(&character).Update("total_chats", gorm.Expr("total_chats + 1"))
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return response.ChatResponse{}, err
|
||
}
|
||
|
||
return toChatResponse(&chat, &character, nil), nil
|
||
}
|
||
|
||
// GetChatList 获取用户的对话列表
|
||
func (cs *ChatService) GetChatList(req request.ChatListRequest, userID uint) (response.ChatListResponse, error) {
|
||
db := global.GVA_DB.Model(&app.AIChat{}).Where("user_id = ?", userID)
|
||
|
||
var total int64
|
||
db.Count(&total)
|
||
|
||
var chats []app.AIChat
|
||
offset := (req.Page - 1) * req.PageSize
|
||
err := db.Preload("Character").
|
||
Order("is_pinned DESC, last_message_at DESC").
|
||
Offset(offset).Limit(req.PageSize).
|
||
Find(&chats).Error
|
||
if err != nil {
|
||
return response.ChatListResponse{}, err
|
||
}
|
||
|
||
// 获取每个对话的最后一条消息
|
||
chatIDs := make([]uint, len(chats))
|
||
for i, c := range chats {
|
||
chatIDs[i] = c.ID
|
||
}
|
||
|
||
lastMessages := make(map[uint]*response.MessageBrief)
|
||
if len(chatIDs) > 0 {
|
||
var messages []app.AIMessage
|
||
// 获取每个对话的最后一条消息(通过子查询)
|
||
global.GVA_DB.Where("chat_id IN ? AND is_deleted = ?", chatIDs, false).
|
||
Order("sequence_number DESC").
|
||
Find(&messages)
|
||
|
||
// 只保留每个对话的最后一条
|
||
for _, msg := range messages {
|
||
if _, exists := lastMessages[msg.ChatID]; !exists {
|
||
content := msg.Content
|
||
if len(content) > 100 {
|
||
content = content[:100] + "..."
|
||
}
|
||
lastMessages[msg.ChatID] = &response.MessageBrief{
|
||
Content: content,
|
||
Role: msg.Role,
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
list := make([]response.ChatResponse, len(chats))
|
||
for i, c := range chats {
|
||
list[i] = toChatResponse(&c, c.Character, lastMessages[c.ID])
|
||
}
|
||
|
||
return response.ChatListResponse{
|
||
List: list,
|
||
Total: total,
|
||
Page: req.Page,
|
||
PageSize: req.PageSize,
|
||
}, nil
|
||
}
|
||
|
||
// GetChatDetail 获取对话详情(包含消息列表)
|
||
func (cs *ChatService) GetChatDetail(chatID uint, userID uint) (response.ChatDetailResponse, error) {
|
||
var chat app.AIChat
|
||
err := global.GVA_DB.Preload("Character").
|
||
Where("id = ? AND user_id = ?", chatID, userID).
|
||
First(&chat).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return response.ChatDetailResponse{}, errors.New("对话不存在")
|
||
}
|
||
return response.ChatDetailResponse{}, err
|
||
}
|
||
|
||
// 获取消息列表(最近的50条)
|
||
var messages []app.AIMessage
|
||
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
|
||
Order("sequence_number ASC").
|
||
Limit(50).
|
||
Find(&messages)
|
||
|
||
msgList := make([]response.MessageResponse, len(messages))
|
||
for i, msg := range messages {
|
||
msgList[i] = toMessageResponse(&msg, chat.Character)
|
||
}
|
||
|
||
return response.ChatDetailResponse{
|
||
Chat: toChatResponse(&chat, chat.Character, nil),
|
||
Messages: msgList,
|
||
}, nil
|
||
}
|
||
|
||
// GetChatMessages 分页获取对话消息
|
||
func (cs *ChatService) GetChatMessages(chatID uint, req request.ChatMessagesRequest, userID uint) (response.MessageListResponse, error) {
|
||
// 验证对话归属
|
||
var chat app.AIChat
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||
if err != nil {
|
||
return response.MessageListResponse{}, errors.New("对话不存在")
|
||
}
|
||
|
||
db := global.GVA_DB.Model(&app.AIMessage{}).
|
||
Where("chat_id = ? AND is_deleted = ?", chatID, false)
|
||
|
||
var total int64
|
||
db.Count(&total)
|
||
|
||
var messages []app.AIMessage
|
||
offset := (req.Page - 1) * req.PageSize
|
||
err = db.Order("sequence_number ASC").
|
||
Offset(offset).Limit(req.PageSize).
|
||
Find(&messages).Error
|
||
if err != nil {
|
||
return response.MessageListResponse{}, err
|
||
}
|
||
|
||
// 预加载角色信息
|
||
var character *app.AICharacter
|
||
if chat.CharacterID != nil {
|
||
var c app.AICharacter
|
||
global.GVA_DB.First(&c, *chat.CharacterID)
|
||
character = &c
|
||
}
|
||
|
||
list := make([]response.MessageResponse, len(messages))
|
||
for i, msg := range messages {
|
||
list[i] = toMessageResponse(&msg, character)
|
||
}
|
||
|
||
return response.MessageListResponse{
|
||
List: list,
|
||
Total: total,
|
||
Page: req.Page,
|
||
PageSize: req.PageSize,
|
||
}, nil
|
||
}
|
||
|
||
// DeleteChat 删除对话
|
||
func (cs *ChatService) DeleteChat(chatID uint, userID uint) error {
|
||
var chat app.AIChat
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("对话不存在")
|
||
}
|
||
return err
|
||
}
|
||
|
||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||
// 删除消息
|
||
if err := tx.Where("chat_id = ?", chatID).Delete(&app.AIMessage{}).Error; err != nil {
|
||
return err
|
||
}
|
||
// 删除消息变体
|
||
if err := tx.Where("message_id IN (?)",
|
||
tx.Model(&app.AIMessage{}).Select("id").Where("chat_id = ?", chatID),
|
||
).Delete(&app.AIMessageSwipe{}).Error; err != nil {
|
||
// 忽略错误,可能没有变体
|
||
}
|
||
// 删除对话
|
||
return tx.Delete(&chat).Error
|
||
})
|
||
}
|
||
|
||
// ==================== 消息操作 ====================
|
||
|
||
// SaveUserMessage 保存用户消息(内部方法)
|
||
func (cs *ChatService) SaveUserMessage(chatID uint, userID uint, content string) (*app.AIMessage, error) {
|
||
// 获取当前最大序号
|
||
var maxSeq int
|
||
global.GVA_DB.Model(&app.AIMessage{}).
|
||
Where("chat_id = ?", chatID).
|
||
Select("COALESCE(MAX(sequence_number), 0)").
|
||
Scan(&maxSeq)
|
||
|
||
msg := &app.AIMessage{
|
||
ChatID: chatID,
|
||
Content: content,
|
||
Role: "user",
|
||
SenderID: &userID,
|
||
SequenceNumber: maxSeq + 1,
|
||
}
|
||
|
||
if err := global.GVA_DB.Create(msg).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 更新对话的消息数和最后消息时间
|
||
now := time.Now()
|
||
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("message_count + 1"),
|
||
"last_message_at": now,
|
||
})
|
||
|
||
return msg, nil
|
||
}
|
||
|
||
// SaveAssistantMessage 保存AI回复消息(内部方法)
|
||
func (cs *ChatService) SaveAssistantMessage(chatID uint, characterID *uint, content string, model string, promptTokens, completionTokens int) (*app.AIMessage, error) {
|
||
var maxSeq int
|
||
global.GVA_DB.Model(&app.AIMessage{}).
|
||
Where("chat_id = ?", chatID).
|
||
Select("COALESCE(MAX(sequence_number), 0)").
|
||
Scan(&maxSeq)
|
||
|
||
msg := &app.AIMessage{
|
||
ChatID: chatID,
|
||
Content: content,
|
||
Role: "assistant",
|
||
CharacterID: characterID,
|
||
SequenceNumber: maxSeq + 1,
|
||
Model: model,
|
||
PromptTokens: promptTokens,
|
||
CompletionTokens: completionTokens,
|
||
TotalTokens: promptTokens + completionTokens,
|
||
}
|
||
|
||
if err := global.GVA_DB.Create(msg).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
now := time.Now()
|
||
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("message_count + 1"),
|
||
"last_message_at": now,
|
||
})
|
||
|
||
return msg, nil
|
||
}
|
||
|
||
// EditMessage 编辑消息
|
||
func (cs *ChatService) EditMessage(req request.EditMessageRequest, userID uint) (response.MessageResponse, error) {
|
||
var msg app.AIMessage
|
||
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
|
||
Where("ai_messages.id = ? AND ai_chats.user_id = ?", req.MessageID, userID).
|
||
First(&msg).Error
|
||
if err != nil {
|
||
return response.MessageResponse{}, errors.New("消息不存在")
|
||
}
|
||
|
||
msg.Content = req.Content
|
||
if err := global.GVA_DB.Save(&msg).Error; err != nil {
|
||
return response.MessageResponse{}, err
|
||
}
|
||
|
||
return toMessageResponse(&msg, nil), nil
|
||
}
|
||
|
||
// DeleteMessage 删除消息(软删除)
|
||
func (cs *ChatService) DeleteMessage(messageID uint, userID uint) error {
|
||
var msg app.AIMessage
|
||
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
|
||
Where("ai_messages.id = ? AND ai_chats.user_id = ?", messageID, userID).
|
||
First(&msg).Error
|
||
if err != nil {
|
||
return errors.New("消息不存在")
|
||
}
|
||
|
||
return global.GVA_DB.Model(&msg).Update("is_deleted", true).Error
|
||
}
|
||
|
||
// GetChatForAI 获取对话用于AI调用的完整上下文(内部方法)
|
||
func (cs *ChatService) GetChatForAI(chatID uint, userID uint) (*app.AIChat, *app.AICharacter, []app.AIMessage, error) {
|
||
var chat app.AIChat
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||
if err != nil {
|
||
return nil, nil, nil, errors.New("对话不存在")
|
||
}
|
||
|
||
var character *app.AICharacter
|
||
if chat.CharacterID != nil {
|
||
var c app.AICharacter
|
||
if err := global.GVA_DB.First(&c, *chat.CharacterID).Error; err == nil {
|
||
character = &c
|
||
}
|
||
}
|
||
|
||
// 获取历史消息(最近30条)
|
||
var messages []app.AIMessage
|
||
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
|
||
Order("sequence_number ASC").
|
||
Limit(30).
|
||
Find(&messages)
|
||
|
||
return &chat, character, messages, nil
|
||
}
|
||
|
||
// ==================== 辅助函数 ====================
|
||
|
||
func toChatResponse(chat *app.AIChat, character *app.AICharacter, lastMsg *response.MessageBrief) response.ChatResponse {
|
||
resp := response.ChatResponse{
|
||
ID: chat.ID,
|
||
Title: chat.Title,
|
||
CharacterID: chat.CharacterID,
|
||
ChatType: chat.ChatType,
|
||
LastMessageAt: chat.LastMessageAt,
|
||
MessageCount: chat.MessageCount,
|
||
IsPinned: chat.IsPinned,
|
||
LastMessage: lastMsg,
|
||
CreatedAt: chat.CreatedAt,
|
||
}
|
||
|
||
if character != nil {
|
||
resp.CharacterName = character.Name
|
||
resp.CharacterAvatar = character.Avatar
|
||
}
|
||
|
||
return resp
|
||
}
|
||
|
||
func toMessageResponse(msg *app.AIMessage, character *app.AICharacter) response.MessageResponse {
|
||
resp := response.MessageResponse{
|
||
ID: msg.ID,
|
||
ChatID: msg.ChatID,
|
||
Content: msg.Content,
|
||
Role: msg.Role,
|
||
CharacterID: msg.CharacterID,
|
||
Model: msg.Model,
|
||
PromptTokens: msg.PromptTokens,
|
||
CompletionTokens: msg.CompletionTokens,
|
||
TotalTokens: msg.TotalTokens,
|
||
SequenceNumber: msg.SequenceNumber,
|
||
CreatedAt: msg.CreatedAt,
|
||
}
|
||
|
||
if msg.Role == "assistant" && character != nil {
|
||
resp.CharacterName = character.Name
|
||
}
|
||
|
||
return resp
|
||
}
|