package app import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "git.echol.cn/loser/ai_proxy/server/global" "git.echol.cn/loser/ai_proxy/server/model/app" "git.echol.cn/loser/ai_proxy/server/model/app/request" "git.echol.cn/loser/ai_proxy/server/model/app/response" "github.com/gin-gonic/gin" "go.uber.org/zap" ) type AiProxyService struct{} // ProcessChatCompletion 处理聊天补全请求 func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userID uint, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) { startTime := time.Now() // 1. 获取绑定配置 binding, err := s.getBinding(userID, req) if err != nil { return nil, fmt.Errorf("获取绑定配置失败: %w", err) } // 2. 注入预设 injector := NewPresetInjector(&binding.Preset) req.Messages = injector.InjectMessages(req.Messages) injector.ApplyPresetParameters(req) // 3. 转发请求到上游 resp, err := s.forwardRequest(ctx, &binding.Provider, req) if err != nil { s.logRequest(userID, binding, req, nil, err, time.Since(startTime)) return nil, err } // 4. 处理响应 if len(resp.Choices) > 0 { resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content) } // 5. 记录日志 s.logRequest(userID, binding, req, resp, nil, time.Since(startTime)) return resp, nil } // ProcessChatCompletionStream 处理流式聊天补全请求 func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint, req *request.ChatCompletionRequest) { startTime := time.Now() // 1. 获取绑定配置 binding, err := s.getBinding(userID, req) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 2. 注入预设 injector := NewPresetInjector(&binding.Preset) req.Messages = injector.InjectMessages(req.Messages) injector.ApplyPresetParameters(req) // 3. 设置 SSE 响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") // 4. 转发流式请求 err = s.forwardStreamRequest(c, &binding.Provider, req, injector) if err != nil { global.GVA_LOG.Error("流式请求失败", zap.Error(err)) s.logRequest(userID, binding, req, nil, err, time.Since(startTime)) } } // getBinding 获取绑定配置 func (s *AiProxyService) getBinding(userID uint, req *request.ChatCompletionRequest) (*app.AiPresetBinding, error) { var binding app.AiPresetBinding query := global.GVA_DB.Preload("Preset").Preload("Provider").Where("user_id = ? AND enabled = ?", userID, true) // 优先使用 binding_name if req.BindingName != "" { query = query.Where("name = ?", req.BindingName) } else if req.PresetName != "" && req.ProviderName != "" { // 使用 preset_name 和 provider_name query = query.Joins("JOIN ai_presets ON ai_presets.id = ai_preset_bindings.preset_id"). Joins("JOIN ai_providers ON ai_providers.id = ai_preset_bindings.provider_id"). Where("ai_presets.name = ? AND ai_providers.name = ?", req.PresetName, req.ProviderName) } else { // 使用默认绑定(第一个启用的) query = query.Order("id ASC") } if err := query.First(&binding).Error; err != nil { return nil, fmt.Errorf("未找到可用的绑定配置") } if !binding.Provider.Enabled { return nil, fmt.Errorf("提供商已禁用") } if !binding.Preset.Enabled { return nil, fmt.Errorf("预设已禁用") } return &binding, nil } // forwardRequest 转发请求到上游 AI 服务 func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) { // 使用提供商的默认模型(如果请求中没有指定) if req.Model == "" && provider.Model != "" { req.Model = provider.Model } // 构建请求 reqBody, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions" httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey) // 发送请求 client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second} httpResp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { body, _ := io.ReadAll(httpResp.Body) return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body)) } // 解析响应 var resp response.ChatCompletionResponse if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } return &resp, nil } // forwardStreamRequest 转发流式请求 func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error { // 使用提供商的默认模型 if req.Model == "" && provider.Model != "" { req.Model = provider.Model } reqBody, err := json.Marshal(req) if err != nil { return err } url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions" httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody)) if err != nil { return err } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey) client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second} httpResp, err := client.Do(httpReq) if err != nil { return err } defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { body, _ := io.ReadAll(httpResp.Body) return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body)) } // 读取并转发流式响应 reader := bufio.NewReader(httpResp.Body) flusher, ok := c.Writer.(http.Flusher) if !ok { return fmt.Errorf("不支持流式响应") } for { line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return err } // 跳过空行 if len(bytes.TrimSpace(line)) == 0 { continue } // 处理 SSE 数据 if bytes.HasPrefix(line, []byte("data: ")) { data := bytes.TrimPrefix(line, []byte("data: ")) data = bytes.TrimSpace(data) // 检查是否是结束标记 if string(data) == "[DONE]" { c.Writer.Write([]byte("data: [DONE]\n\n")) flusher.Flush() break } // 解析并处理响应 var chunk response.ChatCompletionStreamResponse if err := json.Unmarshal(data, &chunk); err != nil { continue } // 应用输出正则处理 if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { chunk.Choices[0].Delta.Content = injector.ProcessResponse(chunk.Choices[0].Delta.Content) } // 重新序列化并发送 processedData, _ := json.Marshal(chunk) c.Writer.Write([]byte("data: ")) c.Writer.Write(processedData) c.Writer.Write([]byte("\n\n")) flusher.Flush() } } return nil } // logRequest 记录请求日志 func (s *AiProxyService) logRequest(userID uint, binding *app.AiPresetBinding, req *request.ChatCompletionRequest, resp *response.ChatCompletionResponse, err error, duration time.Duration) { log := app.AiRequestLog{ UserID: userID, BindingID: binding.ID, ProviderID: binding.ProviderID, PresetID: binding.PresetID, Model: req.Model, Duration: duration.Milliseconds(), RequestTime: time.Now(), } if err != nil { log.Status = "error" log.ErrorMessage = err.Error() } else { log.Status = "success" if resp != nil { log.PromptTokens = resp.Usage.PromptTokens log.CompletionTokens = resp.Usage.CompletionTokens log.TotalTokens = resp.Usage.TotalTokens } } global.GVA_DB.Create(&log) }