Files
ai_proxy/server/service/app/ai_claude.go
2026-03-04 17:57:44 +08:00

327 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProcessClaudeMessage 处理 Claude 消息请求
func (s *AiProxyService) ProcessClaudeMessage(ctx context.Context, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
// 记录请求参数
global.GVA_LOG.Info("收到 Claude Messages 请求",
zap.String("model", req.Model),
zap.Any("messages", req.Messages),
zap.Any("full_request", req),
)
// 1. 根据模型获取配置
if req.Model == "" {
return nil, fmt.Errorf("model 参数不能为空")
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
return nil, err
}
// 2. 注入预设
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 转发请求到上游
resp, err := s.forwardClaudeRequest(ctx, provider, req)
if err != nil {
return nil, err
}
// 获取 AI 输出内容
aiOutput := ""
if len(resp.Content) > 0 {
aiOutput = resp.Content[0].Text
}
// 4. 处理响应(使用同一个 injector 实例)
if injector != nil && len(resp.Content) > 0 {
resp.Content[0].Text = injector.ProcessResponse(resp.Content[0].Text)
aiOutput = resp.Content[0].Text
}
// 5. 统一填充 standard_usage转换为 OpenAI 风格的用量统计
if resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 {
resp.StandardUsage = &response.ChatCompletionUsage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
}
// 记录响应内容
logFields := []zap.Field{
zap.String("ai_output", aiOutput),
zap.Any("usage", resp.Usage),
}
// 添加正则脚本执行日志(使用同一个 injector 实例)
if injector != nil {
regexLogs := injector.GetRegexLogs()
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
// 收集触发的脚本名称
triggeredScripts := make([]string, 0)
for _, scriptLog := range regexLogs.InputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
for _, scriptLog := range regexLogs.OutputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
if len(triggeredScripts) > 0 {
logFields = append(logFields,
zap.Strings("triggered_regex_scripts", triggeredScripts),
zap.Int("total_matches", regexLogs.TotalMatches),
)
}
}
}
logFields = append(logFields, zap.Any("full_response", resp))
global.GVA_LOG.Info("Claude Messages 响应", logFields...)
return resp, nil
}
// ProcessClaudeMessageStream 处理 Claude 流式消息请求
func (s *AiProxyService) ProcessClaudeMessageStream(c *gin.Context, req *request.ClaudeMessageRequest) {
// 记录请求参数
global.GVA_LOG.Info("收到 Claude Messages 流式请求",
zap.String("model", req.Model),
zap.Any("messages", req.Messages),
zap.Any("full_request", req),
)
// 1. 根据模型获取配置
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
return
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 注入预设
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 4. 转发流式请求
err = s.forwardClaudeStreamRequest(c, provider, req, injector)
if err != nil {
global.GVA_LOG.Error("Claude流式请求失败", zap.Error(err))
}
}
// forwardClaudeRequest 转发 Claude 请求
func (s *AiProxyService) forwardClaudeRequest(ctx context.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
var resp response.ClaudeMessageResponse
json.NewDecoder(httpResp.Body).Decode(&resp)
return &resp, nil
}
// forwardClaudeStreamRequest 转发 Claude 流式请求
func (s *AiProxyService) forwardClaudeStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest, injector *PresetInjector) error {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
global.GVA_LOG.Error("Claude 流式请求上游返回错误",
zap.Int("status_code", httpResp.StatusCode),
zap.String("response_body", string(body)),
)
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
reader := bufio.NewReader(httpResp.Body)
flusher, _ := c.Writer.(http.Flusher)
// 聚合完整输出用于日志
var fullContent bytes.Buffer
var totalInputTokens, totalOutputTokens int
for {
line, err := reader.ReadBytes('\n')
if err == io.EOF {
break
}
if len(bytes.TrimSpace(line)) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
var chunk response.ClaudeStreamResponse
if json.Unmarshal(data, &chunk) == nil {
// 收集 usage 信息
if chunk.Usage != nil {
totalInputTokens = chunk.Usage.InputTokens
totalOutputTokens = chunk.Usage.OutputTokens
}
// 处理文本内容
if chunk.Delta != nil && chunk.Delta.Text != "" {
fullContent.WriteString(chunk.Delta.Text)
if injector != nil {
chunk.Delta.Text = injector.ProcessResponse(chunk.Delta.Text)
}
}
processedData, _ := json.Marshal(chunk)
c.Writer.Write([]byte("data: "))
c.Writer.Write(processedData)
c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
}
}
// 记录完整的流式响应日志
logFields := []zap.Field{
zap.String("ai_output", fullContent.String()),
zap.Int("input_tokens", totalInputTokens),
zap.Int("output_tokens", totalOutputTokens),
zap.Int("total_tokens", totalInputTokens+totalOutputTokens),
}
// 添加正则脚本执行日志
if injector != nil {
regexLogs := injector.GetRegexLogs()
if regexLogs != nil && (regexLogs.TotalMatches > 0 || len(regexLogs.InputScripts) > 0 || len(regexLogs.OutputScripts) > 0) {
// 收集触发的脚本名称
triggeredScripts := make([]string, 0)
for _, scriptLog := range regexLogs.InputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输入:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
for _, scriptLog := range regexLogs.OutputScripts {
if scriptLog.MatchCount > 0 {
triggeredScripts = append(triggeredScripts, fmt.Sprintf("%s(输出:%d次)", scriptLog.ScriptName, scriptLog.MatchCount))
}
}
if len(triggeredScripts) > 0 {
logFields = append(logFields,
zap.Strings("triggered_regex_scripts", triggeredScripts),
zap.Int("total_matches", regexLogs.TotalMatches),
)
}
}
}
global.GVA_LOG.Info("Claude Messages 流式响应完成", logFields...)
return nil
}
// convertToOpenAIMessages 转换 Claude 消息为 OpenAI 格式
func (s *AiProxyService) convertToOpenAIMessages(messages []request.ClaudeMessage) []request.ChatMessage {
result := make([]request.ChatMessage, len(messages))
for i, msg := range messages {
content := ""
// 处理字符串类型的 content
if str, ok := msg.Content.(string); ok {
content = str
} else if blocks, ok := msg.Content.([]interface{}); ok {
// 处理对象数组类型的 content (Claude API 标准格式)
for _, block := range blocks {
if blockMap, ok := block.(map[string]interface{}); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content += text
}
}
}
}
}
result[i] = request.ChatMessage{Role: msg.Role, Content: content}
}
return result
}
// convertClaudeMessages 转换 OpenAI 消息为 Claude 格式
func (s *AiProxyService) convertClaudeMessages(messages []request.ChatMessage) []request.ClaudeMessage {
result := make([]request.ClaudeMessage, len(messages))
for i, msg := range messages {
result[i] = request.ClaudeMessage{Role: msg.Role, Content: msg.Content}
}
return result
}