🎨 优化项目结构 && 完善ai配置
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -16,264 +15,252 @@ import (
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AiProxyService struct{}
|
||||
|
||||
// ProcessChatCompletion 处理聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint, req *request.ChatCompletionRequest) (resp response.ChatCompletionResponse, err error) {
|
||||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userID uint, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 获取预设配置
|
||||
var preset app.AiPreset
|
||||
if req.PresetID > 0 {
|
||||
err = global.GVA_DB.First(&preset, req.PresetID).Error
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("预设不存在: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取提供商配置
|
||||
var provider app.AiProvider
|
||||
|
||||
// 根据 binding_key 或预设绑定获取 provider
|
||||
if req.BindingKey != "" {
|
||||
// 通过 binding_key 查找绑定关系
|
||||
var binding app.AiPresetBinding
|
||||
err = global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
|
||||
Order("priority ASC").
|
||||
First(&binding).Error
|
||||
if err == nil {
|
||||
err = global.GVA_DB.First(&provider, binding.ProviderID).Error
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到,使用默认的活跃提供商
|
||||
if provider.ID == 0 {
|
||||
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建注入后的消息
|
||||
messages, err := s.buildInjectedMessages(req, &preset)
|
||||
// 1. 获取绑定配置
|
||||
binding, err := s.getBinding(userID, req)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("构建消息失败: %w", err)
|
||||
return nil, fmt.Errorf("获取绑定配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 转发到上游AI
|
||||
resp, err = s.forwardToAI(ctx, &provider, &preset, messages)
|
||||
// 2. 注入预设
|
||||
injector := NewPresetInjector(&binding.Preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
|
||||
// 3. 转发请求到上游
|
||||
resp, err := s.forwardRequest(ctx, &binding.Provider, req)
|
||||
if err != nil {
|
||||
// 记录失败日志
|
||||
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, "", err, time.Since(startTime))
|
||||
return resp, err
|
||||
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. 应用输出正则脚本
|
||||
resp.Choices[0].Message.Content = s.applyOutputRegex(resp.Choices[0].Message.Content, preset.RegexScripts)
|
||||
// 4. 处理响应
|
||||
if len(resp.Choices) > 0 {
|
||||
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
// 6. 记录成功日志
|
||||
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, resp.Choices[0].Message.Content, nil, time.Since(startTime))
|
||||
// 5. 记录日志
|
||||
s.logRequest(userID, binding, req, resp, nil, time.Since(startTime))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// buildInjectedMessages 构建注入预设后的消息数组
|
||||
func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionRequest, preset *app.AiPreset) ([]request.Message, error) {
|
||||
if preset == nil || preset.ID == 0 {
|
||||
return req.Messages, nil
|
||||
}
|
||||
// ProcessChatCompletionStream 处理流式聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint, req *request.ChatCompletionRequest) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 按 injection_order 排序 prompts
|
||||
sortedPrompts := make([]app.Prompt, len(preset.Prompts))
|
||||
copy(sortedPrompts, preset.Prompts)
|
||||
sort.Slice(sortedPrompts, func(i, j int) bool {
|
||||
return sortedPrompts[i].InjectionOrder < sortedPrompts[j].InjectionOrder
|
||||
})
|
||||
|
||||
messages := make([]request.Message, 0)
|
||||
|
||||
// 2. 根据 injection_depth 插入到对话历史中
|
||||
for _, prompt := range sortedPrompts {
|
||||
if prompt.Marker {
|
||||
continue // 跳过标记提示词
|
||||
}
|
||||
|
||||
// 替换变量
|
||||
content := s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard)
|
||||
|
||||
// 根据 injection_depth 决定插入位置
|
||||
// depth=0: 插入到最前面(系统提示词)
|
||||
// depth>0: 从对话历史末尾往前数 depth 条消息的位置插入
|
||||
if prompt.InjectionDepth == 0 || prompt.SystemPrompt {
|
||||
messages = append(messages, request.Message{
|
||||
Role: prompt.Role,
|
||||
Content: content,
|
||||
})
|
||||
} else {
|
||||
// 先添加用户消息,稍后根据 depth 插入
|
||||
// 这里简化处理,将非系统提示词也添加到前面
|
||||
messages = append(messages, request.Message{
|
||||
Role: prompt.Role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 添加用户消息
|
||||
messages = append(messages, req.Messages...)
|
||||
|
||||
// 4. 应用输入正则脚本 (placement=1)
|
||||
for i := range messages {
|
||||
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// replaceVariables 替换变量
|
||||
func (s *AiProxyService) replaceVariables(content string, vars map[string]string, card *request.CharacterCard) string {
|
||||
result := content
|
||||
|
||||
// 替换自定义变量
|
||||
for key, value := range vars {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
result = replaceAll(result, placeholder, value)
|
||||
}
|
||||
|
||||
// 替换角色卡片变量
|
||||
if card != nil {
|
||||
result = replaceAll(result, "{{char}}", card.Name)
|
||||
result = replaceAll(result, "{{char_name}}", card.Name)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// applyInputRegex 应用输入正则脚本
|
||||
func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScript) string {
|
||||
for _, script := range scripts {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
if !containsPlacement(script.Placement, 1) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 编译正则表达式
|
||||
re, err := regexp.Compile(script.FindRegex)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行替换
|
||||
content = re.ReplaceAllString(content, script.ReplaceString)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// applyOutputRegex 应用输出正则脚本
|
||||
func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScript) string {
|
||||
for _, script := range scripts {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
if !containsPlacement(script.Placement, 2) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 编译正则表达式
|
||||
re, err := regexp.Compile(script.FindRegex)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行替换
|
||||
content = re.ReplaceAllString(content, script.ReplaceString)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// forwardToAI 转发请求到上游AI
|
||||
func (s *AiProxyService) forwardToAI(ctx context.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message) (response.ChatCompletionResponse, error) {
|
||||
// 构建请求体
|
||||
reqBody := map[string]interface{}{
|
||||
"model": provider.Model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if preset != nil {
|
||||
reqBody["temperature"] = preset.Temperature
|
||||
reqBody["top_p"] = preset.TopP
|
||||
reqBody["max_tokens"] = preset.MaxTokens
|
||||
reqBody["frequency_penalty"] = preset.FrequencyPenalty
|
||||
reqBody["presence_penalty"] = preset.PresencePenalty
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
// 1. 获取绑定配置
|
||||
binding, err := s.getBinding(userID, req)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
// 2. 注入预设
|
||||
injector := NewPresetInjector(&binding.Preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
|
||||
// 3. 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 4. 转发流式请求
|
||||
err = s.forwardStreamRequest(c, &binding.Provider, req, injector)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
|
||||
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
|
||||
// getBinding 获取绑定配置
|
||||
func (s *AiProxyService) getBinding(userID uint, req *request.ChatCompletionRequest) (*app.AiPresetBinding, error) {
|
||||
var binding app.AiPresetBinding
|
||||
|
||||
query := global.GVA_DB.Preload("Preset").Preload("Provider").Where("user_id = ? AND enabled = ?", userID, true)
|
||||
|
||||
// 优先使用 binding_name
|
||||
if req.BindingName != "" {
|
||||
query = query.Where("name = ?", req.BindingName)
|
||||
} else if req.PresetName != "" && req.ProviderName != "" {
|
||||
// 使用 preset_name 和 provider_name
|
||||
query = query.Joins("JOIN ai_presets ON ai_presets.id = ai_preset_bindings.preset_id").
|
||||
Joins("JOIN ai_providers ON ai_providers.id = ai_preset_bindings.provider_id").
|
||||
Where("ai_presets.name = ? AND ai_providers.name = ?", req.PresetName, req.ProviderName)
|
||||
} else {
|
||||
// 使用默认绑定(第一个启用的)
|
||||
query = query.Order("id ASC")
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if provider.UpstreamKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
|
||||
if err := query.First(&binding).Error; err != nil {
|
||||
return nil, fmt.Errorf("未找到可用的绑定配置")
|
||||
}
|
||||
|
||||
if !binding.Provider.Enabled {
|
||||
return nil, fmt.Errorf("提供商已禁用")
|
||||
}
|
||||
|
||||
if !binding.Preset.Enabled {
|
||||
return nil, fmt.Errorf("预设已禁用")
|
||||
}
|
||||
|
||||
return &binding, nil
|
||||
}
|
||||
|
||||
// forwardRequest 转发请求到上游 AI 服务
|
||||
func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||||
// 使用提供商的默认模型(如果请求中没有指定)
|
||||
if req.Model == "" && provider.Model != "" {
|
||||
req.Model = provider.Model
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return response.ChatCompletionResponse{}, fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var aiResp response.ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &aiResp); err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
var resp response.ChatCompletionResponse
|
||||
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
return aiResp, nil
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// forwardStreamRequest 转发流式请求
|
||||
func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error {
|
||||
// 使用提供商的默认模型
|
||||
if req.Model == "" && provider.Model != "" {
|
||||
req.Model = provider.Model
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 读取并转发流式响应
|
||||
reader := bufio.NewReader(httpResp.Body)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return fmt.Errorf("不支持流式响应")
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 跳过空行
|
||||
if len(bytes.TrimSpace(line)) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理 SSE 数据
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||||
data = bytes.TrimSpace(data)
|
||||
|
||||
// 检查是否是结束标记
|
||||
if string(data) == "[DONE]" {
|
||||
c.Writer.Write([]byte("data: [DONE]\n\n"))
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 解析并处理响应
|
||||
var chunk response.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用输出正则处理
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
chunk.Choices[0].Delta.Content = injector.ProcessResponse(chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
// 重新序列化并发送
|
||||
processedData, _ := json.Marshal(chunk)
|
||||
c.Writer.Write([]byte("data: "))
|
||||
c.Writer.Write(processedData)
|
||||
c.Writer.Write([]byte("\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logRequest 记录请求日志
|
||||
func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider *app.AiProvider, originalMsg, responseText string, err error, latency time.Duration) {
|
||||
func (s *AiProxyService) logRequest(userID uint, binding *app.AiPresetBinding, req *request.ChatCompletionRequest, resp *response.ChatCompletionResponse, err error, duration time.Duration) {
|
||||
log := app.AiRequestLog{
|
||||
UserID: &userId,
|
||||
OriginalMessage: originalMsg,
|
||||
ResponseText: responseText,
|
||||
LatencyMs: int(latency.Milliseconds()),
|
||||
}
|
||||
|
||||
if preset != nil {
|
||||
presetID := preset.ID
|
||||
log.PresetID = &presetID
|
||||
}
|
||||
|
||||
if provider != nil {
|
||||
providerID := provider.ID
|
||||
log.ProviderID = &providerID
|
||||
UserID: userID,
|
||||
BindingID: binding.ID,
|
||||
ProviderID: binding.ProviderID,
|
||||
PresetID: binding.PresetID,
|
||||
Model: req.Model,
|
||||
Duration: duration.Milliseconds(),
|
||||
RequestTime: time.Now(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -281,21 +268,12 @@ func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider
|
||||
log.ErrorMessage = err.Error()
|
||||
} else {
|
||||
log.Status = "success"
|
||||
if resp != nil {
|
||||
log.PromptTokens = resp.Usage.PromptTokens
|
||||
log.CompletionTokens = resp.Usage.CompletionTokens
|
||||
log.TotalTokens = resp.Usage.TotalTokens
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_DB.Create(&log)
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func replaceAll(s, old, new string) string {
|
||||
return strings.ReplaceAll(s, old, new)
|
||||
}
|
||||
|
||||
func containsPlacement(placements []int, target int) bool {
|
||||
for _, p := range placements {
|
||||
if p == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user