201 lines
6.0 KiB
Go
201 lines
6.0 KiB
Go
package app
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.echol.cn/loser/ai_proxy/server/global"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// ProcessClaudeMessage 处理 Claude 消息请求
|
|
func (s *AiProxyService) ProcessClaudeMessage(ctx context.Context, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
|
|
// 1. 根据模型获取配置
|
|
if req.Model == "" {
|
|
return nil, fmt.Errorf("model 参数不能为空")
|
|
}
|
|
|
|
preset, provider, err := s.getConfigByModel(req.Model)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 2. 注入预设
|
|
if preset != nil {
|
|
injector := NewPresetInjector(preset)
|
|
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
|
|
}
|
|
|
|
// 3. 转发请求到上游
|
|
resp, err := s.forwardClaudeRequest(ctx, provider, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 4. 处理响应
|
|
if preset != nil && len(resp.Content) > 0 {
|
|
injector := NewPresetInjector(preset)
|
|
resp.Content[0].Text = injector.ProcessResponse(resp.Content[0].Text)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// ProcessClaudeMessageStream 处理 Claude 流式消息请求
|
|
func (s *AiProxyService) ProcessClaudeMessageStream(c *gin.Context, req *request.ClaudeMessageRequest) {
|
|
// 1. 根据模型获取配置
|
|
if req.Model == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
|
|
return
|
|
}
|
|
|
|
preset, provider, err := s.getConfigByModel(req.Model)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// 2. 注入预设
|
|
var injector *PresetInjector
|
|
if preset != nil {
|
|
injector = NewPresetInjector(preset)
|
|
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
|
|
}
|
|
|
|
// 3. 设置 SSE 响应头
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
// 4. 转发流式请求
|
|
err = s.forwardClaudeStreamRequest(c, provider, req, injector)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("Claude流式请求失败", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// forwardClaudeRequest 转发 Claude 请求
|
|
func (s *AiProxyService) forwardClaudeRequest(ctx context.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
|
|
if req.Model == "" && provider.Model != "" {
|
|
req.Model = provider.Model
|
|
}
|
|
|
|
reqBody, _ := json.Marshal(req)
|
|
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
|
|
httpReq, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("x-api-key", provider.APIKey)
|
|
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
|
httpResp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer httpResp.Body.Close()
|
|
|
|
if httpResp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(httpResp.Body)
|
|
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
|
}
|
|
|
|
var resp response.ClaudeMessageResponse
|
|
json.NewDecoder(httpResp.Body).Decode(&resp)
|
|
return &resp, nil
|
|
}
|
|
|
|
// forwardClaudeStreamRequest 转发 Claude 流式请求
|
|
func (s *AiProxyService) forwardClaudeStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest, injector *PresetInjector) error {
|
|
if req.Model == "" && provider.Model != "" {
|
|
req.Model = provider.Model
|
|
}
|
|
|
|
reqBody, _ := json.Marshal(req)
|
|
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
|
|
httpReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("x-api-key", provider.APIKey)
|
|
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
|
httpResp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer httpResp.Body.Close()
|
|
|
|
reader := bufio.NewReader(httpResp.Body)
|
|
flusher, _ := c.Writer.(http.Flusher)
|
|
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if len(bytes.TrimSpace(line)) == 0 {
|
|
continue
|
|
}
|
|
|
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
|
data := bytes.TrimPrefix(line, []byte("data: "))
|
|
var chunk response.ClaudeStreamResponse
|
|
if json.Unmarshal(data, &chunk) == nil && chunk.Delta != nil {
|
|
if injector != nil {
|
|
chunk.Delta.Text = injector.ProcessResponse(chunk.Delta.Text)
|
|
}
|
|
processedData, _ := json.Marshal(chunk)
|
|
c.Writer.Write([]byte("data: "))
|
|
c.Writer.Write(processedData)
|
|
c.Writer.Write([]byte("\n\n"))
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// convertToOpenAIMessages 转换 Claude 消息为 OpenAI 格式
|
|
func (s *AiProxyService) convertToOpenAIMessages(messages []request.ClaudeMessage) []request.ChatMessage {
|
|
result := make([]request.ChatMessage, len(messages))
|
|
for i, msg := range messages {
|
|
content := ""
|
|
// 处理字符串类型的 content
|
|
if str, ok := msg.Content.(string); ok {
|
|
content = str
|
|
} else if blocks, ok := msg.Content.([]interface{}); ok {
|
|
// 处理对象数组类型的 content (Claude API 标准格式)
|
|
for _, block := range blocks {
|
|
if blockMap, ok := block.(map[string]interface{}); ok {
|
|
if blockMap["type"] == "text" {
|
|
if text, ok := blockMap["text"].(string); ok {
|
|
content += text
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
result[i] = request.ChatMessage{Role: msg.Role, Content: content}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// convertClaudeMessages 转换 OpenAI 消息为 Claude 格式
|
|
func (s *AiProxyService) convertClaudeMessages(messages []request.ChatMessage) []request.ClaudeMessage {
|
|
result := make([]request.ClaudeMessage, len(messages))
|
|
for i, msg := range messages {
|
|
result[i] = request.ClaudeMessage{Role: msg.Role, Content: msg.Content}
|
|
}
|
|
return result
|
|
}
|