248 lines
6.6 KiB
Go
248 lines
6.6 KiB
Go
package app
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.echol.cn/loser/ai_proxy/server/global"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
|
)
|
|
|
|
type AiProxyService struct{}
|
|
|
|
// ProcessChatCompletion 处理聊天补全请求
|
|
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint, req *request.ChatCompletionRequest) (resp response.ChatCompletionResponse, err error) {
|
|
startTime := time.Now()
|
|
|
|
// 1. 获取预设配置
|
|
var preset app.AiPreset
|
|
if req.PresetID > 0 {
|
|
err = global.GVA_DB.First(&preset, req.PresetID).Error
|
|
if err != nil {
|
|
return resp, fmt.Errorf("预设不存在: %w", err)
|
|
}
|
|
}
|
|
|
|
// 2. 获取提供商配置
|
|
var provider app.AiProvider
|
|
// TODO: 根据 binding_key 或默认配置获取 provider
|
|
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
|
if err != nil {
|
|
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
|
|
}
|
|
|
|
// 3. 构建注入后的消息
|
|
messages, err := s.buildInjectedMessages(req, &preset)
|
|
if err != nil {
|
|
return resp, fmt.Errorf("构建消息失败: %w", err)
|
|
}
|
|
|
|
// 4. 转发到上游AI
|
|
resp, err = s.forwardToAI(ctx, &provider, &preset, messages)
|
|
if err != nil {
|
|
// 记录失败日志
|
|
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, "", err, time.Since(startTime))
|
|
return resp, err
|
|
}
|
|
|
|
// 5. 应用输出正则脚本
|
|
resp.Choices[0].Message.Content = s.applyOutputRegex(resp.Choices[0].Message.Content, preset.RegexScripts)
|
|
|
|
// 6. 记录成功日志
|
|
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, resp.Choices[0].Message.Content, nil, time.Since(startTime))
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// buildInjectedMessages 构建注入预设后的消息数组
|
|
func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionRequest, preset *app.AiPreset) ([]request.Message, error) {
|
|
if preset == nil || preset.ID == 0 {
|
|
return req.Messages, nil
|
|
}
|
|
|
|
// TODO: 实现完整的预设注入逻辑
|
|
// 1. 按 injection_order 排序 prompts
|
|
// 2. 根据 injection_depth 插入到对话历史中
|
|
// 3. 替换变量 {{user}}, {{char}}
|
|
// 4. 应用正则脚本 (placement=1)
|
|
|
|
messages := make([]request.Message, 0)
|
|
|
|
// 简化实现:直接添加系统提示词
|
|
for _, prompt := range preset.Prompts {
|
|
if prompt.SystemPrompt && !prompt.Marker {
|
|
messages = append(messages, request.Message{
|
|
Role: prompt.Role,
|
|
Content: s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard),
|
|
})
|
|
}
|
|
}
|
|
|
|
// 添加用户消息
|
|
messages = append(messages, req.Messages...)
|
|
|
|
// 应用输入正则脚本
|
|
for i := range messages {
|
|
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
|
|
}
|
|
|
|
return messages, nil
|
|
}
|
|
|
|
// replaceVariables 替换变量
|
|
func (s *AiProxyService) replaceVariables(content string, vars map[string]string, card *request.CharacterCard) string {
|
|
result := content
|
|
|
|
// 替换自定义变量
|
|
for key, value := range vars {
|
|
placeholder := fmt.Sprintf("{{%s}}", key)
|
|
result = replaceAll(result, placeholder, value)
|
|
}
|
|
|
|
// 替换角色卡片变量
|
|
if card != nil {
|
|
result = replaceAll(result, "{{char}}", card.Name)
|
|
result = replaceAll(result, "{{char_name}}", card.Name)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// applyInputRegex 应用输入正则脚本
|
|
func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScript) string {
|
|
for _, script := range scripts {
|
|
if script.Disabled {
|
|
continue
|
|
}
|
|
if !containsPlacement(script.Placement, 1) {
|
|
continue
|
|
}
|
|
// TODO: 实现正则替换逻辑
|
|
}
|
|
return content
|
|
}
|
|
|
|
// applyOutputRegex 应用输出正则脚本
|
|
func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScript) string {
|
|
for _, script := range scripts {
|
|
if script.Disabled {
|
|
continue
|
|
}
|
|
if !containsPlacement(script.Placement, 2) {
|
|
continue
|
|
}
|
|
// TODO: 实现正则替换逻辑
|
|
}
|
|
return content
|
|
}
|
|
|
|
// forwardToAI 转发请求到上游AI
|
|
func (s *AiProxyService) forwardToAI(ctx context.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message) (response.ChatCompletionResponse, error) {
|
|
// 构建请求体
|
|
reqBody := map[string]interface{}{
|
|
"model": provider.Model,
|
|
"messages": messages,
|
|
}
|
|
|
|
if preset != nil {
|
|
reqBody["temperature"] = preset.Temperature
|
|
reqBody["top_p"] = preset.TopP
|
|
reqBody["max_tokens"] = preset.MaxTokens
|
|
reqBody["frequency_penalty"] = preset.FrequencyPenalty
|
|
reqBody["presence_penalty"] = preset.PresencePenalty
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return response.ChatCompletionResponse{}, err
|
|
}
|
|
|
|
// 创建HTTP请求
|
|
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return response.ChatCompletionResponse{}, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if provider.UpstreamKey != "" {
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
|
|
}
|
|
|
|
// 发送请求
|
|
client := &http.Client{Timeout: 120 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return response.ChatCompletionResponse{}, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 读取响应
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return response.ChatCompletionResponse{}, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return response.ChatCompletionResponse{}, fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
|
|
}
|
|
|
|
// 解析响应
|
|
var aiResp response.ChatCompletionResponse
|
|
if err := json.Unmarshal(body, &aiResp); err != nil {
|
|
return response.ChatCompletionResponse{}, err
|
|
}
|
|
|
|
return aiResp, nil
|
|
}
|
|
|
|
// logRequest 记录请求日志
|
|
func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider *app.AiProvider, originalMsg, responseText string, err error, latency time.Duration) {
|
|
log := app.AiRequestLog{
|
|
UserID: &userId,
|
|
OriginalMessage: originalMsg,
|
|
ResponseText: responseText,
|
|
LatencyMs: int(latency.Milliseconds()),
|
|
}
|
|
|
|
if preset != nil {
|
|
presetID := preset.ID
|
|
log.PresetID = &presetID
|
|
}
|
|
|
|
if provider != nil {
|
|
providerID := provider.ID
|
|
log.ProviderID = &providerID
|
|
}
|
|
|
|
if err != nil {
|
|
log.Status = "error"
|
|
log.ErrorMessage = err.Error()
|
|
} else {
|
|
log.Status = "success"
|
|
}
|
|
|
|
global.GVA_DB.Create(&log)
|
|
}
|
|
|
|
// 辅助函数
|
|
func replaceAll(s, old, new string) string {
|
|
return s // TODO: 实现字符串替换
|
|
}
|
|
|
|
func containsPlacement(placements []int, target int) bool {
|
|
for _, p := range placements {
|
|
if p == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|