Files
ai_proxy/server/service/app/ai_proxy.go
2026-03-03 15:39:23 +08:00

280 lines
7.9 KiB
Go

package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type AiProxyService struct{}
// ProcessChatCompletion 处理聊天补全请求
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userID uint, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
startTime := time.Now()
// 1. 获取绑定配置
binding, err := s.getBinding(userID, req)
if err != nil {
return nil, fmt.Errorf("获取绑定配置失败: %w", err)
}
// 2. 注入预设
injector := NewPresetInjector(&binding.Preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
// 3. 转发请求到上游
resp, err := s.forwardRequest(ctx, &binding.Provider, req)
if err != nil {
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
return nil, err
}
// 4. 处理响应
if len(resp.Choices) > 0 {
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
}
// 5. 记录日志
s.logRequest(userID, binding, req, resp, nil, time.Since(startTime))
return resp, nil
}
// ProcessChatCompletionStream 处理流式聊天补全请求
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint, req *request.ChatCompletionRequest) {
startTime := time.Now()
// 1. 获取绑定配置
binding, err := s.getBinding(userID, req)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 注入预设
injector := NewPresetInjector(&binding.Preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 4. 转发流式请求
err = s.forwardStreamRequest(c, &binding.Provider, req, injector)
if err != nil {
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
}
}
// getBinding 获取绑定配置
func (s *AiProxyService) getBinding(userID uint, req *request.ChatCompletionRequest) (*app.AiPresetBinding, error) {
var binding app.AiPresetBinding
query := global.GVA_DB.Preload("Preset").Preload("Provider").Where("user_id = ? AND enabled = ?", userID, true)
// 优先使用 binding_name
if req.BindingName != "" {
query = query.Where("name = ?", req.BindingName)
} else if req.PresetName != "" && req.ProviderName != "" {
// 使用 preset_name 和 provider_name
query = query.Joins("JOIN ai_presets ON ai_presets.id = ai_preset_bindings.preset_id").
Joins("JOIN ai_providers ON ai_providers.id = ai_preset_bindings.provider_id").
Where("ai_presets.name = ? AND ai_providers.name = ?", req.PresetName, req.ProviderName)
} else {
// 使用默认绑定(第一个启用的)
query = query.Order("id ASC")
}
if err := query.First(&binding).Error; err != nil {
return nil, fmt.Errorf("未找到可用的绑定配置")
}
if !binding.Provider.Enabled {
return nil, fmt.Errorf("提供商已禁用")
}
if !binding.Preset.Enabled {
return nil, fmt.Errorf("预设已禁用")
}
return &binding, nil
}
// forwardRequest 转发请求到上游 AI 服务
func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
// 使用提供商的默认模型(如果请求中没有指定)
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
// 构建请求
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
// 发送请求
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
// 解析响应
var resp response.ChatCompletionResponse
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return &resp, nil
}
// forwardStreamRequest 转发流式请求
func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error {
// 使用提供商的默认模型
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, err := json.Marshal(req)
if err != nil {
return err
}
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
// 读取并转发流式响应
reader := bufio.NewReader(httpResp.Body)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return fmt.Errorf("不支持流式响应")
}
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
// 跳过空行
if len(bytes.TrimSpace(line)) == 0 {
continue
}
// 处理 SSE 数据
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
data = bytes.TrimSpace(data)
// 检查是否是结束标记
if string(data) == "[DONE]" {
c.Writer.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
break
}
// 解析并处理响应
var chunk response.ChatCompletionStreamResponse
if err := json.Unmarshal(data, &chunk); err != nil {
continue
}
// 应用输出正则处理
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
chunk.Choices[0].Delta.Content = injector.ProcessResponse(chunk.Choices[0].Delta.Content)
}
// 重新序列化并发送
processedData, _ := json.Marshal(chunk)
c.Writer.Write([]byte("data: "))
c.Writer.Write(processedData)
c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
}
return nil
}
// logRequest 记录请求日志
func (s *AiProxyService) logRequest(userID uint, binding *app.AiPresetBinding, req *request.ChatCompletionRequest, resp *response.ChatCompletionResponse, err error, duration time.Duration) {
log := app.AiRequestLog{
UserID: userID,
BindingID: binding.ID,
ProviderID: binding.ProviderID,
PresetID: binding.PresetID,
Model: req.Model,
Duration: duration.Milliseconds(),
RequestTime: time.Now(),
}
if err != nil {
log.Status = "error"
log.ErrorMessage = err.Error()
} else {
log.Status = "success"
if resp != nil {
log.PromptTokens = resp.Usage.PromptTokens
log.CompletionTokens = resp.Usage.CompletionTokens
log.TotalTokens = resp.Usage.TotalTokens
}
}
global.GVA_DB.Create(&log)
}