Files
ai_proxy/server/service/app/ai_claude.go

208 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProcessClaudeMessage 处理 Claude 消息请求
func (s *AiProxyService) ProcessClaudeMessage(ctx context.Context, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
// 1. 根据模型获取配置
if req.Model == "" {
return nil, fmt.Errorf("model 参数不能为空")
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
return nil, err
}
// 2. 注入预设
if preset != nil {
injector := NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 转发请求到上游
resp, err := s.forwardClaudeRequest(ctx, provider, req)
if err != nil {
return nil, err
}
// 4. 处理响应
if preset != nil && len(resp.Content) > 0 {
injector := NewPresetInjector(preset)
resp.Content[0].Text = injector.ProcessResponse(resp.Content[0].Text)
}
// 5. 统一填充 standard_usage转换为 OpenAI 风格的用量统计
resp.StandardUsage = &response.ChatCompletionUsage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
return resp, nil
}
// ProcessClaudeMessageStream 处理 Claude 流式消息请求
func (s *AiProxyService) ProcessClaudeMessageStream(c *gin.Context, req *request.ClaudeMessageRequest) {
// 1. 根据模型获取配置
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
return
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 注入预设
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 4. 转发流式请求
err = s.forwardClaudeStreamRequest(c, provider, req, injector)
if err != nil {
global.GVA_LOG.Error("Claude流式请求失败", zap.Error(err))
}
}
// forwardClaudeRequest 转发 Claude 请求
func (s *AiProxyService) forwardClaudeRequest(ctx context.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
var resp response.ClaudeMessageResponse
json.NewDecoder(httpResp.Body).Decode(&resp)
return &resp, nil
}
// forwardClaudeStreamRequest 转发 Claude 流式请求
func (s *AiProxyService) forwardClaudeStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest, injector *PresetInjector) error {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
reader := bufio.NewReader(httpResp.Body)
flusher, _ := c.Writer.(http.Flusher)
for {
line, err := reader.ReadBytes('\n')
if err == io.EOF {
break
}
if len(bytes.TrimSpace(line)) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
var chunk response.ClaudeStreamResponse
if json.Unmarshal(data, &chunk) == nil && chunk.Delta != nil {
if injector != nil {
chunk.Delta.Text = injector.ProcessResponse(chunk.Delta.Text)
}
processedData, _ := json.Marshal(chunk)
c.Writer.Write([]byte("data: "))
c.Writer.Write(processedData)
c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
}
}
return nil
}
// convertToOpenAIMessages 转换 Claude 消息为 OpenAI 格式
func (s *AiProxyService) convertToOpenAIMessages(messages []request.ClaudeMessage) []request.ChatMessage {
result := make([]request.ChatMessage, len(messages))
for i, msg := range messages {
content := ""
// 处理字符串类型的 content
if str, ok := msg.Content.(string); ok {
content = str
} else if blocks, ok := msg.Content.([]interface{}); ok {
// 处理对象数组类型的 content (Claude API 标准格式)
for _, block := range blocks {
if blockMap, ok := block.(map[string]interface{}); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content += text
}
}
}
}
}
result[i] = request.ChatMessage{Role: msg.Role, Content: content}
}
return result
}
// convertClaudeMessages 转换 OpenAI 消息为 Claude 格式
func (s *AiProxyService) convertClaudeMessages(messages []request.ChatMessage) []request.ClaudeMessage {
result := make([]request.ClaudeMessage, len(messages))
for i, msg := range messages {
result[i] = request.ClaudeMessage{Role: msg.Role, Content: msg.Content}
}
return result
}