327 lines
10 KiB
Go
327 lines
10 KiB
Go
package app
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"git.echol.cn/loser/ai_proxy/server/global"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// ProcessClaudeMessage 处理 Claude 消息请求
|
||
func (s *AiProxyService) ProcessClaudeMessage(ctx context.Context, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
|
||
// 记录请求参数
|
||
global.GVA_LOG.Info("收到 Claude Messages 请求",
|
||
zap.String("model", req.Model),
|
||
zap.Any("messages", req.Messages),
|
||
zap.Any("full_request", req),
|
||
)
|
||
|
||
// 1. 根据模型获取配置
|
||
if req.Model == "" {
|
||
return nil, fmt.Errorf("model 参数不能为空")
|
||
}
|
||
|
||
preset, provider, err := s.getConfigByModel(req.Model)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 2. 注入预设
|
||
var injector *PresetInjector
|
||
if preset != nil {
|
||
injector = NewPresetInjector(preset)
|
||
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
|
||
}
|
||
|
||
// 3. 转发请求到上游
|
||
resp, err := s.forwardClaudeRequest(ctx, provider, req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取 AI 输出内容
|
||
aiOutput := ""
|
||
if len(resp.Content) > 0 {
|
||
aiOutput = resp.Content[0].Text
|
||
}
|
||
|
||
// 4. 处理响应(使用同一个 injector 实例)
|
||
if injector != nil && len(resp.Content) > 0 {
|
||
resp.Content[0].Text = injector.ProcessResponse(resp.Content[0].Text)
|
||
aiOutput = resp.Content[0].Text
|
||
}
|
||
|
||
// 5. 统一填充 standard_usage,转换为 OpenAI 风格的用量统计
|
||
if resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 {
|
||
resp.StandardUsage = &response.ChatCompletionUsage{
|
||
PromptTokens: resp.Usage.InputTokens,
|
||
CompletionTokens: resp.Usage.OutputTokens,
|
||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||
}
|
||
}
|
||
|
||
// 记录响应内容
|
||
logFields := []zap.Field{
|
||
zap.String("ai_output", aiOutput),
|
||
zap.Any("usage", resp.Usage),
|
||
}
|
||
|
||
// 添加正则脚本执行日志(使用同一个 injector 实例)
|
||
if injector != nil {
|
||
regexLogs := injector.GetRegexLogs()
|
||
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
|
||
// 收集触发的脚本名称
|
||
triggeredScripts := make([]string, 0)
|
||
for _, scriptLog := range regexLogs.InputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
for _, scriptLog := range regexLogs.OutputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
|
||
if len(triggeredScripts) > 0 {
|
||
logFields = append(logFields,
|
||
zap.Strings("triggered_regex_scripts", triggeredScripts),
|
||
zap.Int("total_matches", regexLogs.TotalMatches),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
logFields = append(logFields, zap.Any("full_response", resp))
|
||
global.GVA_LOG.Info("Claude Messages 响应", logFields...)
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// ProcessClaudeMessageStream 处理 Claude 流式消息请求
|
||
func (s *AiProxyService) ProcessClaudeMessageStream(c *gin.Context, req *request.ClaudeMessageRequest) {
|
||
// 记录请求参数
|
||
global.GVA_LOG.Info("收到 Claude Messages 流式请求",
|
||
zap.String("model", req.Model),
|
||
zap.Any("messages", req.Messages),
|
||
zap.Any("full_request", req),
|
||
)
|
||
|
||
// 1. 根据模型获取配置
|
||
if req.Model == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
|
||
return
|
||
}
|
||
|
||
preset, provider, err := s.getConfigByModel(req.Model)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 2. 注入预设
|
||
var injector *PresetInjector
|
||
if preset != nil {
|
||
injector = NewPresetInjector(preset)
|
||
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
|
||
}
|
||
|
||
// 3. 设置 SSE 响应头
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
|
||
// 4. 转发流式请求
|
||
err = s.forwardClaudeStreamRequest(c, provider, req, injector)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("Claude流式请求失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// forwardClaudeRequest 转发 Claude 请求
|
||
func (s *AiProxyService) forwardClaudeRequest(ctx context.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
|
||
if req.Model == "" && provider.Model != "" {
|
||
req.Model = provider.Model
|
||
}
|
||
|
||
reqBody, _ := json.Marshal(req)
|
||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
|
||
httpReq, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("x-api-key", provider.APIKey)
|
||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||
|
||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||
httpResp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer httpResp.Body.Close()
|
||
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(httpResp.Body)
|
||
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||
}
|
||
|
||
var resp response.ClaudeMessageResponse
|
||
json.NewDecoder(httpResp.Body).Decode(&resp)
|
||
return &resp, nil
|
||
}
|
||
|
||
// forwardClaudeStreamRequest 转发 Claude 流式请求
|
||
func (s *AiProxyService) forwardClaudeStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest, injector *PresetInjector) error {
|
||
if req.Model == "" && provider.Model != "" {
|
||
req.Model = provider.Model
|
||
}
|
||
|
||
reqBody, _ := json.Marshal(req)
|
||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
|
||
httpReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("x-api-key", provider.APIKey)
|
||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||
|
||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||
httpResp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer httpResp.Body.Close()
|
||
|
||
if httpResp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(httpResp.Body)
|
||
global.GVA_LOG.Error("Claude 流式请求上游返回错误",
|
||
zap.Int("status_code", httpResp.StatusCode),
|
||
zap.String("response_body", string(body)),
|
||
)
|
||
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||
}
|
||
|
||
reader := bufio.NewReader(httpResp.Body)
|
||
flusher, _ := c.Writer.(http.Flusher)
|
||
|
||
// 聚合完整输出用于日志
|
||
var fullContent bytes.Buffer
|
||
var totalInputTokens, totalOutputTokens int
|
||
|
||
for {
|
||
line, err := reader.ReadBytes('\n')
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
if len(bytes.TrimSpace(line)) == 0 {
|
||
continue
|
||
}
|
||
|
||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||
var chunk response.ClaudeStreamResponse
|
||
if json.Unmarshal(data, &chunk) == nil {
|
||
// 收集 usage 信息
|
||
if chunk.Usage != nil {
|
||
totalInputTokens = chunk.Usage.InputTokens
|
||
totalOutputTokens = chunk.Usage.OutputTokens
|
||
}
|
||
|
||
// 处理文本内容
|
||
if chunk.Delta != nil && chunk.Delta.Text != "" {
|
||
fullContent.WriteString(chunk.Delta.Text)
|
||
if injector != nil {
|
||
chunk.Delta.Text = injector.ProcessResponse(chunk.Delta.Text)
|
||
}
|
||
}
|
||
|
||
processedData, _ := json.Marshal(chunk)
|
||
c.Writer.Write([]byte("data: "))
|
||
c.Writer.Write(processedData)
|
||
c.Writer.Write([]byte("\n\n"))
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
// 记录完整的流式响应日志
|
||
logFields := []zap.Field{
|
||
zap.String("ai_output", fullContent.String()),
|
||
zap.Int("input_tokens", totalInputTokens),
|
||
zap.Int("output_tokens", totalOutputTokens),
|
||
zap.Int("total_tokens", totalInputTokens+totalOutputTokens),
|
||
}
|
||
|
||
// 添加正则脚本执行日志
|
||
if injector != nil {
|
||
regexLogs := injector.GetRegexLogs()
|
||
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
|
||
// 收集触发的脚本名称
|
||
triggeredScripts := make([]string, 0)
|
||
for _, scriptLog := range regexLogs.InputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
for _, scriptLog := range regexLogs.OutputScripts {
|
||
if scriptLog.MatchCount > 0 {
|
||
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
|
||
}
|
||
}
|
||
|
||
if len(triggeredScripts) > 0 {
|
||
logFields = append(logFields,
|
||
zap.Strings("triggered_regex_scripts", triggeredScripts),
|
||
zap.Int("total_matches", regexLogs.TotalMatches),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
global.GVA_LOG.Info("Claude Messages 流式响应完成", logFields...)
|
||
|
||
return nil
|
||
}
|
||
|
||
// convertToOpenAIMessages 转换 Claude 消息为 OpenAI 格式
|
||
func (s *AiProxyService) convertToOpenAIMessages(messages []request.ClaudeMessage) []request.ChatMessage {
|
||
result := make([]request.ChatMessage, len(messages))
|
||
for i, msg := range messages {
|
||
content := ""
|
||
// 处理字符串类型的 content
|
||
if str, ok := msg.Content.(string); ok {
|
||
content = str
|
||
} else if blocks, ok := msg.Content.([]interface{}); ok {
|
||
// 处理对象数组类型的 content (Claude API 标准格式)
|
||
for _, block := range blocks {
|
||
if blockMap, ok := block.(map[string]interface{}); ok {
|
||
if blockMap["type"] == "text" {
|
||
if text, ok := blockMap["text"].(string); ok {
|
||
content += text
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
result[i] = request.ChatMessage{Role: msg.Role, Content: content}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// convertClaudeMessages 转换 OpenAI 消息为 Claude 格式
|
||
func (s *AiProxyService) convertClaudeMessages(messages []request.ChatMessage) []request.ClaudeMessage {
|
||
result := make([]request.ClaudeMessage, len(messages))
|
||
for i, msg := range messages {
|
||
result[i] = request.ClaudeMessage{Role: msg.Role, Content: msg.Content}
|
||
}
|
||
return result
|
||
}
|