package app import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "time" appModel "git.echol.cn/loser/st/server/model/app" ) // AIMessage AI调用的消息格式(统一) type AIMessagePayload struct { Role string `json:"role"` Content string `json:"content"` } // AIStreamChunk 流式响应的单个块 type AIStreamChunk struct { Content string `json:"content"` // 增量文本 Done bool `json:"done"` // 是否结束 Model string `json:"model"` // 使用的模型 PromptTokens int `json:"promptTokens"` // 提示词Token(仅结束时有值) CompletionTokens int `json:"completionTokens"` // 补全Token(仅结束时有值) Error string `json:"error"` // 错误信息 } // BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload { var payload []AIMessagePayload // 1. 系统提示词(System Prompt) systemPrompt := buildSystemPrompt(character) if systemPrompt != "" { payload = append(payload, AIMessagePayload{ Role: "system", Content: systemPrompt, }) } // 2. 历史消息 for _, msg := range messages { role := msg.Role if role == "assistant" || role == "user" || role == "system" { payload = append(payload, AIMessagePayload{ Role: role, Content: msg.Content, }) } } return payload } // buildSystemPrompt 构建系统提示词 func buildSystemPrompt(character *appModel.AICharacter) string { if character == nil { return "" } var parts []string // 角色描述 if character.Description != "" { parts = append(parts, character.Description) } // 角色性格 if character.Personality != "" { parts = append(parts, "Personality: "+character.Personality) } // 场景 if character.Scenario != "" { parts = append(parts, "Scenario: "+character.Scenario) } // 系统提示词 if character.SystemPrompt != "" { parts = append(parts, character.SystemPrompt) } // 示例消息 if len(character.ExampleMessages) > 0 { parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n")) } return strings.Join(parts, "\n\n") } // StreamAIResponse 流式调用 AI 并通过 channel 返回结果 func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) { defer close(ch) apiKey := decryptAPIKey(provider.APIKey) baseURL := provider.BaseURL switch provider.ProviderType { case "openai", "custom": streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch) case "claude": streamClaude(ctx, baseURL, apiKey, modelName, messages, ch) case "gemini": streamGemini(ctx, baseURL, apiKey, modelName, messages, ch) default: ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true} } } // ==================== OpenAI Compatible ==================== func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) { url := strings.TrimRight(baseURL, "/") + "/chat/completions" body := map[string]interface{}{ "model": modelName, "messages": messages, "stream": true, } bodyJSON, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON)) if err != nil { ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true} return } req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 120 * time.Second} resp, err := client.Do(req) if err != nil { ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true} return } defer resp.Body.Close() if resp.StatusCode != 200 { respBody, _ := io.ReadAll(resp.Body) ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true} return } scanner := bufio.NewScanner(resp.Body) var fullContent string for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""} return } var chunk struct { Choices []struct { Delta struct { Content string `json:"content"` } `json:"delta"` } `json:"choices"` Usage *struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` } `json:"usage"` } if err := json.Unmarshal([]byte(data), &chunk); err != nil { continue } if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { content := chunk.Choices[0].Delta.Content fullContent += content ch <- AIStreamChunk{Content: content} } if chunk.Usage != nil { ch <- AIStreamChunk{ Done: true, Model: modelName, PromptTokens: chunk.Usage.PromptTokens, CompletionTokens: chunk.Usage.CompletionTokens, } return } } // 如果扫描结束但没收到 [DONE] if fullContent != "" { ch <- AIStreamChunk{Done: true, Model: modelName} } } // ==================== Claude ==================== func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) { url := strings.TrimRight(baseURL, "/") + "/v1/messages" // Claude 需要把 system 消息分离出来 var systemContent string var claudeMessages []map[string]string for _, msg := range messages { if msg.Role == "system" { systemContent += msg.Content + "\n" } else { claudeMessages = append(claudeMessages, map[string]string{ "role": msg.Role, "content": msg.Content, }) } } // 确保至少有一条消息 if len(claudeMessages) == 0 { ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true} return } body := map[string]interface{}{ "model": modelName, "messages": claudeMessages, "max_tokens": 4096, "stream": true, } if systemContent != "" { body["system"] = strings.TrimSpace(systemContent) } bodyJSON, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON)) if err != nil { ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true} return } req.Header.Set("x-api-key", apiKey) req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 120 * time.Second} resp, err := client.Do(req) if err != nil { ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true} return } defer resp.Body.Close() if resp.StatusCode != 200 { respBody, _ := io.ReadAll(resp.Body) ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true} return } scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") var event struct { Type string `json:"type"` Delta *struct { Type string `json:"type"` Text string `json:"text"` } `json:"delta"` Usage *struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` } `json:"usage"` } if err := json.Unmarshal([]byte(data), &event); err != nil { continue } switch event.Type { case "content_block_delta": if event.Delta != nil && event.Delta.Text != "" { ch <- AIStreamChunk{Content: event.Delta.Text} } case "message_delta": if event.Usage != nil { ch <- AIStreamChunk{ Done: true, Model: modelName, PromptTokens: event.Usage.InputTokens, CompletionTokens: event.Usage.OutputTokens, } return } case "message_stop": ch <- AIStreamChunk{Done: true, Model: modelName} return case "error": ch <- AIStreamChunk{Error: "Claude API 错误", Done: true} return } } } // ==================== Gemini ==================== func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) { url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s", strings.TrimRight(baseURL, "/"), modelName, apiKey) // 构建 Gemini 格式的消息 var systemInstruction string var contents []map[string]interface{} for _, msg := range messages { if msg.Role == "system" { systemInstruction += msg.Content + "\n" continue } role := msg.Role if role == "assistant" { role = "model" } contents = append(contents, map[string]interface{}{ "role": role, "parts": []map[string]string{ {"text": msg.Content}, }, }) } if len(contents) == 0 { ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true} return } body := map[string]interface{}{ "contents": contents, } if systemInstruction != "" { body["systemInstruction"] = map[string]interface{}{ "parts": []map[string]string{ {"text": strings.TrimSpace(systemInstruction)}, }, } } bodyJSON, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON)) if err != nil { ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true} return } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 120 * time.Second} resp, err := client.Do(req) if err != nil { ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true} return } defer resp.Body.Close() if resp.StatusCode != 200 { respBody, _ := io.ReadAll(resp.Body) ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true} return } scanner := bufio.NewScanner(resp.Body) // Gemini 返回较大的 chunks,增大 buffer scanner.Buffer(make([]byte, 0), 1024*1024) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") var chunk struct { Candidates []struct { Content struct { Parts []struct { Text string `json:"text"` } `json:"parts"` } `json:"content"` } `json:"candidates"` UsageMetadata *struct { PromptTokenCount int `json:"promptTokenCount"` CandidatesTokenCount int `json:"candidatesTokenCount"` } `json:"usageMetadata"` } if err := json.Unmarshal([]byte(data), &chunk); err != nil { continue } for _, candidate := range chunk.Candidates { for _, part := range candidate.Content.Parts { if part.Text != "" { ch <- AIStreamChunk{Content: part.Text} } } } if chunk.UsageMetadata != nil { ch <- AIStreamChunk{ Done: true, Model: modelName, PromptTokens: chunk.UsageMetadata.PromptTokenCount, CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount, } return } } ch <- AIStreamChunk{Done: true, Model: modelName} } // ==================== 非流式调用(用于生图等) ==================== // CallAINonStream 非流式调用 AI func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) { apiKey := decryptAPIKey(provider.APIKey) baseURL := provider.BaseURL switch provider.ProviderType { case "openai", "custom": return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages) case "claude": return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages) case "gemini": return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages) default: return "", errors.New("不支持的提供商类型: " + provider.ProviderType) } } func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) { url := strings.TrimRight(baseURL, "/") + "/chat/completions" body := map[string]interface{}{ "model": modelName, "messages": messages, } bodyJSON, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON)) if err != nil { return "", err } req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != 200 { return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)) } var result struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } if err := json.Unmarshal(respBody, &result); err != nil { return "", err } if len(result.Choices) > 0 { return result.Choices[0].Message.Content, nil } return "", errors.New("AI 未返回有效回复") } func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) { url := strings.TrimRight(baseURL, "/") + "/v1/messages" var systemContent string var claudeMessages []map[string]string for _, msg := range messages { if msg.Role == "system" { systemContent += msg.Content + "\n" } else { claudeMessages = append(claudeMessages, map[string]string{ "role": msg.Role, "content": msg.Content, }) } } body := map[string]interface{}{ "model": modelName, "messages": claudeMessages, "max_tokens": 4096, } if systemContent != "" { body["system"] = strings.TrimSpace(systemContent) } bodyJSON, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON)) if err != nil { return "", err } req.Header.Set("x-api-key", apiKey) req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != 200 { return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)) } var result struct { Content []struct { Text string `json:"text"` } `json:"content"` } if err := json.Unmarshal(respBody, &result); err != nil { return "", err } if len(result.Content) > 0 { return result.Content[0].Text, nil } return "", errors.New("AI 未返回有效回复") } func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) { url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", strings.TrimRight(baseURL, "/"), modelName, apiKey) var systemInstruction string var contents []map[string]interface{} for _, msg := range messages { if msg.Role == "system" { systemInstruction += msg.Content + "\n" continue } role := msg.Role if role == "assistant" { role = "model" } contents = append(contents, map[string]interface{}{ "role": role, "parts": []map[string]string{{"text": msg.Content}}, }) } body := map[string]interface{}{"contents": contents} if systemInstruction != "" { body["systemInstruction"] = map[string]interface{}{ "parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}}, } } bodyJSON, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON)) if err != nil { return "", err } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != 200 { return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)) } var result struct { Candidates []struct { Content struct { Parts []struct { Text string `json:"text"` } `json:"parts"` } `json:"content"` } `json:"candidates"` } if err := json.Unmarshal(respBody, &result); err != nil { return "", err } if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 { return result.Candidates[0].Content.Parts[0].Text, nil } return "", errors.New("AI 未返回有效回复") }