1988 lines
60 KiB
Go
1988 lines
60 KiB
Go
package app
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"git.echol.cn/loser/st/server/global"
|
||
"git.echol.cn/loser/st/server/model/app"
|
||
"git.echol.cn/loser/st/server/model/app/request"
|
||
"git.echol.cn/loser/st/server/model/app/response"
|
||
"gorm.io/datatypes"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type ConversationService struct{}
|
||
|
||
// CreateConversation 创建对话
|
||
func (s *ConversationService) CreateConversation(userID uint, req *request.CreateConversationRequest) (*response.ConversationResponse, error) {
|
||
// 验证角色卡是否存在且有权访问
|
||
var character app.AICharacter
|
||
err := global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", req.CharacterID, userID, true).
|
||
First(&character).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("角色卡不存在或无权访问")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
// 生成对话标题
|
||
title := req.Title
|
||
if title == "" {
|
||
title = "与 " + character.Name + " 的对话"
|
||
}
|
||
|
||
// 获取默认 AI 配置
|
||
var aiConfig app.AIConfig
|
||
err = global.GVA_DB.Where("is_active = ?", true).
|
||
Order("is_default DESC, created_at DESC").
|
||
First(&aiConfig).Error
|
||
|
||
// 设置 AI 配置
|
||
aiProvider := req.AIProvider
|
||
model := req.Model
|
||
|
||
if err == nil {
|
||
// 如果找到了默认配置,使用它
|
||
if aiProvider == "" {
|
||
aiProvider = aiConfig.Provider
|
||
}
|
||
if model == "" {
|
||
model = aiConfig.DefaultModel
|
||
}
|
||
global.GVA_LOG.Info(fmt.Sprintf("创建对话使用 AI 配置: %s (Provider: %s, Model: %s)", aiConfig.Name, aiProvider, model))
|
||
} else {
|
||
// 如果没有找到配置,使用默认值
|
||
if aiProvider == "" {
|
||
aiProvider = "openai"
|
||
}
|
||
if model == "" {
|
||
model = "gpt-4"
|
||
}
|
||
global.GVA_LOG.Warn("未找到默认 AI 配置,使用硬编码默认值")
|
||
}
|
||
|
||
// 创建对话
|
||
conversation := app.Conversation{
|
||
UserID: userID,
|
||
CharacterID: req.CharacterID,
|
||
Title: title,
|
||
PresetID: req.PresetID,
|
||
WorldbookID: req.WorldbookID,
|
||
WorldbookEnabled: req.WorldbookEnabled,
|
||
AIProvider: aiProvider,
|
||
Model: model,
|
||
Settings: datatypes.JSON("{}"),
|
||
}
|
||
|
||
err = global.GVA_DB.Create(&conversation).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 如果角色有开场白,创建开场白消息
|
||
if character.FirstMes != "" {
|
||
// 获取用户信息
|
||
var user app.AppUser
|
||
err = global.GVA_DB.Where("id = ?", userID).First(&user).Error
|
||
if err != nil {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("获取用户信息失败: %v", err))
|
||
}
|
||
userName := user.Username
|
||
if userName == "" {
|
||
userName = user.NickName
|
||
}
|
||
|
||
// 【重要】不再应用正则脚本处理开场白,保留原始内容
|
||
// 让前端来处理 <Status_block> 和 <maintext> 的渲染
|
||
processedFirstMes := character.FirstMes
|
||
global.GVA_LOG.Info(fmt.Sprintf("[开场白] 保留原始内容,长度=%d", len(processedFirstMes)))
|
||
|
||
firstMessage := app.Message{
|
||
ConversationID: conversation.ID,
|
||
Role: "assistant",
|
||
Content: processedFirstMes,
|
||
TokenCount: len(processedFirstMes) / 4,
|
||
}
|
||
err = global.GVA_DB.Create(&firstMessage).Error
|
||
if err != nil {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("创建开场白消息失败: %v", err))
|
||
} else {
|
||
// 更新对话统计
|
||
conversation.MessageCount = 1
|
||
conversation.TokenCount = firstMessage.TokenCount
|
||
global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": 1,
|
||
"token_count": firstMessage.TokenCount,
|
||
})
|
||
}
|
||
}
|
||
|
||
resp := response.ToConversationResponse(&conversation)
|
||
return &resp, nil
|
||
}
|
||
|
||
// GetConversationList 获取对话列表
|
||
func (s *ConversationService) GetConversationList(userID uint, req *request.GetConversationListRequest) (*response.ConversationListResponse, error) {
|
||
var conversations []app.Conversation
|
||
var total int64
|
||
|
||
db := global.GVA_DB.Model(&app.Conversation{}).Where("user_id = ?", userID)
|
||
|
||
// 统计总数
|
||
err := db.Count(&total).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 分页查询
|
||
offset := (req.Page - 1) * req.PageSize
|
||
err = db.Order("updated_at DESC").Offset(offset).Limit(req.PageSize).Find(&conversations).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 收集所有角色ID
|
||
characterIDs := make([]uint, 0, len(conversations))
|
||
for _, conv := range conversations {
|
||
characterIDs = append(characterIDs, conv.CharacterID)
|
||
}
|
||
|
||
// 批量查询角色信息(只查询必要字段)
|
||
var characters []app.AICharacter
|
||
if len(characterIDs) > 0 {
|
||
err = global.GVA_DB.Select("id, name, avatar, description, created_at, updated_at").
|
||
Where("id IN ?", characterIDs).
|
||
Find(&characters).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// 创建角色ID到角色的映射
|
||
characterMap := make(map[uint]*app.AICharacter)
|
||
for i := range characters {
|
||
characterMap[characters[i].ID] = &characters[i]
|
||
}
|
||
|
||
// 转换响应(使用轻量级结构)
|
||
list := make([]response.ConversationListItemResponse, len(conversations))
|
||
for i, conv := range conversations {
|
||
character := characterMap[conv.CharacterID]
|
||
list[i] = response.ToConversationListItemResponse(&conv, character)
|
||
}
|
||
|
||
return &response.ConversationListResponse{
|
||
List: list,
|
||
Total: total,
|
||
Page: req.Page,
|
||
PageSize: req.PageSize,
|
||
}, nil
|
||
}
|
||
|
||
// GetConversationByID 获取对话详情
|
||
func (s *ConversationService) GetConversationByID(userID, conversationID uint) (*response.ConversationResponse, error) {
|
||
var conversation app.Conversation
|
||
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", conversationID, userID).
|
||
First(&conversation).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("对话不存在或无权访问")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
resp := response.ToConversationResponse(&conversation)
|
||
return &resp, nil
|
||
}
|
||
|
||
// UpdateConversationSettings 更新对话设置
|
||
func (s *ConversationService) UpdateConversationSettings(userID, conversationID uint, req *request.UpdateConversationSettingsRequest) error {
|
||
var conversation app.Conversation
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", conversationID, userID).First(&conversation).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("对话不存在或无权访问")
|
||
}
|
||
return err
|
||
}
|
||
|
||
updates := make(map[string]interface{})
|
||
|
||
// 更新设置
|
||
if req.Settings != nil {
|
||
settingsJSON, err := json.Marshal(req.Settings)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
updates["settings"] = datatypes.JSON(settingsJSON)
|
||
}
|
||
|
||
// 更新世界书ID
|
||
if req.WorldbookID != nil {
|
||
updates["worldbook_id"] = req.WorldbookID
|
||
}
|
||
|
||
// 更新世界书启用状态
|
||
if req.WorldbookEnabled != nil {
|
||
updates["worldbook_enabled"] = *req.WorldbookEnabled
|
||
}
|
||
|
||
if len(updates) == 0 {
|
||
return nil
|
||
}
|
||
|
||
return global.GVA_DB.Model(&conversation).Updates(updates).Error
|
||
}
|
||
|
||
// DeleteConversation 删除对话
|
||
func (s *ConversationService) DeleteConversation(userID, conversationID uint) error {
|
||
// 开启事务
|
||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||
// 删除对话的所有消息
|
||
err := tx.Where("conversation_id = ?", conversationID).Delete(&app.Message{}).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 删除对话
|
||
result := tx.Where("id = ? AND user_id = ?", conversationID, userID).Delete(&app.Conversation{})
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
if result.RowsAffected == 0 {
|
||
return errors.New("对话不存在或无权删除")
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// GetMessageList 获取消息列表
|
||
func (s *ConversationService) GetMessageList(userID, conversationID uint, req *request.GetMessageListRequest) (*response.MessageListResponse, error) {
|
||
// 验证对话权限
|
||
var conversation app.Conversation
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", conversationID, userID).
|
||
First(&conversation).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("对话不存在或无权访问")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
var messages []app.Message
|
||
var total int64
|
||
|
||
db := global.GVA_DB.Model(&app.Message{}).Where("conversation_id = ?", conversationID)
|
||
|
||
// 统计总数
|
||
err = db.Count(&total).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 分页查询
|
||
offset := (req.Page - 1) * req.PageSize
|
||
err = db.Order("created_at ASC").Offset(offset).Limit(req.PageSize).Find(&messages).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 转换响应
|
||
list := make([]response.MessageResponse, len(messages))
|
||
for i, msg := range messages {
|
||
list[i] = response.ToMessageResponse(&msg)
|
||
}
|
||
|
||
return &response.MessageListResponse{
|
||
List: list,
|
||
Total: total,
|
||
Page: req.Page,
|
||
PageSize: req.PageSize,
|
||
}, nil
|
||
}
|
||
|
||
// SendMessage 发送消息并获取 AI 回复
|
||
func (s *ConversationService) SendMessage(userID, conversationID uint, req *request.SendMessageRequest) (*response.MessageResponse, error) {
|
||
// 验证对话权限
|
||
var conversation app.Conversation
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", conversationID, userID).
|
||
First(&conversation).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("对话不存在或无权访问")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
// 获取角色卡信息
|
||
var character app.AICharacter
|
||
err = global.GVA_DB.Where("id = ?", conversation.CharacterID).First(&character).Error
|
||
if err != nil {
|
||
return nil, errors.New("角色卡不存在")
|
||
}
|
||
|
||
// 获取用户信息
|
||
var user app.AppUser
|
||
err = global.GVA_DB.Where("id = ?", userID).First(&user).Error
|
||
if err != nil {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("获取用户信息失败: %v", err))
|
||
}
|
||
userName := user.Username
|
||
if userName == "" {
|
||
userName = user.NickName
|
||
}
|
||
|
||
// 应用输入阶段的正则脚本 (Placement 0)
|
||
processedContent := req.Content
|
||
var regexService RegexScriptService
|
||
global.GVA_LOG.Info(fmt.Sprintf("查询输入阶段正则脚本: userID=%d, placement=0, charID=%d", userID, conversation.CharacterID))
|
||
inputScripts, err := regexService.GetScriptsForPlacement(userID, 0, &conversation.CharacterID, nil)
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("查询输入阶段正则脚本失败: %v", err))
|
||
} else {
|
||
global.GVA_LOG.Info(fmt.Sprintf("找到 %d 个输入阶段正则脚本", len(inputScripts)))
|
||
if len(inputScripts) > 0 {
|
||
processedContent = regexService.ExecuteScripts(inputScripts, processedContent, userName, character.Name)
|
||
global.GVA_LOG.Info(fmt.Sprintf("应用了 %d 个输入阶段正则脚本,原文: %s, 处理后: %s", len(inputScripts), req.Content, processedContent))
|
||
}
|
||
}
|
||
|
||
// 保存用户消息
|
||
userMessage := app.Message{
|
||
ConversationID: conversationID,
|
||
Role: "user",
|
||
Content: processedContent,
|
||
TokenCount: len(processedContent) / 4, // 简单估算
|
||
}
|
||
|
||
err = global.GVA_DB.Create(&userMessage).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取完整对话历史(context 管理由 callAIService 内部处理)
|
||
var messages []app.Message
|
||
err = global.GVA_DB.Where("conversation_id = ?", conversationID).
|
||
Order("created_at ASC").
|
||
Find(&messages).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 调用 AI 服务获取回复
|
||
aiResponse, err := s.callAIService(conversation, character, messages)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 保存 AI 回复
|
||
assistantMessage := app.Message{
|
||
ConversationID: conversationID,
|
||
Role: "assistant",
|
||
Content: aiResponse,
|
||
TokenCount: len(aiResponse) / 4, // 简单估算
|
||
}
|
||
|
||
err = global.GVA_DB.Create(&assistantMessage).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 更新对话统计
|
||
err = global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("message_count + ?", 2),
|
||
"token_count": gorm.Expr("token_count + ?", userMessage.TokenCount+assistantMessage.TokenCount),
|
||
}).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 提取并保存变量 (从 AI 回复中提取 {{setvar::key::value}})
|
||
newVars, cleanedContent := regexService.ExtractSetVars(assistantMessage.Content)
|
||
if len(newVars) > 0 {
|
||
// 加载现有变量
|
||
var existingVars map[string]string
|
||
if len(conversation.Variables) > 0 {
|
||
json.Unmarshal(conversation.Variables, &existingVars)
|
||
}
|
||
if existingVars == nil {
|
||
existingVars = make(map[string]string)
|
||
}
|
||
|
||
// 合并新变量
|
||
for k, v := range newVars {
|
||
existingVars[k] = v
|
||
}
|
||
|
||
// 保存回数据库
|
||
varsJSON, _ := json.Marshal(existingVars)
|
||
global.GVA_DB.Model(&conversation).Update("variables", datatypes.JSON(varsJSON))
|
||
global.GVA_LOG.Info(fmt.Sprintf("提取并保存了 %d 个变量: %v", len(newVars), newVars))
|
||
}
|
||
|
||
// 先替换 {{getvar::}} 为实际变量值(在应用正则脚本之前)
|
||
var currentVars map[string]string
|
||
if len(conversation.Variables) > 0 {
|
||
json.Unmarshal(conversation.Variables, ¤tVars)
|
||
}
|
||
displayContent := cleanedContent // 使用清理后的内容(移除了 {{setvar::}})
|
||
if currentVars != nil {
|
||
displayContent = regexService.SubstituteGetVars(displayContent, currentVars)
|
||
global.GVA_LOG.Info(fmt.Sprintf("替换了 {{getvar::}} 变量"))
|
||
}
|
||
|
||
// 注意:此时 displayContent 中的 <Status_block> 已经被 Placement 1 正则脚本
|
||
// 替换成了包含 YAML 数据的 HTML 模板,所以不需要再提取和保护
|
||
// 直接返回给前端即可
|
||
|
||
resp := response.ToMessageResponse(&assistantMessage)
|
||
resp.Content = displayContent // 使用处理后的显示内容
|
||
return &resp, nil
|
||
}
|
||
|
||
// callAIService 调用 AI 服务
|
||
func (s *ConversationService) callAIService(conversation app.Conversation, character app.AICharacter, messages []app.Message) (string, error) {
|
||
// 获取 AI 配置
|
||
|
||
var aiConfig app.AIConfig
|
||
var err error
|
||
|
||
// 1. 尝试从对话设置中获取指定的 AI 配置 ID
|
||
var configID uint
|
||
if len(conversation.Settings) > 0 {
|
||
var settings map[string]interface{}
|
||
if err := json.Unmarshal(conversation.Settings, &settings); err == nil {
|
||
if id, ok := settings["aiConfigId"].(float64); ok {
|
||
configID = uint(id)
|
||
}
|
||
}
|
||
}
|
||
|
||
if configID > 0 {
|
||
// 使用用户指定的 AI 配置
|
||
global.GVA_LOG.Info(fmt.Sprintf("使用用户指定的 AI 配置 ID: %d", configID))
|
||
err = global.GVA_DB.Where("id = ? AND is_active = ?", configID, true).First(&aiConfig).Error
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("未找到指定的 AI 配置 ID: %d, 错误: %v", configID, err))
|
||
}
|
||
}
|
||
|
||
if err != nil || configID == 0 {
|
||
// 使用默认 AI 配置
|
||
global.GVA_LOG.Info("尝试使用默认 AI 配置")
|
||
err = global.GVA_DB.Where("is_active = ?", true).
|
||
Order("is_default DESC, created_at DESC").
|
||
First(&aiConfig).Error
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("未找到默认 AI 配置, 错误: %v", err))
|
||
}
|
||
}
|
||
|
||
if err != nil {
|
||
return "", errors.New("未找到可用的 AI 配置,请在管理后台添加并激活 AI 配置")
|
||
}
|
||
|
||
global.GVA_LOG.Info(fmt.Sprintf("使用 AI 配置: %s (Provider: %s, Model: %s)", aiConfig.Name, aiConfig.Provider, aiConfig.DefaultModel))
|
||
|
||
// 2. 尝试从对话设置中获取预设 ID 并加载预设参数
|
||
var preset *app.AIPreset
|
||
var presetID uint
|
||
if len(conversation.Settings) > 0 {
|
||
var settings map[string]interface{}
|
||
if err := json.Unmarshal(conversation.Settings, &settings); err == nil {
|
||
if id, ok := settings["presetId"].(float64); ok {
|
||
presetID = uint(id)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 加载预设
|
||
if presetID > 0 {
|
||
var loadedPreset app.AIPreset
|
||
if err := global.GVA_DB.First(&loadedPreset, presetID).Error; err == nil {
|
||
preset = &loadedPreset
|
||
global.GVA_LOG.Info(fmt.Sprintf("使用预设: %s (Temperature: %.2f, TopP: %.2f)", preset.Name, preset.Temperature, preset.TopP))
|
||
|
||
// 增加预设使用次数
|
||
global.GVA_DB.Model(&preset).Update("use_count", gorm.Expr("use_count + ?", 1))
|
||
} else {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("未找到预设 ID: %d, 使用默认参数", presetID))
|
||
}
|
||
}
|
||
|
||
// 构建消息列表(含 context 预算管理)
|
||
var presetSysPrompt string
|
||
if preset != nil {
|
||
presetSysPrompt = preset.SystemPrompt
|
||
}
|
||
wbEngine := &WorldbookEngine{}
|
||
apiMessages := s.buildAPIMessagesWithContextManagement(
|
||
messages, character, presetSysPrompt, wbEngine, conversation, &aiConfig, preset,
|
||
)
|
||
|
||
// 从 apiMessages 中提取 systemPrompt,供 Anthropic 独立参数使用
|
||
systemPrompt := ""
|
||
if len(apiMessages) > 0 && apiMessages[0]["role"] == "system" {
|
||
systemPrompt = apiMessages[0]["content"]
|
||
}
|
||
|
||
// 打印发送给AI的完整内容
|
||
global.GVA_LOG.Info("========== 发送给AI的完整内容 ==========")
|
||
global.GVA_LOG.Info(fmt.Sprintf("系统提示词长度: %d 字符", len(systemPrompt)))
|
||
global.GVA_LOG.Info(fmt.Sprintf("历史消息条数: %d", len(apiMessages)-1))
|
||
global.GVA_LOG.Info("==========================================")
|
||
|
||
// 确定使用的模型:如果用户在设置中指定了AI配置,则使用该配置的默认模型
|
||
// 否则使用对话创建时的模型(向后兼容)
|
||
model := aiConfig.DefaultModel
|
||
if model == "" {
|
||
// 如果AI配置没有默认模型,才使用对话表中的模型
|
||
model = conversation.Model
|
||
}
|
||
if model == "" {
|
||
// 最后的兜底
|
||
model = "gpt-4"
|
||
}
|
||
|
||
global.GVA_LOG.Info(fmt.Sprintf("使用模型: %s (来源: AI配置 %s)", model, aiConfig.Name))
|
||
|
||
// 根据提供商调用不同的 API
|
||
var aiResponse string
|
||
|
||
switch aiConfig.Provider {
|
||
case "openai", "custom":
|
||
aiResponse, err = s.callOpenAIAPI(&aiConfig, model, apiMessages, preset)
|
||
case "anthropic":
|
||
aiResponse, err = s.callAnthropicAPI(&aiConfig, model, apiMessages, systemPrompt, preset)
|
||
default:
|
||
return "", fmt.Errorf("不支持的 AI 提供商: %s", aiConfig.Provider)
|
||
}
|
||
|
||
// 打印AI返回的完整内容
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("========== AI返回错误 ==========\n%v\n==========================================", err))
|
||
return "", err
|
||
}
|
||
global.GVA_LOG.Info(fmt.Sprintf("========== AI返回的完整内容 ==========\n%s\n==========================================", aiResponse))
|
||
|
||
// 应用输出阶段的正则脚本 (Placement 1)
|
||
// 这里会把 <Status_block> 替换成 HTML 模板,并注入 YAML 数据
|
||
var regexService RegexScriptService
|
||
global.GVA_LOG.Info(fmt.Sprintf("查询输出阶段正则脚本: userID=%d, placement=1, charID=%d", conversation.UserID, conversation.CharacterID))
|
||
outputScripts, err := regexService.GetScriptsForPlacement(conversation.UserID, 1, &conversation.CharacterID, nil)
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("查询输出阶段正则脚本失败: %v", err))
|
||
} else {
|
||
global.GVA_LOG.Info(fmt.Sprintf("找到 %d 个输出阶段正则脚本", len(outputScripts)))
|
||
if len(outputScripts) > 0 {
|
||
// 获取用户信息
|
||
var user app.AppUser
|
||
err = global.GVA_DB.Where("id = ?", conversation.UserID).First(&user).Error
|
||
userName := ""
|
||
if err == nil {
|
||
userName = user.Username
|
||
if userName == "" {
|
||
userName = user.NickName
|
||
}
|
||
}
|
||
|
||
originalResponse := aiResponse
|
||
aiResponse = regexService.ExecuteScripts(outputScripts, aiResponse, userName, character.Name)
|
||
global.GVA_LOG.Info(fmt.Sprintf("应用了 %d 个输出阶段正则脚本,原始长度: %d, 处理后长度: %d", len(outputScripts), len(originalResponse), len(aiResponse)))
|
||
}
|
||
}
|
||
|
||
return aiResponse, nil
|
||
}
|
||
|
||
// buildSystemPrompt 构建系统提示词
|
||
func (s *ConversationService) buildSystemPrompt(character app.AICharacter) string {
|
||
prompt := fmt.Sprintf("你是 %s。", character.Name)
|
||
|
||
if character.Description != "" {
|
||
prompt += fmt.Sprintf("\n\n描述:%s", character.Description)
|
||
}
|
||
|
||
if character.Personality != "" {
|
||
prompt += fmt.Sprintf("\n\n性格:%s", character.Personality)
|
||
}
|
||
|
||
if character.Scenario != "" {
|
||
prompt += fmt.Sprintf("\n\n场景:%s", character.Scenario)
|
||
}
|
||
|
||
if character.FirstMes != "" {
|
||
prompt += fmt.Sprintf("\n\n开场白:%s", character.FirstMes)
|
||
}
|
||
|
||
if character.MesExample != "" {
|
||
prompt += fmt.Sprintf("\n\n对话示例:\n%s", character.MesExample)
|
||
}
|
||
|
||
if character.SystemPrompt != "" {
|
||
prompt += fmt.Sprintf("\n\n系统提示:%s", character.SystemPrompt)
|
||
}
|
||
|
||
// 处理世界书 (Character Book)
|
||
if len(character.CharacterBook) > 0 {
|
||
var characterBook map[string]interface{}
|
||
if err := json.Unmarshal(character.CharacterBook, &characterBook); err == nil {
|
||
if entries, ok := characterBook["entries"].([]interface{}); ok && len(entries) > 0 {
|
||
prompt += "\n\n世界设定:"
|
||
for _, entry := range entries {
|
||
if entryMap, ok := entry.(map[string]interface{}); ok {
|
||
// 默认启用,除非明确设置为false
|
||
enabled := true
|
||
if enabledVal, ok := entryMap["enabled"].(bool); ok {
|
||
enabled = enabledVal
|
||
}
|
||
if !enabled {
|
||
continue
|
||
}
|
||
// 添加世界书条目内容
|
||
if content, ok := entryMap["content"].(string); ok && content != "" {
|
||
prompt += fmt.Sprintf("\n- %s", content)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
prompt += "\n\n请根据以上设定进行角色扮演,保持角色的性格和说话方式。"
|
||
|
||
// 应用MVU变量替换
|
||
prompt = s.applyMacroVariables(prompt, character)
|
||
|
||
return prompt
|
||
}
|
||
|
||
// applyMacroVariables 应用宏变量替换 (MVU功能)
|
||
func (s *ConversationService) applyMacroVariables(text string, character app.AICharacter) string {
|
||
// 获取当前时间
|
||
now := time.Now()
|
||
|
||
// 基础变量
|
||
replacements := map[string]string{
|
||
"{{char}}": character.Name,
|
||
"{{user}}": "用户", // 可以从用户信息中获取
|
||
"{{time}}": now.Format("15:04"),
|
||
"{{date}}": now.Format("2006-01-02"),
|
||
"{{datetime}}": now.Format("2006-01-02 15:04:05"),
|
||
"{{weekday}}": s.getWeekdayInChinese(now.Weekday()),
|
||
"{{idle_duration}}": "0分钟",
|
||
}
|
||
|
||
// 执行替换
|
||
result := text
|
||
for macro, value := range replacements {
|
||
result = strings.ReplaceAll(result, macro, value)
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// getWeekdayInChinese 获取中文星期
|
||
func (s *ConversationService) getWeekdayInChinese(weekday time.Weekday) string {
|
||
weekdays := map[time.Weekday]string{
|
||
time.Sunday: "星期日",
|
||
time.Monday: "星期一",
|
||
time.Tuesday: "星期二",
|
||
time.Wednesday: "星期三",
|
||
time.Thursday: "星期四",
|
||
time.Friday: "星期五",
|
||
time.Saturday: "星期六",
|
||
}
|
||
return weekdays[weekday]
|
||
}
|
||
|
||
// SendMessageStream 流式发送消息并获取 AI 回复
|
||
func (s *ConversationService) SendMessageStream(ctx context.Context, userID, conversationID uint, req *request.SendMessageRequest, streamChan chan string, doneChan chan bool) error {
|
||
defer close(streamChan)
|
||
defer close(doneChan)
|
||
|
||
// 验证对话权限
|
||
var conversation app.Conversation
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", conversationID, userID).
|
||
First(&conversation).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("对话不存在或无权访问")
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 获取角色卡信息
|
||
var character app.AICharacter
|
||
err = global.GVA_DB.Where("id = ?", conversation.CharacterID).First(&character).Error
|
||
if err != nil {
|
||
return errors.New("角色卡不存在")
|
||
}
|
||
|
||
// 获取用户信息
|
||
var user app.AppUser
|
||
err = global.GVA_DB.Where("id = ?", userID).First(&user).Error
|
||
if err != nil {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("[流式传输] 获取用户信息失败: %v", err))
|
||
}
|
||
userName := user.Username
|
||
if userName == "" {
|
||
userName = user.NickName
|
||
}
|
||
|
||
// 应用输入阶段的正则脚本 (Placement 0)
|
||
processedContent := req.Content
|
||
var regexService RegexScriptService
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 查询输入阶段正则脚本: userID=%d, placement=0, charID=%d", userID, conversation.CharacterID))
|
||
inputScripts, err := regexService.GetScriptsForPlacement(userID, 0, &conversation.CharacterID, nil)
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("[流式传输] 查询输入阶段正则脚本失败: %v", err))
|
||
} else {
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 找到 %d 个输入阶段正则脚本", len(inputScripts)))
|
||
if len(inputScripts) > 0 {
|
||
processedContent = regexService.ExecuteScripts(inputScripts, processedContent, userName, character.Name)
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 应用了 %d 个输入阶段正则脚本", len(inputScripts)))
|
||
}
|
||
}
|
||
|
||
// 保存用户消息
|
||
userMessage := app.Message{
|
||
ConversationID: conversationID,
|
||
Role: "user",
|
||
Content: processedContent,
|
||
TokenCount: len(processedContent) / 4,
|
||
}
|
||
|
||
err = global.GVA_DB.Create(&userMessage).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取完整对话历史(context 管理由 buildAPIMessagesWithContextManagement 处理)
|
||
var messages []app.Message
|
||
err = global.GVA_DB.Where("conversation_id = ?", conversationID).
|
||
Order("created_at ASC").
|
||
Find(&messages).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取 AI 配置
|
||
var aiConfig app.AIConfig
|
||
var configID uint
|
||
if len(conversation.Settings) > 0 {
|
||
var settings map[string]interface{}
|
||
if err := json.Unmarshal(conversation.Settings, &settings); err == nil {
|
||
if id, ok := settings["aiConfigId"].(float64); ok {
|
||
configID = uint(id)
|
||
}
|
||
}
|
||
}
|
||
|
||
if configID > 0 {
|
||
err = global.GVA_DB.Where("id = ? AND is_active = ?", configID, true).First(&aiConfig).Error
|
||
}
|
||
|
||
if err != nil || configID == 0 {
|
||
err = global.GVA_DB.Where("is_active = ?", true).
|
||
Order("is_default DESC, created_at DESC").
|
||
First(&aiConfig).Error
|
||
}
|
||
|
||
if err != nil {
|
||
return errors.New("未找到可用的 AI 配置")
|
||
}
|
||
|
||
// 加载预设
|
||
var streamPreset *app.AIPreset
|
||
var streamPresetID uint
|
||
if len(conversation.Settings) > 0 {
|
||
var settings map[string]interface{}
|
||
if err := json.Unmarshal(conversation.Settings, &settings); err == nil {
|
||
if id, ok := settings["presetId"].(float64); ok {
|
||
streamPresetID = uint(id)
|
||
}
|
||
}
|
||
}
|
||
if streamPresetID > 0 {
|
||
var loadedPreset app.AIPreset
|
||
if err := global.GVA_DB.First(&loadedPreset, streamPresetID).Error; err == nil {
|
||
streamPreset = &loadedPreset
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 使用预设: %s (Temperature: %.2f)", streamPreset.Name, streamPreset.Temperature))
|
||
global.GVA_DB.Model(streamPreset).Update("use_count", gorm.Expr("use_count + ?", 1))
|
||
}
|
||
}
|
||
|
||
// 构建消息列表(含 context 预算管理)
|
||
var streamPresetSysPrompt string
|
||
if streamPreset != nil {
|
||
streamPresetSysPrompt = streamPreset.SystemPrompt
|
||
}
|
||
streamWbEngine := &WorldbookEngine{}
|
||
apiMessages := s.buildAPIMessagesWithContextManagement(
|
||
messages, character, streamPresetSysPrompt, streamWbEngine, conversation, &aiConfig, streamPreset,
|
||
)
|
||
|
||
// 从 apiMessages 中提取 systemPrompt,供 Anthropic 独立参数使用
|
||
systemPrompt := ""
|
||
if len(apiMessages) > 0 && apiMessages[0]["role"] == "system" {
|
||
systemPrompt = apiMessages[0]["content"]
|
||
}
|
||
|
||
// 打印发送给AI的完整内容(流式传输)
|
||
global.GVA_LOG.Info("========== [流式传输] 发送给AI的完整内容 ==========")
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 系统提示词长度: %d 字符", len(systemPrompt)))
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 历史消息条数: %d", len(apiMessages)-1))
|
||
global.GVA_LOG.Info("==========================================")
|
||
|
||
// 确定使用的模型
|
||
model := aiConfig.DefaultModel
|
||
if model == "" {
|
||
model = conversation.Model
|
||
}
|
||
if model == "" {
|
||
model = "gpt-4"
|
||
}
|
||
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 使用模型: %s (Provider: %s)", model, aiConfig.Provider))
|
||
|
||
// 调用流式 API
|
||
var fullContent string
|
||
switch aiConfig.Provider {
|
||
case "openai", "custom":
|
||
fullContent, err = s.callOpenAIAPIStream(ctx, &aiConfig, model, apiMessages, streamPreset, streamChan)
|
||
case "anthropic":
|
||
fullContent, err = s.callAnthropicAPIStream(ctx, &aiConfig, model, apiMessages, systemPrompt, streamPreset, streamChan)
|
||
default:
|
||
return fmt.Errorf("不支持的 AI 提供商: %s", aiConfig.Provider)
|
||
}
|
||
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("========== [流式传输] AI返回错误 ==========\n%v\n==========================================", err))
|
||
// AI 调用失败,回滚已写入的用户消息,避免孤立记录残留在数据库
|
||
if delErr := global.GVA_DB.Delete(&userMessage).Error; delErr != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("[流式传输] 回滚用户消息失败: %v", delErr))
|
||
} else {
|
||
global.GVA_LOG.Info("[流式传输] 已回滚用户消息")
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 打印AI返回的完整内容
|
||
global.GVA_LOG.Info(fmt.Sprintf("========== [流式传输] AI返回的完整内容 ==========\n%s\n==========================================", fullContent))
|
||
|
||
// 应用输出阶段的正则脚本 (Placement 1)
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 查询输出阶段正则脚本: userID=%d, placement=1, charID=%d", userID, conversation.CharacterID))
|
||
outputScripts, err := regexService.GetScriptsForPlacement(userID, 1, &conversation.CharacterID, nil)
|
||
if err != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("[流式传输] 查询输出阶段正则脚本失败: %v", err))
|
||
} else {
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 找到 %d 个输出阶段正则脚本", len(outputScripts)))
|
||
if len(outputScripts) > 0 {
|
||
fullContent = regexService.ExecuteScripts(outputScripts, fullContent, userName, character.Name)
|
||
global.GVA_LOG.Info(fmt.Sprintf("[流式传输] 应用了 %d 个输出阶段正则脚本", len(outputScripts)))
|
||
}
|
||
}
|
||
|
||
// 保存 AI 回复
|
||
assistantMessage := app.Message{
|
||
ConversationID: conversationID,
|
||
Role: "assistant",
|
||
Content: fullContent,
|
||
TokenCount: len(fullContent) / 4,
|
||
}
|
||
|
||
err = global.GVA_DB.Create(&assistantMessage).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 更新对话统计
|
||
err = global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("message_count + ?", 2),
|
||
"token_count": gorm.Expr("token_count + ?", userMessage.TokenCount+assistantMessage.TokenCount),
|
||
}).Error
|
||
|
||
doneChan <- true
|
||
return err
|
||
}
|
||
|
||
// callOpenAIAPIStream 调用 OpenAI API 流式传输
|
||
func (s *ConversationService) callOpenAIAPIStream(ctx context.Context, config *app.AIConfig, model string, messages []map[string]string, preset *app.AIPreset, streamChan chan string) (string, error) {
|
||
// 不设 Timeout:生命周期由调用方传入的 ctx 控制(客户端断连时自动取消)
|
||
client := &http.Client{}
|
||
|
||
if model == "" {
|
||
model = config.DefaultModel
|
||
}
|
||
if model == "" {
|
||
model = "gpt-4"
|
||
}
|
||
|
||
// 应用预设参数
|
||
temperature := 0.7
|
||
maxTokens := 2000
|
||
var topP *float64
|
||
var frequencyPenalty *float64
|
||
var presencePenalty *float64
|
||
var stopSequences []string
|
||
|
||
if preset != nil {
|
||
temperature = preset.Temperature
|
||
maxTokens = preset.MaxTokens
|
||
if preset.TopP > 0 {
|
||
topP = &preset.TopP
|
||
}
|
||
if preset.FrequencyPenalty != 0 {
|
||
frequencyPenalty = &preset.FrequencyPenalty
|
||
}
|
||
if preset.PresencePenalty != 0 {
|
||
presencePenalty = &preset.PresencePenalty
|
||
}
|
||
if len(preset.StopSequences) > 0 {
|
||
json.Unmarshal(preset.StopSequences, &stopSequences)
|
||
}
|
||
}
|
||
|
||
// 构建请求体,启用流式传输
|
||
requestBody := map[string]interface{}{
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": maxTokens,
|
||
"stream": true,
|
||
}
|
||
|
||
if topP != nil {
|
||
requestBody["top_p"] = *topP
|
||
}
|
||
if frequencyPenalty != nil {
|
||
requestBody["frequency_penalty"] = *frequencyPenalty
|
||
}
|
||
if presencePenalty != nil {
|
||
requestBody["presence_penalty"] = *presencePenalty
|
||
}
|
||
if len(stopSequences) > 0 {
|
||
requestBody["stop"] = stopSequences
|
||
}
|
||
|
||
bodyBytes, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
endpoint := config.BaseURL + "/chat/completions"
|
||
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(bodyBytes))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
// 客户端主动断开时 ctx 被取消,不算真正的错误
|
||
if ctx.Err() != nil {
|
||
return "", nil
|
||
}
|
||
return "", fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return "", fmt.Errorf("API 返回错误 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 读取流式响应
|
||
var fullContent strings.Builder
|
||
reader := bufio.NewReader(resp.Body)
|
||
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
// 先处理本次读到的数据(EOF 时可能仍携带最后一行内容)
|
||
if line != "" {
|
||
trimmed := strings.TrimSpace(line)
|
||
if trimmed != "" && trimmed != "data: [DONE]" && strings.HasPrefix(trimmed, "data: ") {
|
||
data := strings.TrimPrefix(trimmed, "data: ")
|
||
|
||
var streamResp struct {
|
||
Choices []struct {
|
||
Delta struct {
|
||
Content string `json:"content"`
|
||
} `json:"delta"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if jsonErr := json.Unmarshal([]byte(data), &streamResp); jsonErr == nil {
|
||
if len(streamResp.Choices) > 0 {
|
||
content := streamResp.Choices[0].Delta.Content
|
||
if content != "" {
|
||
fullContent.WriteString(content)
|
||
streamChan <- content
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// 再检查读取错误
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
// ctx 被取消(客户端断开)时不算真正的流读取错误
|
||
if ctx.Err() != nil {
|
||
return fullContent.String(), nil
|
||
}
|
||
return "", fmt.Errorf("读取流失败: %v", err)
|
||
}
|
||
}
|
||
|
||
return fullContent.String(), nil
|
||
}
|
||
|
||
// callAnthropicAPIStream 调用 Anthropic API 流式传输
|
||
func (s *ConversationService) callAnthropicAPIStream(ctx context.Context, config *app.AIConfig, model string, messages []map[string]string, systemPrompt string, preset *app.AIPreset, streamChan chan string) (string, error) {
|
||
// 不设 Timeout:生命周期由调用方传入的 ctx 控制(客户端断连时自动取消)
|
||
client := &http.Client{}
|
||
|
||
if model == "" {
|
||
model = config.DefaultModel
|
||
}
|
||
if model == "" {
|
||
model = "claude-3-sonnet-20240229"
|
||
}
|
||
|
||
// Anthropic API 不支持 system role
|
||
apiMessages := make([]map[string]string, 0)
|
||
for _, msg := range messages {
|
||
if msg["role"] != "system" {
|
||
apiMessages = append(apiMessages, msg)
|
||
}
|
||
}
|
||
|
||
// 应用预设参数
|
||
maxTokens := 2000
|
||
var temperature *float64
|
||
var topP *float64
|
||
var stopSequences []string
|
||
|
||
if preset != nil {
|
||
maxTokens = preset.MaxTokens
|
||
if preset.Temperature > 0 {
|
||
temperature = &preset.Temperature
|
||
}
|
||
if preset.TopP > 0 {
|
||
topP = &preset.TopP
|
||
}
|
||
if len(preset.StopSequences) > 0 {
|
||
json.Unmarshal(preset.StopSequences, &stopSequences)
|
||
}
|
||
}
|
||
|
||
requestBody := map[string]interface{}{
|
||
"model": model,
|
||
"messages": apiMessages,
|
||
"system": systemPrompt,
|
||
"max_tokens": maxTokens,
|
||
"stream": true,
|
||
}
|
||
|
||
if temperature != nil {
|
||
requestBody["temperature"] = *temperature
|
||
}
|
||
if topP != nil {
|
||
requestBody["top_p"] = *topP
|
||
}
|
||
if len(stopSequences) > 0 {
|
||
requestBody["stop_sequences"] = stopSequences
|
||
}
|
||
|
||
bodyBytes, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
endpoint := config.BaseURL + "/messages"
|
||
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(bodyBytes))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("x-api-key", config.APIKey)
|
||
req.Header.Set("anthropic-version", "2023-06-01")
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
// 客户端主动断开时 ctx 被取消,不算真正的错误
|
||
if ctx.Err() != nil {
|
||
return "", nil
|
||
}
|
||
return "", fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return "", fmt.Errorf("API 返回错误 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 读取流式响应
|
||
var fullContent strings.Builder
|
||
reader := bufio.NewReader(resp.Body)
|
||
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
// 先处理本次读到的数据(EOF 时可能仍携带最后一行内容)
|
||
if line != "" {
|
||
trimmed := strings.TrimSpace(line)
|
||
if trimmed != "" && strings.HasPrefix(trimmed, "data: ") {
|
||
data := strings.TrimPrefix(trimmed, "data: ")
|
||
|
||
var streamResp struct {
|
||
Type string `json:"type"`
|
||
Delta struct {
|
||
Type string `json:"type"`
|
||
Text string `json:"text"`
|
||
} `json:"delta"`
|
||
}
|
||
|
||
if jsonErr := json.Unmarshal([]byte(data), &streamResp); jsonErr == nil {
|
||
if streamResp.Type == "content_block_delta" && streamResp.Delta.Text != "" {
|
||
fullContent.WriteString(streamResp.Delta.Text)
|
||
streamChan <- streamResp.Delta.Text
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// 再检查读取错误
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
// ctx 被取消(客户端断开)时不算真正的流读取错误
|
||
if ctx.Err() != nil {
|
||
return fullContent.String(), nil
|
||
}
|
||
return "", fmt.Errorf("读取流失败: %v", err)
|
||
}
|
||
}
|
||
|
||
return fullContent.String(), nil
|
||
}
|
||
|
||
// RegenerateMessage 重新生成最后一条 AI 回复(非流式)
|
||
func (s *ConversationService) RegenerateMessage(userID, conversationID uint) (*response.MessageResponse, error) {
|
||
var conversation app.Conversation
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", conversationID, userID).First(&conversation).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("对话不存在或无权访问")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
var character app.AICharacter
|
||
err = global.GVA_DB.Where("id = ?", conversation.CharacterID).First(&character).Error
|
||
if err != nil {
|
||
return nil, errors.New("角色卡不存在")
|
||
}
|
||
|
||
// 删除最后一条 AI 回复
|
||
var lastAssistantMsg app.Message
|
||
if err = global.GVA_DB.Where("conversation_id = ? AND role = ?", conversationID, "assistant").
|
||
Order("created_at DESC").First(&lastAssistantMsg).Error; err == nil {
|
||
global.GVA_DB.Delete(&lastAssistantMsg)
|
||
global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("GREATEST(message_count - 1, 0)"),
|
||
"token_count": gorm.Expr("GREATEST(token_count - ?, 0)", lastAssistantMsg.TokenCount),
|
||
})
|
||
}
|
||
|
||
// 获取删除后的完整消息历史(context 管理由 callAIService 内部处理)
|
||
var messages []app.Message
|
||
err = global.GVA_DB.Where("conversation_id = ?", conversationID).
|
||
Order("created_at ASC").Find(&messages).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(messages) == 0 {
|
||
return nil, errors.New("没有可用的消息历史")
|
||
}
|
||
|
||
aiResponse, err := s.callAIService(conversation, character, messages)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
assistantMessage := app.Message{
|
||
ConversationID: conversationID,
|
||
Role: "assistant",
|
||
Content: aiResponse,
|
||
TokenCount: len(aiResponse) / 4,
|
||
}
|
||
if err = global.GVA_DB.Create(&assistantMessage).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("message_count + ?", 1),
|
||
"token_count": gorm.Expr("token_count + ?", assistantMessage.TokenCount),
|
||
})
|
||
|
||
resp := response.ToMessageResponse(&assistantMessage)
|
||
return &resp, nil
|
||
}
|
||
|
||
// RegenerateMessageStream 流式重新生成最后一条 AI 回复
|
||
func (s *ConversationService) RegenerateMessageStream(ctx context.Context, userID, conversationID uint, streamChan chan string, doneChan chan bool) error {
|
||
defer close(streamChan)
|
||
defer close(doneChan)
|
||
|
||
var conversation app.Conversation
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", conversationID, userID).First(&conversation).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("对话不存在或无权访问")
|
||
}
|
||
return err
|
||
}
|
||
|
||
var character app.AICharacter
|
||
err = global.GVA_DB.Where("id = ?", conversation.CharacterID).First(&character).Error
|
||
if err != nil {
|
||
return errors.New("角色卡不存在")
|
||
}
|
||
|
||
// 删除最后一条 AI 回复
|
||
var lastAssistantMsg app.Message
|
||
if err = global.GVA_DB.Where("conversation_id = ? AND role = ?", conversationID, "assistant").
|
||
Order("created_at DESC").First(&lastAssistantMsg).Error; err == nil {
|
||
global.GVA_DB.Delete(&lastAssistantMsg)
|
||
global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("GREATEST(message_count - 1, 0)"),
|
||
"token_count": gorm.Expr("GREATEST(token_count - ?, 0)", lastAssistantMsg.TokenCount),
|
||
})
|
||
}
|
||
|
||
// 获取删除后的完整消息历史(context 管理由 buildAPIMessagesWithContextManagement 处理)
|
||
var messages []app.Message
|
||
err = global.GVA_DB.Where("conversation_id = ?", conversationID).
|
||
Order("created_at ASC").Find(&messages).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(messages) == 0 {
|
||
return errors.New("没有可用的消息历史")
|
||
}
|
||
|
||
// 获取 AI 配置
|
||
var aiConfig app.AIConfig
|
||
var configID uint
|
||
if len(conversation.Settings) > 0 {
|
||
var settings map[string]interface{}
|
||
if err := json.Unmarshal(conversation.Settings, &settings); err == nil {
|
||
if id, ok := settings["aiConfigId"].(float64); ok {
|
||
configID = uint(id)
|
||
}
|
||
}
|
||
}
|
||
if configID > 0 {
|
||
err = global.GVA_DB.Where("id = ? AND is_active = ?", configID, true).First(&aiConfig).Error
|
||
}
|
||
if err != nil || configID == 0 {
|
||
err = global.GVA_DB.Where("is_active = ?", true).
|
||
Order("is_default DESC, created_at DESC").
|
||
First(&aiConfig).Error
|
||
}
|
||
if err != nil {
|
||
return errors.New("未找到可用的 AI 配置")
|
||
}
|
||
|
||
// 加载预设
|
||
var preset *app.AIPreset
|
||
var presetID uint
|
||
if len(conversation.Settings) > 0 {
|
||
var settings map[string]interface{}
|
||
if err := json.Unmarshal(conversation.Settings, &settings); err == nil {
|
||
if id, ok := settings["presetId"].(float64); ok {
|
||
presetID = uint(id)
|
||
}
|
||
}
|
||
}
|
||
if presetID > 0 {
|
||
var loadedPreset app.AIPreset
|
||
if err := global.GVA_DB.First(&loadedPreset, presetID).Error; err == nil {
|
||
preset = &loadedPreset
|
||
}
|
||
}
|
||
|
||
// 构建消息列表(含 context 预算管理)
|
||
var regenPresetSysPrompt string
|
||
if preset != nil {
|
||
regenPresetSysPrompt = preset.SystemPrompt
|
||
}
|
||
regenWbEngine := &WorldbookEngine{}
|
||
apiMessages := s.buildAPIMessagesWithContextManagement(
|
||
messages, character, regenPresetSysPrompt, regenWbEngine, conversation, &aiConfig, preset,
|
||
)
|
||
|
||
// 从 apiMessages 中提取 systemPrompt,供 Anthropic 独立参数使用
|
||
systemPrompt := ""
|
||
if len(apiMessages) > 0 && apiMessages[0]["role"] == "system" {
|
||
systemPrompt = apiMessages[0]["content"]
|
||
}
|
||
|
||
model := aiConfig.DefaultModel
|
||
if model == "" {
|
||
model = conversation.Model
|
||
}
|
||
if model == "" {
|
||
model = "gpt-4"
|
||
}
|
||
|
||
var fullContent string
|
||
switch aiConfig.Provider {
|
||
case "openai", "custom":
|
||
fullContent, err = s.callOpenAIAPIStream(ctx, &aiConfig, model, apiMessages, preset, streamChan)
|
||
case "anthropic":
|
||
fullContent, err = s.callAnthropicAPIStream(ctx, &aiConfig, model, apiMessages, systemPrompt, preset, streamChan)
|
||
default:
|
||
return fmt.Errorf("不支持的 AI 提供商: %s", aiConfig.Provider)
|
||
}
|
||
|
||
if err != nil {
|
||
// AI 调用失败,恢复刚才删除的 assistant 消息,避免数据永久丢失
|
||
if lastAssistantMsg.ID > 0 {
|
||
if restoreErr := global.GVA_DB.Unscoped().Model(&lastAssistantMsg).Update("deleted_at", nil).Error; restoreErr != nil {
|
||
global.GVA_LOG.Error(fmt.Sprintf("[重新生成] 恢复 assistant 消息失败: %v", restoreErr))
|
||
} else {
|
||
// 回滚 conversation 统计
|
||
global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("message_count + 1"),
|
||
"token_count": gorm.Expr("token_count + ?", lastAssistantMsg.TokenCount),
|
||
})
|
||
global.GVA_LOG.Info("[重新生成] 已恢复 assistant 消息")
|
||
}
|
||
}
|
||
return err
|
||
}
|
||
|
||
assistantMessage := app.Message{
|
||
ConversationID: conversationID,
|
||
Role: "assistant",
|
||
Content: fullContent,
|
||
TokenCount: len(fullContent) / 4,
|
||
}
|
||
if err = global.GVA_DB.Create(&assistantMessage).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
global.GVA_DB.Model(&conversation).Updates(map[string]interface{}{
|
||
"message_count": gorm.Expr("message_count + ?", 1),
|
||
"token_count": gorm.Expr("token_count + ?", assistantMessage.TokenCount),
|
||
})
|
||
|
||
doneChan <- true
|
||
return nil
|
||
}
|
||
|
||
func (s *ConversationService) buildAPIMessages(messages []app.Message, systemPrompt string) []map[string]string {
|
||
apiMessages := make([]map[string]string, 0, len(messages)+1)
|
||
|
||
// 添加系统消息(OpenAI 格式)
|
||
apiMessages = append(apiMessages, map[string]string{
|
||
"role": "system",
|
||
"content": systemPrompt,
|
||
})
|
||
|
||
// 添加历史消息
|
||
for _, msg := range messages {
|
||
if msg.Role == "system" {
|
||
continue // 跳过已有的系统消息
|
||
}
|
||
apiMessages = append(apiMessages, map[string]string{
|
||
"role": msg.Role,
|
||
"content": msg.Content,
|
||
})
|
||
}
|
||
|
||
return apiMessages
|
||
}
|
||
|
||
// estimateTokens 粗略估算文本的 token 数(字符数 / 3,适用于中英混合文本)
|
||
func estimateTokens(text string) int {
|
||
if text == "" {
|
||
return 0
|
||
}
|
||
// 中文字符约 1 char = 1 token,英文约 4 chars = 1 token
|
||
// 取中间值 1 char ≈ 0.75 token,即 chars * 4 / 3 的倒数 ≈ chars / 1.5
|
||
// 保守估算用 chars / 2 防止超出
|
||
n := len([]rune(text))
|
||
return (n + 1) / 2
|
||
}
|
||
|
||
// contextConfig 保存从 AIConfig.Settings 中解析出的上下文配置
|
||
type contextConfig struct {
|
||
contextLength int // 模型上下文窗口大小(token 数)
|
||
maxTokens int // 最大输出 token 数
|
||
}
|
||
|
||
// getContextConfig 从 AIConfig 中读取上下文配置,如果没有配置则使用默认值
|
||
func getContextConfig(aiConfig *app.AIConfig, preset *app.AIPreset) contextConfig {
|
||
cfg := contextConfig{
|
||
contextLength: 200000, // 保守默认值
|
||
maxTokens: 2000,
|
||
}
|
||
|
||
// 从 preset 读取 max_tokens
|
||
if preset != nil && preset.MaxTokens > 0 {
|
||
cfg.maxTokens = preset.MaxTokens
|
||
}
|
||
|
||
// 从 AIConfig.Settings 读取 context_length
|
||
if len(aiConfig.Settings) > 0 {
|
||
var settings map[string]interface{}
|
||
if err := json.Unmarshal(aiConfig.Settings, &settings); err == nil {
|
||
if cl, ok := settings["context_length"].(float64); ok && cl > 0 {
|
||
cfg.contextLength = int(cl)
|
||
}
|
||
}
|
||
}
|
||
|
||
return cfg
|
||
}
|
||
|
||
// buildContextManagedSystemPrompt 按优先级构建 system prompt,超出 budget 时截断低优先级内容
|
||
// 优先级(从高到低):
|
||
// 1. 核心人设(Name/Description/Personality/Scenario/SystemPrompt)
|
||
// 2. Preset.SystemPrompt
|
||
// 3. Worldbook 触发条目
|
||
// 4. CharacterBook 内嵌条目
|
||
// 5. MesExample(对话示例,最容易被截断)
|
||
//
|
||
// 返回构建好的 systemPrompt 以及消耗的 token 数
|
||
func (s *ConversationService) buildContextManagedSystemPrompt(
|
||
character app.AICharacter,
|
||
presetSystemPrompt string,
|
||
worldbookEngine *WorldbookEngine,
|
||
conversation app.Conversation,
|
||
messageContents []string,
|
||
budget int,
|
||
) (string, int) {
|
||
used := 0
|
||
|
||
// ── 优先级1:核心人设 ─────────────────────────────────────────────
|
||
core := fmt.Sprintf("你是 %s。", character.Name)
|
||
if character.Description != "" {
|
||
core += fmt.Sprintf("\n\n描述:%s", character.Description)
|
||
}
|
||
if character.Personality != "" {
|
||
core += fmt.Sprintf("\n\n性格:%s", character.Personality)
|
||
}
|
||
if character.Scenario != "" {
|
||
core += fmt.Sprintf("\n\n场景:%s", character.Scenario)
|
||
}
|
||
if character.SystemPrompt != "" {
|
||
core += fmt.Sprintf("\n\n系统提示:%s", character.SystemPrompt)
|
||
}
|
||
core += "\n\n请根据以上设定进行角色扮演,保持角色的性格和说话方式。"
|
||
core = s.applyMacroVariables(core, character)
|
||
|
||
coreTokens := estimateTokens(core)
|
||
if coreTokens >= budget {
|
||
// 极端情况:核心人设本身就超出 budget,截断到 budget
|
||
runes := []rune(core)
|
||
limit := budget * 2
|
||
if limit > len(runes) {
|
||
limit = len(runes)
|
||
}
|
||
core = string(runes[:limit])
|
||
global.GVA_LOG.Warn(fmt.Sprintf("[context] 核心人设超出 budget,已截断至 %d chars", limit))
|
||
return core, budget
|
||
}
|
||
used += coreTokens
|
||
prompt := core
|
||
|
||
// ── 优先级2:Preset.SystemPrompt ────────────────────────────────
|
||
if presetSystemPrompt != "" {
|
||
tokens := estimateTokens(presetSystemPrompt)
|
||
if used+tokens <= budget {
|
||
prompt += "\n\n" + presetSystemPrompt
|
||
used += tokens
|
||
} else {
|
||
// 尝试部分插入
|
||
remaining := budget - used
|
||
if remaining > 50 {
|
||
runes := []rune(presetSystemPrompt)
|
||
limit := remaining * 2
|
||
if limit > len(runes) {
|
||
limit = len(runes)
|
||
}
|
||
prompt += "\n\n" + string(runes[:limit])
|
||
used = budget
|
||
}
|
||
global.GVA_LOG.Warn(fmt.Sprintf("[context] Preset.SystemPrompt 因 budget 不足被截断(需要 %d tokens,剩余 %d)", tokens, budget-used))
|
||
}
|
||
}
|
||
|
||
if used >= budget {
|
||
return prompt, used
|
||
}
|
||
|
||
// ── 优先级3:世界书触发条目 ──────────────────────────────────────
|
||
if conversation.WorldbookEnabled && conversation.WorldbookID != nil && worldbookEngine != nil {
|
||
triggeredEntries, wbErr := worldbookEngine.ScanAndTrigger(*conversation.WorldbookID, messageContents)
|
||
if wbErr != nil {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("[context] 世界书触发失败: %v", wbErr))
|
||
} else if len(triggeredEntries) > 0 {
|
||
wbHeader := "\n\n[World Information]"
|
||
wbSection := wbHeader
|
||
for _, te := range triggeredEntries {
|
||
if te.Entry == nil || te.Entry.Content == "" {
|
||
continue
|
||
}
|
||
line := fmt.Sprintf("\n- %s", te.Entry.Content)
|
||
lineTokens := estimateTokens(line)
|
||
if used+estimateTokens(wbSection)+lineTokens <= budget {
|
||
wbSection += line
|
||
used += lineTokens
|
||
} else {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("[context] 世界书条目 (id=%d) 因 budget 不足被跳过", te.Entry.ID))
|
||
break
|
||
}
|
||
}
|
||
if wbSection != wbHeader {
|
||
prompt += wbSection
|
||
}
|
||
}
|
||
}
|
||
|
||
if used >= budget {
|
||
return prompt, used
|
||
}
|
||
|
||
// ── 优先级4:CharacterBook 内嵌条目 ──────────────────────────────
|
||
if len(character.CharacterBook) > 0 {
|
||
var characterBook map[string]interface{}
|
||
if err := json.Unmarshal(character.CharacterBook, &characterBook); err == nil {
|
||
if entries, ok := characterBook["entries"].([]interface{}); ok && len(entries) > 0 {
|
||
cbSection := "\n\n世界设定:"
|
||
addedAny := false
|
||
for _, entry := range entries {
|
||
entryMap, ok := entry.(map[string]interface{})
|
||
if !ok {
|
||
continue
|
||
}
|
||
enabled := true
|
||
if enabledVal, ok := entryMap["enabled"].(bool); ok {
|
||
enabled = enabledVal
|
||
}
|
||
if !enabled {
|
||
continue
|
||
}
|
||
content, ok := entryMap["content"].(string)
|
||
if !ok || content == "" {
|
||
continue
|
||
}
|
||
line := fmt.Sprintf("\n- %s", content)
|
||
lineTokens := estimateTokens(line)
|
||
if used+estimateTokens(cbSection)+lineTokens <= budget {
|
||
cbSection += line
|
||
used += lineTokens
|
||
addedAny = true
|
||
} else {
|
||
global.GVA_LOG.Warn("[context] CharacterBook 条目因 budget 不足被跳过")
|
||
break
|
||
}
|
||
}
|
||
if addedAny {
|
||
prompt += cbSection
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if used >= budget {
|
||
return prompt, used
|
||
}
|
||
|
||
// ── 优先级5:MesExample(对话示例,最低优先级)──────────────────
|
||
if character.MesExample != "" {
|
||
mesTokens := estimateTokens(character.MesExample)
|
||
prefix := "\n\n对话示例:\n"
|
||
prefixTokens := estimateTokens(prefix)
|
||
if used+prefixTokens+mesTokens <= budget {
|
||
prompt += prefix + character.MesExample
|
||
used += prefixTokens + mesTokens
|
||
} else {
|
||
// 尝试截断 MesExample
|
||
remaining := budget - used - prefixTokens
|
||
if remaining > 100 {
|
||
runes := []rune(character.MesExample)
|
||
limit := remaining * 2
|
||
if limit > len(runes) {
|
||
limit = len(runes)
|
||
}
|
||
prompt += prefix + string(runes[:limit])
|
||
used = budget
|
||
global.GVA_LOG.Warn(fmt.Sprintf("[context] MesExample 被截断(原始 %d tokens,保留约 %d tokens)", mesTokens, remaining))
|
||
} else {
|
||
global.GVA_LOG.Warn("[context] MesExample 因 budget 不足被完全跳过")
|
||
}
|
||
}
|
||
}
|
||
|
||
return prompt, used
|
||
}
|
||
|
||
// trimMessagesToBudget 从历史消息中按 token 预算选取最近的消息
|
||
// 优先保留最新的消息,从后往前丢弃旧消息直到 token 数在 budget 内
|
||
func trimMessagesToBudget(messages []app.Message, budget int) []app.Message {
|
||
if budget <= 0 {
|
||
return nil
|
||
}
|
||
|
||
// messages 已经是从旧到新的顺序
|
||
// 从最新消息开始往前累加,直到超出 budget
|
||
selected := make([]app.Message, 0, len(messages))
|
||
used := 0
|
||
|
||
for i := len(messages) - 1; i >= 0; i-- {
|
||
msg := messages[i]
|
||
if msg.Role == "system" {
|
||
continue
|
||
}
|
||
t := estimateTokens(msg.Content)
|
||
if used+t > budget {
|
||
global.GVA_LOG.Warn(fmt.Sprintf("[context] 历史消息已截断,保留最近 %d 条(共 %d 条),使用 %d tokens", len(selected), len(messages), used))
|
||
break
|
||
}
|
||
used += t
|
||
selected = append([]app.Message{msg}, selected...) // 保持时序
|
||
}
|
||
|
||
return selected
|
||
}
|
||
|
||
// buildAPIMessagesWithContextManagement 整合 context 管理,构建最终的 messages 列表
|
||
// 返回 apiMessages 及各部分 token 统计日志
|
||
func (s *ConversationService) buildAPIMessagesWithContextManagement(
|
||
allMessages []app.Message,
|
||
character app.AICharacter,
|
||
presetSystemPrompt string,
|
||
worldbookEngine *WorldbookEngine,
|
||
conversation app.Conversation,
|
||
aiConfig *app.AIConfig,
|
||
preset *app.AIPreset,
|
||
) []map[string]string {
|
||
cfg := getContextConfig(aiConfig, preset)
|
||
|
||
// 安全边际:为输出保留 max_tokens,另加 200 token 缓冲
|
||
inputBudget := cfg.contextLength - cfg.maxTokens - 200
|
||
if inputBudget <= 0 {
|
||
inputBudget = cfg.contextLength / 2
|
||
}
|
||
|
||
// 为历史消息分配预算:system prompt 最多占用 60% 的 input budget
|
||
systemBudget := inputBudget * 60 / 100
|
||
historyBudget := inputBudget - systemBudget
|
||
|
||
// 提取消息内容用于世界书扫描
|
||
var messageContents []string
|
||
for _, msg := range allMessages {
|
||
messageContents = append(messageContents, msg.Content)
|
||
}
|
||
|
||
// 构建 system prompt(含 worldbook 注入,按优先级截断)
|
||
systemPrompt, systemTokens := s.buildContextManagedSystemPrompt(
|
||
character,
|
||
presetSystemPrompt,
|
||
worldbookEngine,
|
||
conversation,
|
||
messageContents,
|
||
systemBudget,
|
||
)
|
||
|
||
// 如果 system prompt 实际用量比预算少,把节省的预算让给历史消息
|
||
if systemTokens < systemBudget {
|
||
historyBudget += systemBudget - systemTokens
|
||
}
|
||
|
||
global.GVA_LOG.Info(fmt.Sprintf("[context] 配置:context_length=%d, max_tokens=%d, input_budget=%d, system=%d tokens, history_budget=%d",
|
||
cfg.contextLength, cfg.maxTokens, inputBudget, systemTokens, historyBudget))
|
||
|
||
// 按 token 预算裁剪历史消息
|
||
trimmedMessages := trimMessagesToBudget(allMessages, historyBudget)
|
||
|
||
// 构建最终 messages
|
||
apiMessages := make([]map[string]string, 0, len(trimmedMessages)+1)
|
||
apiMessages = append(apiMessages, map[string]string{
|
||
"role": "system",
|
||
"content": systemPrompt,
|
||
})
|
||
for _, msg := range trimmedMessages {
|
||
if msg.Role == "system" {
|
||
continue
|
||
}
|
||
apiMessages = append(apiMessages, map[string]string{
|
||
"role": msg.Role,
|
||
"content": msg.Content,
|
||
})
|
||
}
|
||
|
||
return apiMessages
|
||
}
|
||
|
||
// callOpenAIAPI 调用 OpenAI API
|
||
func (s *ConversationService) callOpenAIAPI(config *app.AIConfig, model string, messages []map[string]string, preset *app.AIPreset) (string, error) {
|
||
client := &http.Client{Timeout: 10 * time.Minute}
|
||
|
||
// 使用配置的模型或默认模型
|
||
if model == "" {
|
||
model = config.DefaultModel
|
||
}
|
||
if model == "" {
|
||
model = "gpt-4"
|
||
}
|
||
|
||
// 应用预设参数(如果有预设)
|
||
temperature := 0.7
|
||
maxTokens := 2000
|
||
var topP *float64
|
||
var frequencyPenalty *float64
|
||
var presencePenalty *float64
|
||
var stopSequences []string
|
||
|
||
if preset != nil {
|
||
temperature = preset.Temperature
|
||
maxTokens = preset.MaxTokens
|
||
if preset.TopP > 0 {
|
||
topP = &preset.TopP
|
||
}
|
||
if preset.FrequencyPenalty != 0 {
|
||
frequencyPenalty = &preset.FrequencyPenalty
|
||
}
|
||
if preset.PresencePenalty != 0 {
|
||
presencePenalty = &preset.PresencePenalty
|
||
}
|
||
// 解析停止序列
|
||
if len(preset.StopSequences) > 0 {
|
||
json.Unmarshal(preset.StopSequences, &stopSequences)
|
||
}
|
||
global.GVA_LOG.Info(fmt.Sprintf("应用预设参数: Temperature=%.2f, MaxTokens=%d, TopP=%.2f", temperature, maxTokens, preset.TopP))
|
||
}
|
||
|
||
// 构建请求体
|
||
requestBody := map[string]interface{}{
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": maxTokens,
|
||
}
|
||
|
||
// 添加可选参数
|
||
if topP != nil {
|
||
requestBody["top_p"] = *topP
|
||
}
|
||
if frequencyPenalty != nil {
|
||
requestBody["frequency_penalty"] = *frequencyPenalty
|
||
}
|
||
if presencePenalty != nil {
|
||
requestBody["presence_penalty"] = *presencePenalty
|
||
}
|
||
if len(stopSequences) > 0 {
|
||
requestBody["stop"] = stopSequences
|
||
}
|
||
|
||
bodyBytes, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
// 创建请求
|
||
endpoint := config.BaseURL + "/chat/completions"
|
||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(bodyBytes))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+config.APIKey)
|
||
|
||
// 发送请求
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取响应失败: %v", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return "", fmt.Errorf("API 返回错误 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 解析响应
|
||
var result struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Error *struct {
|
||
Message string `json:"message"`
|
||
} `json:"error"`
|
||
}
|
||
|
||
err = json.Unmarshal(body, &result)
|
||
if err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
if result.Error != nil {
|
||
return "", fmt.Errorf("API 错误: %s", result.Error.Message)
|
||
}
|
||
|
||
if len(result.Choices) == 0 {
|
||
return "", errors.New("API 未返回任何回复")
|
||
}
|
||
|
||
return result.Choices[0].Message.Content, nil
|
||
}
|
||
|
||
// callAnthropicAPI 调用 Anthropic API
|
||
func (s *ConversationService) callAnthropicAPI(config *app.AIConfig, model string, messages []map[string]string, systemPrompt string, preset *app.AIPreset) (string, error) {
|
||
client := &http.Client{Timeout: 10 * time.Minute}
|
||
|
||
// 使用配置的模型或默认模型
|
||
if model == "" {
|
||
model = config.DefaultModel
|
||
}
|
||
if model == "" {
|
||
model = "claude-3-sonnet-20240229"
|
||
}
|
||
|
||
// Anthropic API 不支持 system role,需要单独传递
|
||
apiMessages := make([]map[string]string, 0)
|
||
for _, msg := range messages {
|
||
if msg["role"] != "system" {
|
||
apiMessages = append(apiMessages, msg)
|
||
}
|
||
}
|
||
|
||
// 应用预设参数(如果有预设)
|
||
maxTokens := 2000
|
||
var temperature *float64
|
||
var topP *float64
|
||
var stopSequences []string
|
||
|
||
if preset != nil {
|
||
maxTokens = preset.MaxTokens
|
||
if preset.Temperature > 0 {
|
||
temperature = &preset.Temperature
|
||
}
|
||
if preset.TopP > 0 {
|
||
topP = &preset.TopP
|
||
}
|
||
// 解析停止序列
|
||
if len(preset.StopSequences) > 0 {
|
||
json.Unmarshal(preset.StopSequences, &stopSequences)
|
||
}
|
||
global.GVA_LOG.Info(fmt.Sprintf("应用预设参数: Temperature=%.2f, MaxTokens=%d, TopP=%.2f", preset.Temperature, maxTokens, preset.TopP))
|
||
}
|
||
|
||
// 构建请求体
|
||
requestBody := map[string]interface{}{
|
||
"model": model,
|
||
"messages": apiMessages,
|
||
"system": systemPrompt,
|
||
"max_tokens": maxTokens,
|
||
}
|
||
|
||
// 添加可选参数
|
||
if temperature != nil {
|
||
requestBody["temperature"] = *temperature
|
||
}
|
||
if topP != nil {
|
||
requestBody["top_p"] = *topP
|
||
}
|
||
if len(stopSequences) > 0 {
|
||
requestBody["stop_sequences"] = stopSequences
|
||
}
|
||
|
||
bodyBytes, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
// 创建请求
|
||
endpoint := config.BaseURL + "/messages"
|
||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(bodyBytes))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("x-api-key", config.APIKey)
|
||
req.Header.Set("anthropic-version", "2023-06-01")
|
||
|
||
// 发送请求
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取响应失败: %v", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return "", fmt.Errorf("API 返回错误 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 解析响应
|
||
var result struct {
|
||
Content []struct {
|
||
Text string `json:"text"`
|
||
} `json:"content"`
|
||
Error *struct {
|
||
Message string `json:"message"`
|
||
} `json:"error"`
|
||
}
|
||
|
||
err = json.Unmarshal(body, &result)
|
||
if err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
if result.Error != nil {
|
||
return "", fmt.Errorf("API 错误: %s", result.Error.Message)
|
||
}
|
||
|
||
if len(result.Content) == 0 {
|
||
return "", errors.New("API 未返回任何回复")
|
||
}
|
||
|
||
return result.Content[0].Text, nil
|
||
}
|