Files
ai_proxy/server/service/app/ai_proxy.go

398 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type AiProxyService struct{}
// ProcessChatCompletion 处理聊天补全请求
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
// 记录请求参数
global.GVA_LOG.Info("收到 ChatCompletion 请求",
zap.String("model", req.Model),
zap.Any("messages", req.Messages),
zap.Any("full_request", req),
)
// 1. 根据模型获取配置
if req.Model == "" {
return nil, fmt.Errorf("model 参数不能为空")
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
return nil, err
}
// 2. 注入预设
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
}
// 3. 转发请求到上游
resp, err := s.forwardRequest(ctx, provider, req)
if err != nil {
return nil, err
}
// 获取 AI 输出内容
aiOutput := ""
if len(resp.Choices) > 0 {
aiOutput = resp.Choices[0].Message.Content
}
// 应用预设处理(使用同一个 injector 实例)
if injector != nil && len(resp.Choices) > 0 {
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
aiOutput = resp.Choices[0].Message.Content
}
// 4. 处理响应并收集正则日志
if resp != nil && resp.Usage != nil {
// 统一填充 standard_usage方便上游使用统一格式解析
if resp.Usage.PromptTokens > 0 || resp.Usage.CompletionTokens > 0 || resp.Usage.TotalTokens > 0 {
resp.StandardUsage = &response.ChatCompletionUsage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
}
// 记录响应内容(统一日志输出)
logFields := []zap.Field{
zap.String("ai_output", aiOutput),
}
if resp.Usage != nil {
logFields = append(logFields, zap.Any("usage", resp.Usage))
}
if resp.StandardUsage != nil {
logFields = append(logFields, zap.Any("standard_usage", resp.StandardUsage))
}
// 添加正则脚本执行日志
if injector != nil {
regexLogs := injector.GetRegexLogs()
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
// 收集触发的脚本名称
triggeredScripts := make([]string, 0)
for _, scriptLog := range regexLogs.InputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
for _, scriptLog := range regexLogs.OutputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
if len(triggeredScripts) > 0 {
logFields = append(logFields,
zap.Strings("triggered_regex_scripts", triggeredScripts),
zap.Int("total_matches", regexLogs.TotalMatches),
)
}
}
}
logFields = append(logFields, zap.Any("full_response", resp))
global.GVA_LOG.Info("ChatCompletion 响应", logFields...)
return resp, nil
}
// ProcessChatCompletionStream 处理流式聊天补全请求
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, req *request.ChatCompletionRequest) {
// 记录请求参数
global.GVA_LOG.Info("收到 ChatCompletion 流式请求",
zap.String("model", req.Model),
zap.Any("messages", req.Messages),
zap.Any("full_request", req),
)
// 1. 根据模型获取配置
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
return
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 注入预设
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
}
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 4. 转发流式请求
err = s.forwardStreamRequest(c, provider, req, injector)
if err != nil {
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
}
}
// getConfigByModel 根据模型名称获取配置
func (s *AiProxyService) getConfigByModel(modelName string) (*app.AiPreset, *app.AiProvider, error) {
// 查找启用的模型配置
var model app.AiModel
err := global.GVA_DB.Preload("Provider").Preload("Preset").
Where("name = ? AND enabled = ?", modelName, true).
First(&model).Error
if err != nil {
return nil, nil, fmt.Errorf("未找到模型配置: %s", modelName)
}
// 检查提供商是否启用
if !model.Provider.Enabled {
return nil, nil, fmt.Errorf("提供商已禁用")
}
return model.Preset, &model.Provider, nil
}
// forwardRequest 转发请求到上游 AI 服务
func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
// 使用提供商的默认模型(如果请求中没有指定)
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
// 构建请求
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
// 发送请求
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
// 解析响应
var resp response.ChatCompletionResponse
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return &resp, nil
}
// forwardStreamRequest 转发流式请求
func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error {
// 使用提供商的默认模型
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, err := json.Marshal(req)
if err != nil {
return err
}
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
// 读取并转发流式响应
reader := bufio.NewReader(httpResp.Body)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return fmt.Errorf("不支持流式响应")
}
// 聚合 AI 输出内容用于日志和正则处理
var fullContent bytes.Buffer
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
// 跳过空行
if len(bytes.TrimSpace(line)) == 0 {
continue
}
// 处理 SSE 数据
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
data = bytes.TrimSpace(data)
// 检查是否是结束标记
if string(data) == "[DONE]" {
c.Writer.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
break
}
// 解析并处理响应
var chunk response.ChatCompletionStreamResponse
if err := json.Unmarshal(data, &chunk); err != nil {
continue
}
// 收集原始内容(不在流式传输时应用正则,避免重复处理)
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
fullContent.WriteString(chunk.Choices[0].Delta.Content)
}
// 直接转发原始响应
c.Writer.Write([]byte("data: "))
c.Writer.Write(data)
c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
}
// 流式结束后,对完整内容应用输出正则处理(仅用于日志记录)
processedContent := fullContent.String()
if injector != nil && processedContent != "" {
processedContent = injector.ProcessResponse(processedContent)
}
// 流式请求结束后记录日志
logFields := []zap.Field{
zap.String("ai_output_original", fullContent.String()),
zap.String("ai_output_processed", processedContent),
}
// 添加正则脚本执行日志
if injector != nil {
regexLogs := injector.GetRegexLogs()
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
// 收集触发的脚本名称
triggeredScripts := make([]string, 0)
for _, scriptLog := range regexLogs.InputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
for _, scriptLog := range regexLogs.OutputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
if len(triggeredScripts) > 0 {
logFields = append(logFields,
zap.Strings("triggered_regex_scripts", triggeredScripts),
zap.Int("total_matches", regexLogs.TotalMatches),
)
}
}
}
global.GVA_LOG.Info("ChatCompletion 流式响应完成", logFields...)
return nil
}
// GetAvailableModels 获取用户可用的模型列表
func (s *AiProxyService) GetAvailableModels(apiKey *app.AiApiKey) (*response.ModelListResponse, error) {
// 查询所有启用的模型
var models []app.AiModel
query := global.GVA_DB.Where("enabled = ?", true)
// 如果 API Key 限制了模型,只返回允许的模型
if len(apiKey.AllowedModels) > 0 {
query = query.Where("name IN ?", apiKey.AllowedModels)
}
if err := query.Find(&models).Error; err != nil {
return nil, fmt.Errorf("查询模型列表失败: %w", err)
}
// 构建响应
modelList := &response.ModelListResponse{
Object: "list",
Data: make([]response.ModelInfo, 0, len(models)),
}
// 去重(同一模型可能在多个提供商下配置)
seen := make(map[string]bool)
for _, model := range models {
if !seen[model.Name] {
seen[model.Name] = true
modelList.Data = append(modelList.Data, response.ModelInfo{
ID: model.Name,
Object: "model",
Created: model.CreatedAt.Unix(),
OwnedBy: "system",
})
}
}
return modelList, nil
}