package app import ( "bufio" "bytes" "encoding/json" "fmt" "io" "net/http" "strings" "time" "git.echol.cn/loser/ai_proxy/server/global" "git.echol.cn/loser/ai_proxy/server/model/app" "git.echol.cn/loser/ai_proxy/server/model/app/request" "github.com/gin-gonic/gin" "go.uber.org/zap" ) // ProcessChatCompletionStream 处理流式聊天补全请求 func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userId uint, req *request.ChatCompletionRequest) { startTime := time.Now() // 1. 获取预设配置 var preset app.AiPreset if req.PresetID > 0 { err := global.GVA_DB.First(&preset, req.PresetID).Error if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "预设不存在"}) return } } // 2. 获取提供商配置 var provider app.AiProvider if req.BindingKey != "" { var binding app.AiPresetBinding err := global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true). Order("priority ASC"). First(&binding).Error if err == nil { global.GVA_DB.First(&provider, binding.ProviderID) } } if provider.ID == 0 { err := global.GVA_DB.Where("is_active = ?", true).First(&provider).Error if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "未找到可用的AI提供商"}) return } } // 3. 构建注入后的消息 messages, err := s.buildInjectedMessages(req, &preset) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "构建消息失败"}) return } // 4. 转发流式请求到上游AI err = s.forwardStreamToAI(c, &provider, &preset, messages, userId, startTime) if err != nil { global.GVA_LOG.Error("流式请求失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } } // forwardStreamToAI 转发流式请求到上游AI func (s *AiProxyService) forwardStreamToAI(c *gin.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message, userId uint, startTime time.Time) error { // 构建请求体 reqBody := map[string]interface{}{ "model": provider.Model, "messages": messages, "stream": true, } if preset != nil { reqBody["temperature"] = preset.Temperature reqBody["top_p"] = preset.TopP reqBody["max_tokens"] = preset.MaxTokens reqBody["frequency_penalty"] = preset.FrequencyPenalty reqBody["presence_penalty"] = preset.PresencePenalty } jsonData, err := json.Marshal(reqBody) if err != nil { return err } // 创建HTTP请求 url := fmt.Sprintf("%s/chat/completions", provider.BaseURL) req, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") if provider.UpstreamKey != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey)) } // 发送请求 client := &http.Client{Timeout: 300 * time.Second} resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("API错误: %s - %s", resp.Status, string(body)) } // 设置SSE响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("Transfer-Encoding", "chunked") // 读取并转发流式响应 reader := bufio.NewReader(resp.Body) flusher, ok := c.Writer.(http.Flusher) if !ok { return fmt.Errorf("streaming not supported") } var fullResponse strings.Builder for { line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return err } // 跳过空行 if len(bytes.TrimSpace(line)) == 0 { continue } // 解析SSE数据 lineStr := string(line) if strings.HasPrefix(lineStr, "data: ") { data := strings.TrimPrefix(lineStr, "data: ") data = strings.TrimSpace(data) // 检查是否是结束标记 if data == "[DONE]" { c.Writer.Write([]byte("data: [DONE]\n\n")) flusher.Flush() break } // 解析JSON并提取内容 var chunk map[string]interface{} if err := json.Unmarshal([]byte(data), &chunk); err == nil { if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 { if choice, ok := choices[0].(map[string]interface{}); ok { if delta, ok := choice["delta"].(map[string]interface{}); ok { if content, ok := delta["content"].(string); ok { fullResponse.WriteString(content) } } } } } // 转发原始数据 c.Writer.Write(line) flusher.Flush() } } // 应用输出正则脚本 finalContent := fullResponse.String() if preset != nil { finalContent = s.applyOutputRegex(finalContent, preset.RegexScripts) } // 记录日志 var originalMsg string if len(messages) > 0 { originalMsg = messages[len(messages)-1].Content } s.logRequest(userId, preset, provider, originalMsg, finalContent, nil, time.Since(startTime)) return nil }