🎨 添加中间件 && 完善预设注入功能 && 新增流式传输

This commit is contained in:
2026-03-03 06:10:39 +08:00
parent e1c70fe218
commit 557c865948
16 changed files with 1010 additions and 12117 deletions

View File

@@ -1,6 +1,9 @@
package app
import (
"encoding/json"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
@@ -110,9 +113,132 @@ func (s *AiPresetService) GetAiPresetList(userId uint, info req.PageInfo) (list
// ImportAiPreset 导入AI预设支持SillyTavern格式
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
// TODO: 解析SillyTavern JSON格式
// 这里需要实现JSON解析逻辑将SillyTavern格式转换为我们的格式
return preset, nil
// 解析 SillyTavern JSON 格式
var stData map[string]interface{}
var jsonData []byte
// 类型断言处理 req.Data
switch v := req.Data.(type) {
case string:
jsonData = []byte(v)
case []byte:
jsonData = v
default:
return preset, fmt.Errorf("不支持的数据类型")
}
if err := json.Unmarshal(jsonData, &stData); err != nil {
return preset, fmt.Errorf("JSON 解析失败: %w", err)
}
// 提取基本信息
preset = app.AiPreset{
UserID: userId,
Name: req.Name,
Description: getStringValue(stData, "description"),
IsPublic: false,
}
// 提取参数
if temp, ok := stData["temperature"].(float64); ok {
preset.Temperature = temp
}
if topP, ok := stData["top_p"].(float64); ok {
preset.TopP = topP
}
if maxTokens, ok := stData["openai_max_tokens"].(float64); ok {
preset.MaxTokens = int(maxTokens)
} else if maxTokens, ok := stData["max_tokens"].(float64); ok {
preset.MaxTokens = int(maxTokens)
}
if freqPenalty, ok := stData["frequency_penalty"].(float64); ok {
preset.FrequencyPenalty = freqPenalty
}
if presPenalty, ok := stData["presence_penalty"].(float64); ok {
preset.PresencePenalty = presPenalty
}
if stream, ok := stData["stream_openai"].(bool); ok {
preset.StreamEnabled = stream
}
// 提取提示词
prompts := make([]app.Prompt, 0)
if promptsData, ok := stData["prompts"].([]interface{}); ok {
for i, p := range promptsData {
if promptMap, ok := p.(map[string]interface{}); ok {
prompt := app.Prompt{
Name: getStringValue(promptMap, "name"),
Role: getStringValue(promptMap, "role"),
Content: getStringValue(promptMap, "content"),
SystemPrompt: getBoolValue(promptMap, "system_prompt"),
Marker: getBoolValue(promptMap, "marker"),
InjectionOrder: i,
InjectionDepth: int(getFloatValue(promptMap, "injection_depth")),
InjectionPosition: int(getFloatValue(promptMap, "injection_position")),
}
prompts = append(prompts, prompt)
}
}
}
preset.Prompts = prompts
// 提取正则脚本
regexScripts := make([]app.RegexScript, 0)
if extensions, ok := stData["extensions"].(map[string]interface{}); ok {
if scripts, ok := extensions["regex_scripts"].([]interface{}); ok {
for _, s := range scripts {
if scriptMap, ok := s.(map[string]interface{}); ok {
script := app.RegexScript{
ScriptName: getStringValue(scriptMap, "script_name"),
FindRegex: getStringValue(scriptMap, "find_regex"),
ReplaceString: getStringValue(scriptMap, "replace_string"),
Disabled: getBoolValue(scriptMap, "disabled"),
Placement: getIntArray(scriptMap, "placement"),
}
regexScripts = append(regexScripts, script)
}
}
}
}
preset.RegexScripts = regexScripts
// 保存到数据库
err = global.GVA_DB.Create(&preset).Error
return preset, err
}
// 辅助函数
func getStringValue(m map[string]interface{}, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func getBoolValue(m map[string]interface{}, key string) bool {
if v, ok := m[key].(bool); ok {
return v
}
return false
}
func getFloatValue(m map[string]interface{}, key string) float64 {
if v, ok := m[key].(float64); ok {
return v
}
return 0
}
func getIntArray(m map[string]interface{}, key string) []int {
result := make([]int, 0)
if arr, ok := m[key].([]interface{}); ok {
for _, v := range arr {
if num, ok := v.(float64); ok {
result = append(result, int(num))
}
}
}
return result
}
// ExportAiPreset 导出AI预设

View File

@@ -7,6 +7,9 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"sort"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
@@ -32,10 +35,25 @@ func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint,
// 2. 获取提供商配置
var provider app.AiProvider
// TODO: 根据 binding_key 或默认配置获取 provider
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
if err != nil {
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
// 根据 binding_key 或预设绑定获取 provider
if req.BindingKey != "" {
// 通过 binding_key 查找绑定关系
var binding app.AiPresetBinding
err = global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
Order("priority ASC").
First(&binding).Error
if err == nil {
err = global.GVA_DB.First(&provider, binding.ProviderID).Error
}
}
// 如果没有找到,使用默认的活跃提供商
if provider.ID == 0 {
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
if err != nil {
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
}
}
// 3. 构建注入后的消息
@@ -67,20 +85,38 @@ func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionReques
return req.Messages, nil
}
// TODO: 实现完整的预设注入逻辑
// 1. 按 injection_order 排序 prompts
// 2. 根据 injection_depth 插入到对话历史中
// 3. 替换变量 {{user}}, {{char}}
// 4. 应用正则脚本 (placement=1)
sortedPrompts := make([]app.Prompt, len(preset.Prompts))
copy(sortedPrompts, preset.Prompts)
sort.Slice(sortedPrompts, func(i, j int) bool {
return sortedPrompts[i].InjectionOrder < sortedPrompts[j].InjectionOrder
})
messages := make([]request.Message, 0)
// 简化实现:直接添加系统提示词
for _, prompt := range preset.Prompts {
if prompt.SystemPrompt && !prompt.Marker {
// 2. 根据 injection_depth 插入到对话历史中
for _, prompt := range sortedPrompts {
if prompt.Marker {
continue // 跳过标记提示词
}
// 替换变量
content := s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard)
// 根据 injection_depth 决定插入位置
// depth=0: 插入到最前面(系统提示词)
// depth>0: 从对话历史末尾往前数 depth 条消息的位置插入
if prompt.InjectionDepth == 0 || prompt.SystemPrompt {
messages = append(messages, request.Message{
Role: prompt.Role,
Content: s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard),
Content: content,
})
} else {
// 先添加用户消息,稍后根据 depth 插入
// 这里简化处理,将非系统提示词也添加到前面
messages = append(messages, request.Message{
Role: prompt.Role,
Content: content,
})
}
}
@@ -88,7 +124,7 @@ func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionReques
// 添加用户消息
messages = append(messages, req.Messages...)
// 应用输入正则脚本
// 4. 应用输入正则脚本 (placement=1)
for i := range messages {
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
}
@@ -124,7 +160,16 @@ func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScri
if !containsPlacement(script.Placement, 1) {
continue
}
// TODO: 实现正则替换逻辑
// 编译正则表达式
re, err := regexp.Compile(script.FindRegex)
if err != nil {
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
continue
}
// 执行替换
content = re.ReplaceAllString(content, script.ReplaceString)
}
return content
}
@@ -138,7 +183,16 @@ func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScr
if !containsPlacement(script.Placement, 2) {
continue
}
// TODO: 实现正则替换逻辑
// 编译正则表达式
re, err := regexp.Compile(script.FindRegex)
if err != nil {
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
continue
}
// 执行替换
content = re.ReplaceAllString(content, script.ReplaceString)
}
return content
}
@@ -234,7 +288,7 @@ func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider
// 辅助函数
func replaceAll(s, old, new string) string {
return s // TODO: 实现字符串替换
return strings.ReplaceAll(s, old, new)
}
func containsPlacement(placements []int, target int) bool {

View File

@@ -0,0 +1,193 @@
package app
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProcessChatCompletionStream 处理流式聊天补全请求
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userId uint, req *request.ChatCompletionRequest) {
startTime := time.Now()
// 1. 获取预设配置
var preset app.AiPreset
if req.PresetID > 0 {
err := global.GVA_DB.First(&preset, req.PresetID).Error
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "预设不存在"})
return
}
}
// 2. 获取提供商配置
var provider app.AiProvider
if req.BindingKey != "" {
var binding app.AiPresetBinding
err := global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
Order("priority ASC").
First(&binding).Error
if err == nil {
global.GVA_DB.First(&provider, binding.ProviderID)
}
}
if provider.ID == 0 {
err := global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到可用的AI提供商"})
return
}
}
// 3. 构建注入后的消息
messages, err := s.buildInjectedMessages(req, &preset)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "构建消息失败"})
return
}
// 4. 转发流式请求到上游AI
err = s.forwardStreamToAI(c, &provider, &preset, messages, userId, startTime)
if err != nil {
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
// forwardStreamToAI 转发流式请求到上游AI
func (s *AiProxyService) forwardStreamToAI(c *gin.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message, userId uint, startTime time.Time) error {
// 构建请求体
reqBody := map[string]interface{}{
"model": provider.Model,
"messages": messages,
"stream": true,
}
if preset != nil {
reqBody["temperature"] = preset.Temperature
reqBody["top_p"] = preset.TopP
reqBody["max_tokens"] = preset.MaxTokens
reqBody["frequency_penalty"] = preset.FrequencyPenalty
reqBody["presence_penalty"] = preset.PresencePenalty
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return err
}
// 创建HTTP请求
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
req, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
if provider.UpstreamKey != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
}
// 发送请求
client := &http.Client{Timeout: 300 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 读取并转发流式响应
reader := bufio.NewReader(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return fmt.Errorf("streaming not supported")
}
var fullResponse strings.Builder
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
// 跳过空行
if len(bytes.TrimSpace(line)) == 0 {
continue
}
// 解析SSE数据
lineStr := string(line)
if strings.HasPrefix(lineStr, "data: ") {
data := strings.TrimPrefix(lineStr, "data: ")
data = strings.TrimSpace(data)
// 检查是否是结束标记
if data == "[DONE]" {
c.Writer.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
break
}
// 解析JSON并提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullResponse.WriteString(content)
}
}
}
}
}
// 转发原始数据
c.Writer.Write(line)
flusher.Flush()
}
}
// 应用输出正则脚本
finalContent := fullResponse.String()
if preset != nil {
finalContent = s.applyOutputRegex(finalContent, preset.RegexScripts)
}
// 记录日志
var originalMsg string
if len(messages) > 0 {
originalMsg = messages[len(messages)-1].Content
}
s.logRequest(userId, preset, provider, originalMsg, finalContent, nil, time.Since(startTime))
return nil
}