package app import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "regexp" "sort" "strings" "time" "git.echol.cn/loser/ai_proxy/server/global" "git.echol.cn/loser/ai_proxy/server/model/app" "git.echol.cn/loser/ai_proxy/server/model/app/request" "git.echol.cn/loser/ai_proxy/server/model/app/response" ) type AiProxyService struct{} // ProcessChatCompletion 处理聊天补全请求 func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint, req *request.ChatCompletionRequest) (resp response.ChatCompletionResponse, err error) { startTime := time.Now() // 1. 获取预设配置 var preset app.AiPreset if req.PresetID > 0 { err = global.GVA_DB.First(&preset, req.PresetID).Error if err != nil { return resp, fmt.Errorf("预设不存在: %w", err) } } // 2. 获取提供商配置 var provider app.AiProvider // 根据 binding_key 或预设绑定获取 provider if req.BindingKey != "" { // 通过 binding_key 查找绑定关系 var binding app.AiPresetBinding err = global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true). Order("priority ASC"). First(&binding).Error if err == nil { err = global.GVA_DB.First(&provider, binding.ProviderID).Error } } // 如果没有找到,使用默认的活跃提供商 if provider.ID == 0 { err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error if err != nil { return resp, fmt.Errorf("未找到可用的AI提供商: %w", err) } } // 3. 构建注入后的消息 messages, err := s.buildInjectedMessages(req, &preset) if err != nil { return resp, fmt.Errorf("构建消息失败: %w", err) } // 4. 转发到上游AI resp, err = s.forwardToAI(ctx, &provider, &preset, messages) if err != nil { // 记录失败日志 s.logRequest(userId, &preset, &provider, req.Messages[0].Content, "", err, time.Since(startTime)) return resp, err } // 5. 应用输出正则脚本 resp.Choices[0].Message.Content = s.applyOutputRegex(resp.Choices[0].Message.Content, preset.RegexScripts) // 6. 记录成功日志 s.logRequest(userId, &preset, &provider, req.Messages[0].Content, resp.Choices[0].Message.Content, nil, time.Since(startTime)) return resp, nil } // buildInjectedMessages 构建注入预设后的消息数组 func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionRequest, preset *app.AiPreset) ([]request.Message, error) { if preset == nil || preset.ID == 0 { return req.Messages, nil } // 1. 按 injection_order 排序 prompts sortedPrompts := make([]app.Prompt, len(preset.Prompts)) copy(sortedPrompts, preset.Prompts) sort.Slice(sortedPrompts, func(i, j int) bool { return sortedPrompts[i].InjectionOrder < sortedPrompts[j].InjectionOrder }) messages := make([]request.Message, 0) // 2. 根据 injection_depth 插入到对话历史中 for _, prompt := range sortedPrompts { if prompt.Marker { continue // 跳过标记提示词 } // 替换变量 content := s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard) // 根据 injection_depth 决定插入位置 // depth=0: 插入到最前面(系统提示词) // depth>0: 从对话历史末尾往前数 depth 条消息的位置插入 if prompt.InjectionDepth == 0 || prompt.SystemPrompt { messages = append(messages, request.Message{ Role: prompt.Role, Content: content, }) } else { // 先添加用户消息,稍后根据 depth 插入 // 这里简化处理,将非系统提示词也添加到前面 messages = append(messages, request.Message{ Role: prompt.Role, Content: content, }) } } // 添加用户消息 messages = append(messages, req.Messages...) // 4. 应用输入正则脚本 (placement=1) for i := range messages { messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts) } return messages, nil } // replaceVariables 替换变量 func (s *AiProxyService) replaceVariables(content string, vars map[string]string, card *request.CharacterCard) string { result := content // 替换自定义变量 for key, value := range vars { placeholder := fmt.Sprintf("{{%s}}", key) result = replaceAll(result, placeholder, value) } // 替换角色卡片变量 if card != nil { result = replaceAll(result, "{{char}}", card.Name) result = replaceAll(result, "{{char_name}}", card.Name) } return result } // applyInputRegex 应用输入正则脚本 func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScript) string { for _, script := range scripts { if script.Disabled { continue } if !containsPlacement(script.Placement, 1) { continue } // 编译正则表达式 re, err := regexp.Compile(script.FindRegex) if err != nil { global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName)) continue } // 执行替换 content = re.ReplaceAllString(content, script.ReplaceString) } return content } // applyOutputRegex 应用输出正则脚本 func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScript) string { for _, script := range scripts { if script.Disabled { continue } if !containsPlacement(script.Placement, 2) { continue } // 编译正则表达式 re, err := regexp.Compile(script.FindRegex) if err != nil { global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName)) continue } // 执行替换 content = re.ReplaceAllString(content, script.ReplaceString) } return content } // forwardToAI 转发请求到上游AI func (s *AiProxyService) forwardToAI(ctx context.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message) (response.ChatCompletionResponse, error) { // 构建请求体 reqBody := map[string]interface{}{ "model": provider.Model, "messages": messages, } if preset != nil { reqBody["temperature"] = preset.Temperature reqBody["top_p"] = preset.TopP reqBody["max_tokens"] = preset.MaxTokens reqBody["frequency_penalty"] = preset.FrequencyPenalty reqBody["presence_penalty"] = preset.PresencePenalty } jsonData, err := json.Marshal(reqBody) if err != nil { return response.ChatCompletionResponse{}, err } // 创建HTTP请求 url := fmt.Sprintf("%s/chat/completions", provider.BaseURL) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return response.ChatCompletionResponse{}, err } req.Header.Set("Content-Type", "application/json") if provider.UpstreamKey != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey)) } // 发送请求 client := &http.Client{Timeout: 120 * time.Second} resp, err := client.Do(req) if err != nil { return response.ChatCompletionResponse{}, err } defer resp.Body.Close() // 读取响应 body, err := io.ReadAll(resp.Body) if err != nil { return response.ChatCompletionResponse{}, err } if resp.StatusCode != http.StatusOK { return response.ChatCompletionResponse{}, fmt.Errorf("API错误: %s - %s", resp.Status, string(body)) } // 解析响应 var aiResp response.ChatCompletionResponse if err := json.Unmarshal(body, &aiResp); err != nil { return response.ChatCompletionResponse{}, err } return aiResp, nil } // logRequest 记录请求日志 func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider *app.AiProvider, originalMsg, responseText string, err error, latency time.Duration) { log := app.AiRequestLog{ UserID: &userId, OriginalMessage: originalMsg, ResponseText: responseText, LatencyMs: int(latency.Milliseconds()), } if preset != nil { presetID := preset.ID log.PresetID = &presetID } if provider != nil { providerID := provider.ID log.ProviderID = &providerID } if err != nil { log.Status = "error" log.ErrorMessage = err.Error() } else { log.Status = "success" } global.GVA_DB.Create(&log) } // 辅助函数 func replaceAll(s, old, new string) string { return strings.ReplaceAll(s, old, new) } func containsPlacement(placements []int, target int) bool { for _, p := range placements { if p == target { return true } } return false }