404 lines
12 KiB
Go
404 lines
12 KiB
Go
package app
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"git.echol.cn/loser/ai_proxy/server/global"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
type AiProxyService struct{}
|
||
|
||
// ProcessChatCompletion 处理聊天补全请求
|
||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||
// 记录请求参数
|
||
global.GVA_LOG.Info("收到 ChatCompletion 请求",
|
||
zap.String("model", req.Model),
|
||
zap.Any("messages", req.Messages),
|
||
zap.Any("full_request", req),
|
||
)
|
||
|
||
// 1. 根据模型获取配置
|
||
if req.Model == "" {
|
||
return nil, fmt.Errorf("model 参数不能为空")
|
||
}
|
||
|
||
preset, provider, err := s.getConfigByModel(req.Model)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 2. 注入预设
|
||
var injector *PresetInjector
|
||
if preset != nil {
|
||
injector = NewPresetInjector(preset)
|
||
req.Messages = injector.InjectMessages(req.Messages)
|
||
injector.ApplyPresetParameters(req)
|
||
}
|
||
|
||
// 3. 转发请求到上游
|
||
resp, err := s.forwardRequest(ctx, provider, req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取 AI 输出内容
|
||
aiOutput := ""
|
||
if len(resp.Choices) > 0 {
|
||
aiOutput = resp.Choices[0].Message.Content
|
||
}
|
||
|
||
// 应用预设处理(使用同一个 injector 实例)
|
||
if injector != nil && len(resp.Choices) > 0 {
|
||
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
|
||
aiOutput = resp.Choices[0].Message.Content
|
||
|
||
// 添加正则执行日志到响应
|
||
regexLogs := injector.GetRegexLogs()
|
||
if regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0 {
|
||
resp.RegexLogs = regexLogs
|
||
}
|
||
}
|
||
|
||
// 4. 处理响应并收集正则日志
|
||
if resp != nil && resp.Usage != nil {
|
||
// 统一填充 standard_usage,方便上游使用统一格式解析
|
||
if resp.Usage.PromptTokens > 0 || resp.Usage.CompletionTokens > 0 || resp.Usage.TotalTokens > 0 {
|
||
resp.StandardUsage = &response.ChatCompletionUsage{
|
||
PromptTokens: resp.Usage.PromptTokens,
|
||
CompletionTokens: resp.Usage.CompletionTokens,
|
||
TotalTokens: resp.Usage.TotalTokens,
|
||
}
|
||
}
|
||
}
|
||
|
||
// 记录响应内容(统一日志输出)
|
||
logFields := []zap.Field{
|
||
zap.String("ai_output", aiOutput),
|
||
}
|
||
if resp.Usage != nil {
|
||
logFields = append(logFields, zap.Any("usage", resp.Usage))
|
||
}
|
||
if resp.StandardUsage != nil {
|
||
logFields = append(logFields, zap.Any("standard_usage", resp.StandardUsage))
|
||
}
|
||
|
||
// 添加正则脚本执行日志
|
||
if injector != nil {
|
||
regexLogs := injector.GetRegexLogs()
|
||
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
|
||
// 收集触发的脚本名称
|
||
triggeredScripts := make([]string, 0)
|
||
for _, scriptLog := range regexLogs.InputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
for _, scriptLog := range regexLogs.OutputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
|
||
if len(triggeredScripts) > 0 {
|
||
logFields = append(logFields,
|
||
zap.Strings("triggered_regex_scripts", triggeredScripts),
|
||
zap.Int("total_matches", regexLogs.TotalMatches),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
logFields = append(logFields, zap.Any("full_response", resp))
|
||
global.GVA_LOG.Info("ChatCompletion 响应", logFields...)
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// ProcessChatCompletionStream 处理流式聊天补全请求
|
||
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, req *request.ChatCompletionRequest) {
|
||
// 记录请求参数
|
||
global.GVA_LOG.Info("收到 ChatCompletion 流式请求",
|
||
zap.String("model", req.Model),
|
||
zap.Any("messages", req.Messages),
|
||
zap.Any("full_request", req),
|
||
)
|
||
|
||
// 1. 根据模型获取配置
|
||
if req.Model == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
|
||
return
|
||
}
|
||
|
||
preset, provider, err := s.getConfigByModel(req.Model)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 2. 注入预设
|
||
var injector *PresetInjector
|
||
if preset != nil {
|
||
injector = NewPresetInjector(preset)
|
||
req.Messages = injector.InjectMessages(req.Messages)
|
||
injector.ApplyPresetParameters(req)
|
||
}
|
||
|
||
// 3. 设置 SSE 响应头
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
|
||
// 4. 转发流式请求
|
||
err = s.forwardStreamRequest(c, provider, req, injector)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// getConfigByModel 根据模型名称获取配置
|
||
func (s *AiProxyService) getConfigByModel(modelName string) (*app.AiPreset, *app.AiProvider, error) {
|
||
// 查找启用的模型配置
|
||
var model app.AiModel
|
||
err := global.GVA_DB.Preload("Provider").Preload("Preset").
|
||
Where("name = ? AND enabled = ?", modelName, true).
|
||
First(&model).Error
|
||
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("未找到模型配置: %s", modelName)
|
||
}
|
||
|
||
// 检查提供商是否启用
|
||
if !model.Provider.Enabled {
|
||
return nil, nil, fmt.Errorf("提供商已禁用")
|
||
}
|
||
|
||
return model.Preset, &model.Provider, nil
|
||
}
|
||
|
||
// forwardRequest 转发请求到上游 AI 服务
|
||
func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||
// 使用提供商的默认模型(如果请求中没有指定)
|
||
if req.Model == "" && provider.Model != "" {
|
||
req.Model = provider.Model
|
||
}
|
||
|
||
// 构建请求
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
|
||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||
|
||
// 发送请求
|
||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||
httpResp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer httpResp.Body.Close()
|
||
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(httpResp.Body)
|
||
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||
}
|
||
|
||
// 解析响应
|
||
var resp response.ChatCompletionResponse
|
||
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
return &resp, nil
|
||
}
|
||
|
||
// forwardStreamRequest 转发流式请求
|
||
func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error {
|
||
// 使用提供商的默认模型
|
||
if req.Model == "" && provider.Model != "" {
|
||
req.Model = provider.Model
|
||
}
|
||
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
|
||
httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||
|
||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||
httpResp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer httpResp.Body.Close()
|
||
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(httpResp.Body)
|
||
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||
}
|
||
|
||
// 读取并转发流式响应
|
||
reader := bufio.NewReader(httpResp.Body)
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
return fmt.Errorf("不支持流式响应")
|
||
}
|
||
|
||
// 聚合 AI 输出内容用于日志和正则处理
|
||
var fullContent bytes.Buffer
|
||
|
||
for {
|
||
line, err := reader.ReadBytes('\n')
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 跳过空行
|
||
if len(bytes.TrimSpace(line)) == 0 {
|
||
continue
|
||
}
|
||
|
||
// 处理 SSE 数据
|
||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||
data = bytes.TrimSpace(data)
|
||
|
||
// 检查是否是结束标记
|
||
if string(data) == "[DONE]" {
|
||
c.Writer.Write([]byte("data: [DONE]\n\n"))
|
||
flusher.Flush()
|
||
break
|
||
}
|
||
|
||
// 解析并处理响应
|
||
var chunk response.ChatCompletionStreamResponse
|
||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||
continue
|
||
}
|
||
|
||
// 收集原始内容(不在流式传输时应用正则,避免重复处理)
|
||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||
fullContent.WriteString(chunk.Choices[0].Delta.Content)
|
||
}
|
||
|
||
// 直接转发原始响应
|
||
c.Writer.Write([]byte("data: "))
|
||
c.Writer.Write(data)
|
||
c.Writer.Write([]byte("\n\n"))
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
|
||
// 流式结束后,对完整内容应用输出正则处理(仅用于日志记录)
|
||
processedContent := fullContent.String()
|
||
if injector != nil && processedContent != "" {
|
||
processedContent = injector.ProcessResponse(processedContent)
|
||
}
|
||
|
||
// 流式请求结束后记录日志
|
||
logFields := []zap.Field{
|
||
zap.String("ai_output_original", fullContent.String()),
|
||
zap.String("ai_output_processed", processedContent),
|
||
}
|
||
|
||
// 添加正则脚本执行日志
|
||
if injector != nil {
|
||
regexLogs := injector.GetRegexLogs()
|
||
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
|
||
// 收集触发的脚本名称
|
||
triggeredScripts := make([]string, 0)
|
||
for _, scriptLog := range regexLogs.InputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
for _, scriptLog := range regexLogs.OutputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
|
||
if len(triggeredScripts) > 0 {
|
||
logFields = append(logFields,
|
||
zap.Strings("triggered_regex_scripts", triggeredScripts),
|
||
zap.Int("total_matches", regexLogs.TotalMatches),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
global.GVA_LOG.Info("ChatCompletion 流式响应完成", logFields...)
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetAvailableModels 获取用户可用的模型列表
|
||
func (s *AiProxyService) GetAvailableModels(apiKey *app.AiApiKey) (*response.ModelListResponse, error) {
|
||
// 查询所有启用的模型
|
||
var models []app.AiModel
|
||
query := global.GVA_DB.Where("enabled = ?", true)
|
||
|
||
// 如果 API Key 限制了模型,只返回允许的模型
|
||
if len(apiKey.AllowedModels) > 0 {
|
||
query = query.Where("name IN ?", apiKey.AllowedModels)
|
||
}
|
||
|
||
if err := query.Find(&models).Error; err != nil {
|
||
return nil, fmt.Errorf("查询模型列表失败: %w", err)
|
||
}
|
||
|
||
// 构建响应
|
||
modelList := &response.ModelListResponse{
|
||
Object: "list",
|
||
Data: make([]response.ModelInfo, 0, len(models)),
|
||
}
|
||
|
||
// 去重(同一模型可能在多个提供商下配置)
|
||
seen := make(map[string]bool)
|
||
for _, model := range models {
|
||
if !seen[model.Name] {
|
||
seen[model.Name] = true
|
||
modelList.Data = append(modelList.Data, response.ModelInfo{
|
||
ID: model.Name,
|
||
Object: "model",
|
||
Created: model.CreatedAt.Unix(),
|
||
OwnedBy: "system",
|
||
})
|
||
}
|
||
}
|
||
|
||
return modelList, nil
|
||
}
|