package app import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "git.echol.cn/loser/ai_proxy/server/global" "git.echol.cn/loser/ai_proxy/server/model/app" "git.echol.cn/loser/ai_proxy/server/model/app/request" "git.echol.cn/loser/ai_proxy/server/model/app/response" "github.com/gin-gonic/gin" "go.uber.org/zap" ) type AiProxyService struct{} // ProcessChatCompletion 处理聊天补全请求 func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) { // 记录请求参数 global.GVA_LOG.Info("收到 ChatCompletion 请求", zap.String("model", req.Model), zap.Any("messages", req.Messages), zap.Any("full_request", req), ) // 1. 根据模型获取配置 if req.Model == "" { return nil, fmt.Errorf("model 参数不能为空") } preset, provider, err := s.getConfigByModel(req.Model) if err != nil { return nil, err } // 2. 注入预设 var injector *PresetInjector if preset != nil { injector = NewPresetInjector(preset) req.Messages = injector.InjectMessages(req.Messages) injector.ApplyPresetParameters(req) } // 3. 转发请求到上游 resp, err := s.forwardRequest(ctx, provider, req) if err != nil { return nil, err } // 获取 AI 输出内容 aiOutput := "" if len(resp.Choices) > 0 { aiOutput = resp.Choices[0].Message.Content } // 应用预设处理(使用同一个 injector 实例) if injector != nil && len(resp.Choices) > 0 { resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content) aiOutput = resp.Choices[0].Message.Content } // 4. 处理响应并收集正则日志 if resp != nil && resp.Usage != nil { // 统一填充 standard_usage,方便上游使用统一格式解析 if resp.Usage.PromptTokens > 0 || resp.Usage.CompletionTokens > 0 || resp.Usage.TotalTokens > 0 { resp.StandardUsage = &response.ChatCompletionUsage{ PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, } } } // 记录响应内容(统一日志输出) logFields := []zap.Field{ zap.String("ai_output", aiOutput), } if resp.Usage != nil { logFields = append(logFields, zap.Any("usage", resp.Usage)) } if resp.StandardUsage != nil { logFields = append(logFields, zap.Any("standard_usage", resp.StandardUsage)) } // 添加正则脚本执行日志 if injector != nil { regexLogs := injector.GetRegexLogs() if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) { // 收集触发的脚本名称 triggeredScripts := make([]string, 0) for _, scriptLog := range regexLogs.InputScripts { if scriptLog.MatchCount > 0 { triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount)) } } for _, scriptLog := range regexLogs.OutputScripts { if scriptLog.MatchCount > 0 { triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount)) } } if len(triggeredScripts) > 0 { logFields = append(logFields, zap.Strings("triggered_regex_scripts", triggeredScripts), zap.Int("total_matches", regexLogs.TotalMatches), ) } } } logFields = append(logFields, zap.Any("full_response", resp)) global.GVA_LOG.Info("ChatCompletion 响应", logFields...) return resp, nil } // ProcessChatCompletionStream 处理流式聊天补全请求 func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, req *request.ChatCompletionRequest) { // 记录请求参数 global.GVA_LOG.Info("收到 ChatCompletion 流式请求", zap.String("model", req.Model), zap.Any("messages", req.Messages), zap.Any("full_request", req), ) // 1. 根据模型获取配置 if req.Model == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"}) return } preset, provider, err := s.getConfigByModel(req.Model) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 2. 注入预设 var injector *PresetInjector if preset != nil { injector = NewPresetInjector(preset) req.Messages = injector.InjectMessages(req.Messages) injector.ApplyPresetParameters(req) } // 3. 设置 SSE 响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") // 4. 转发流式请求 err = s.forwardStreamRequest(c, provider, req, injector) if err != nil { global.GVA_LOG.Error("流式请求失败", zap.Error(err)) } } // getConfigByModel 根据模型名称获取配置 func (s *AiProxyService) getConfigByModel(modelName string) (*app.AiPreset, *app.AiProvider, error) { // 查找启用的模型配置 var model app.AiModel err := global.GVA_DB.Preload("Provider").Preload("Preset"). Where("name = ? AND enabled = ?", modelName, true). First(&model).Error if err != nil { return nil, nil, fmt.Errorf("未找到模型配置: %s", modelName) } // 检查提供商是否启用 if !model.Provider.Enabled { return nil, nil, fmt.Errorf("提供商已禁用") } return model.Preset, &model.Provider, nil } // forwardRequest 转发请求到上游 AI 服务 func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) { // 使用提供商的默认模型(如果请求中没有指定) if req.Model == "" && provider.Model != "" { req.Model = provider.Model } // 构建请求 reqBody, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } // 上游请求参数 global.GVA_LOG.Info("转发 ChatCompletion 请求到上游", zap.String("provider", provider.Name), zap.String("model", req.Model), zap.Any("messages", req.Messages), ) url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey) // 发送请求 client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second} httpResp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { body, _ := io.ReadAll(httpResp.Body) return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body)) } // 解析响应 var resp response.ChatCompletionResponse if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } return &resp, nil } // forwardStreamRequest 转发流式请求 func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error { // 使用提供商的默认模型 if req.Model == "" && provider.Model != "" { req.Model = provider.Model } reqBody, err := json.Marshal(req) if err != nil { return err } // 上游请求参数 global.GVA_LOG.Info("转发 ChatCompletion 流式请求到上游", zap.String("provider", provider.Name), zap.String("model", req.Model), zap.Any("messages", req.Messages), ) url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions" httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody)) if err != nil { return err } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey) client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second} httpResp, err := client.Do(httpReq) if err != nil { return err } defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { body, _ := io.ReadAll(httpResp.Body) return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body)) } // 读取并转发流式响应 reader := bufio.NewReader(httpResp.Body) flusher, ok := c.Writer.(http.Flusher) if !ok { return fmt.Errorf("不支持流式响应") } // 聚合 AI 输出内容用于日志和正则处理 var fullContent bytes.Buffer for { line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return err } // 跳过空行 if len(bytes.TrimSpace(line)) == 0 { continue } // 处理 SSE 数据 if bytes.HasPrefix(line, []byte("data: ")) { data := bytes.TrimPrefix(line, []byte("data: ")) data = bytes.TrimSpace(data) // 检查是否是结束标记 if string(data) == "[DONE]" { c.Writer.Write([]byte("data: [DONE]\n\n")) flusher.Flush() break } // 解析并处理响应 var chunk response.ChatCompletionStreamResponse if err := json.Unmarshal(data, &chunk); err != nil { continue } // 收集原始内容并应用正则处理 if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { originalContent := chunk.Choices[0].Delta.Content fullContent.WriteString(originalContent) // 应用输出正则处理 if injector != nil { chunk.Choices[0].Delta.Content = injector.ProcessResponse(originalContent) } } // 重新序列化并转发处理后的响应 processedData, _ := json.Marshal(chunk) c.Writer.Write([]byte("data: ")) c.Writer.Write(processedData) c.Writer.Write([]byte("\n\n")) flusher.Flush() } } // 流式结束后,对完整内容应用输出正则处理(仅用于日志记录) processedContent := fullContent.String() if injector != nil && processedContent != "" { processedContent = injector.ProcessResponse(processedContent) } // 流式请求结束后记录日志 logFields := []zap.Field{ zap.String("ai_output_original", fullContent.String()), zap.String("ai_output_processed", processedContent), } // 添加正则脚本执行日志 if injector != nil { regexLogs := injector.GetRegexLogs() if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) { // 收集触发的脚本名称 triggeredScripts := make([]string, 0) for _, scriptLog := range regexLogs.InputScripts { if scriptLog.MatchCount > 0 { triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount)) } } for _, scriptLog := range regexLogs.OutputScripts { if scriptLog.MatchCount > 0 { triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount)) } } if len(triggeredScripts) > 0 { logFields = append(logFields, zap.Strings("triggered_regex_scripts", triggeredScripts), zap.Int("total_matches", regexLogs.TotalMatches), ) } } } global.GVA_LOG.Info("ChatCompletion 流式响应完成", logFields...) return nil } // GetAvailableModels 获取用户可用的模型列表 func (s *AiProxyService) GetAvailableModels(apiKey *app.AiApiKey) (*response.ModelListResponse, error) { // 查询所有启用的模型 var models []app.AiModel query := global.GVA_DB.Where("enabled = ?", true) // 如果 API Key 限制了模型,只返回允许的模型 if len(apiKey.AllowedModels) > 0 { query = query.Where("name IN ?", apiKey.AllowedModels) } if err := query.Find(&models).Error; err != nil { return nil, fmt.Errorf("查询模型列表失败: %w", err) } // 构建响应 modelList := &response.ModelListResponse{ Object: "list", Data: make([]response.ModelInfo, 0, len(models)), } // 去重(同一模型可能在多个提供商下配置) seen := make(map[string]bool) for _, model := range models { if !seen[model.Name] { seen[model.Name] = true modelList.Data = append(modelList.Data, response.ModelInfo{ ID: model.Name, Object: "model", Created: model.CreatedAt.Unix(), OwnedBy: "system", }) } } return modelList, nil }