194 lines
4.9 KiB
Go
194 lines
4.9 KiB
Go
package app
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.echol.cn/loser/ai_proxy/server/global"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// ProcessChatCompletionStream 处理流式聊天补全请求
|
|
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userId uint, req *request.ChatCompletionRequest) {
|
|
startTime := time.Now()
|
|
|
|
// 1. 获取预设配置
|
|
var preset app.AiPreset
|
|
if req.PresetID > 0 {
|
|
err := global.GVA_DB.First(&preset, req.PresetID).Error
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "预设不存在"})
|
|
return
|
|
}
|
|
}
|
|
|
|
// 2. 获取提供商配置
|
|
var provider app.AiProvider
|
|
if req.BindingKey != "" {
|
|
var binding app.AiPresetBinding
|
|
err := global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
|
|
Order("priority ASC").
|
|
First(&binding).Error
|
|
if err == nil {
|
|
global.GVA_DB.First(&provider, binding.ProviderID)
|
|
}
|
|
}
|
|
|
|
if provider.ID == 0 {
|
|
err := global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到可用的AI提供商"})
|
|
return
|
|
}
|
|
}
|
|
|
|
// 3. 构建注入后的消息
|
|
messages, err := s.buildInjectedMessages(req, &preset)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "构建消息失败"})
|
|
return
|
|
}
|
|
|
|
// 4. 转发流式请求到上游AI
|
|
err = s.forwardStreamToAI(c, &provider, &preset, messages, userId, startTime)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
}
|
|
|
|
// forwardStreamToAI 转发流式请求到上游AI
|
|
func (s *AiProxyService) forwardStreamToAI(c *gin.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message, userId uint, startTime time.Time) error {
|
|
// 构建请求体
|
|
reqBody := map[string]interface{}{
|
|
"model": provider.Model,
|
|
"messages": messages,
|
|
"stream": true,
|
|
}
|
|
|
|
if preset != nil {
|
|
reqBody["temperature"] = preset.Temperature
|
|
reqBody["top_p"] = preset.TopP
|
|
reqBody["max_tokens"] = preset.MaxTokens
|
|
reqBody["frequency_penalty"] = preset.FrequencyPenalty
|
|
reqBody["presence_penalty"] = preset.PresencePenalty
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 创建HTTP请求
|
|
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
|
|
req, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
if provider.UpstreamKey != "" {
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
|
|
}
|
|
|
|
// 发送请求
|
|
client := &http.Client{Timeout: 300 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
|
|
}
|
|
|
|
// 设置SSE响应头
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("Transfer-Encoding", "chunked")
|
|
|
|
// 读取并转发流式响应
|
|
reader := bufio.NewReader(resp.Body)
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
return fmt.Errorf("streaming not supported")
|
|
}
|
|
|
|
var fullResponse strings.Builder
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
|
|
// 跳过空行
|
|
if len(bytes.TrimSpace(line)) == 0 {
|
|
continue
|
|
}
|
|
|
|
// 解析SSE数据
|
|
lineStr := string(line)
|
|
if strings.HasPrefix(lineStr, "data: ") {
|
|
data := strings.TrimPrefix(lineStr, "data: ")
|
|
data = strings.TrimSpace(data)
|
|
|
|
// 检查是否是结束标记
|
|
if data == "[DONE]" {
|
|
c.Writer.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
break
|
|
}
|
|
|
|
// 解析JSON并提取内容
|
|
var chunk map[string]interface{}
|
|
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
|
|
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
|
if choice, ok := choices[0].(map[string]interface{}); ok {
|
|
if delta, ok := choice["delta"].(map[string]interface{}); ok {
|
|
if content, ok := delta["content"].(string); ok {
|
|
fullResponse.WriteString(content)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 转发原始数据
|
|
c.Writer.Write(line)
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
// 应用输出正则脚本
|
|
finalContent := fullResponse.String()
|
|
if preset != nil {
|
|
finalContent = s.applyOutputRegex(finalContent, preset.RegexScripts)
|
|
}
|
|
|
|
// 记录日志
|
|
var originalMsg string
|
|
if len(messages) > 0 {
|
|
originalMsg = messages[len(messages)-1].Content
|
|
}
|
|
s.logRequest(userId, preset, provider, originalMsg, finalContent, nil, time.Since(startTime))
|
|
|
|
return nil
|
|
}
|