Files
ai_proxy/server/service/app/ai_claude.go

201 lines
6.0 KiB
Go

package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProcessClaudeMessage 处理 Claude 消息请求
func (s *AiProxyService) ProcessClaudeMessage(ctx context.Context, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
// 1. 根据模型获取配置
if req.Model == "" {
return nil, fmt.Errorf("model 参数不能为空")
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
return nil, err
}
// 2. 注入预设
if preset != nil {
injector := NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 转发请求到上游
resp, err := s.forwardClaudeRequest(ctx, provider, req)
if err != nil {
return nil, err
}
// 4. 处理响应
if preset != nil && len(resp.Content) > 0 {
injector := NewPresetInjector(preset)
resp.Content[0].Text = injector.ProcessResponse(resp.Content[0].Text)
}
return resp, nil
}
// ProcessClaudeMessageStream 处理 Claude 流式消息请求
func (s *AiProxyService) ProcessClaudeMessageStream(c *gin.Context, req *request.ClaudeMessageRequest) {
// 1. 根据模型获取配置
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
return
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 注入预设
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 4. 转发流式请求
err = s.forwardClaudeStreamRequest(c, provider, req, injector)
if err != nil {
global.GVA_LOG.Error("Claude流式请求失败", zap.Error(err))
}
}
// forwardClaudeRequest 转发 Claude 请求
func (s *AiProxyService) forwardClaudeRequest(ctx context.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
var resp response.ClaudeMessageResponse
json.NewDecoder(httpResp.Body).Decode(&resp)
return &resp, nil
}
// forwardClaudeStreamRequest 转发 Claude 流式请求
func (s *AiProxyService) forwardClaudeStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest, injector *PresetInjector) error {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
reader := bufio.NewReader(httpResp.Body)
flusher, _ := c.Writer.(http.Flusher)
for {
line, err := reader.ReadBytes('\n')
if err == io.EOF {
break
}
if len(bytes.TrimSpace(line)) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
var chunk response.ClaudeStreamResponse
if json.Unmarshal(data, &chunk) == nil && chunk.Delta != nil {
if injector != nil {
chunk.Delta.Text = injector.ProcessResponse(chunk.Delta.Text)
}
processedData, _ := json.Marshal(chunk)
c.Writer.Write([]byte("data: "))
c.Writer.Write(processedData)
c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
}
}
return nil
}
// convertToOpenAIMessages 转换 Claude 消息为 OpenAI 格式
func (s *AiProxyService) convertToOpenAIMessages(messages []request.ClaudeMessage) []request.ChatMessage {
result := make([]request.ChatMessage, len(messages))
for i, msg := range messages {
content := ""
// 处理字符串类型的 content
if str, ok := msg.Content.(string); ok {
content = str
} else if blocks, ok := msg.Content.([]interface{}); ok {
// 处理对象数组类型的 content (Claude API 标准格式)
for _, block := range blocks {
if blockMap, ok := block.(map[string]interface{}); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content += text
}
}
}
}
}
result[i] = request.ChatMessage{Role: msg.Role, Content: content}
}
return result
}
// convertClaudeMessages 转换 OpenAI 消息为 Claude 格式
func (s *AiProxyService) convertClaudeMessages(messages []request.ChatMessage) []request.ClaudeMessage {
result := make([]request.ClaudeMessage, len(messages))
for i, msg := range messages {
result[i] = request.ClaudeMessage{Role: msg.Role, Content: msg.Content}
}
return result
}