Files
st/server/service/app/ai_client.go

626 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
appModel "git.echol.cn/loser/st/server/model/app"
)
// AIMessage AI调用的消息格式统一
type AIMessagePayload struct {
Role string `json:"role"`
Content string `json:"content"`
}
// AIStreamChunk 流式响应的单个块
type AIStreamChunk struct {
Content string `json:"content"` // 增量文本
Done bool `json:"done"` // 是否结束
Model string `json:"model"` // 使用的模型
PromptTokens int `json:"promptTokens"` // 提示词Token仅结束时有值
CompletionTokens int `json:"completionTokens"` // 补全Token仅结束时有值
Error string `json:"error"` // 错误信息
}
// BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt
func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload {
var payload []AIMessagePayload
// 1. 系统提示词System Prompt
systemPrompt := buildSystemPrompt(character)
if systemPrompt != "" {
payload = append(payload, AIMessagePayload{
Role: "system",
Content: systemPrompt,
})
}
// 2. 历史消息
for _, msg := range messages {
role := msg.Role
if role == "assistant" || role == "user" || role == "system" {
payload = append(payload, AIMessagePayload{
Role: role,
Content: msg.Content,
})
}
}
return payload
}
// buildSystemPrompt 构建系统提示词
func buildSystemPrompt(character *appModel.AICharacter) string {
if character == nil {
return ""
}
var parts []string
// 角色描述
if character.Description != "" {
parts = append(parts, character.Description)
}
// 角色性格
if character.Personality != "" {
parts = append(parts, "Personality: "+character.Personality)
}
// 场景
if character.Scenario != "" {
parts = append(parts, "Scenario: "+character.Scenario)
}
// 系统提示词
if character.SystemPrompt != "" {
parts = append(parts, character.SystemPrompt)
}
// 示例消息
if len(character.ExampleMessages) > 0 {
parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n"))
}
return strings.Join(parts, "\n\n")
}
// StreamAIResponse 流式调用 AI 并通过 channel 返回结果
func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
defer close(ch)
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch)
case "claude":
streamClaude(ctx, baseURL, apiKey, modelName, messages, ch)
case "gemini":
streamGemini(ctx, baseURL, apiKey, modelName, messages, ch)
default:
ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true}
}
}
// ==================== OpenAI Compatible ====================
func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
"stream": true,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
var fullContent string
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""}
return
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
content := chunk.Choices[0].Delta.Content
fullContent += content
ch <- AIStreamChunk{Content: content}
}
if chunk.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.Usage.PromptTokens,
CompletionTokens: chunk.Usage.CompletionTokens,
}
return
}
}
// 如果扫描结束但没收到 [DONE]
if fullContent != "" {
ch <- AIStreamChunk{Done: true, Model: modelName}
}
}
// ==================== Claude ====================
func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
// Claude 需要把 system 消息分离出来
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
// 确保至少有一条消息
if len(claudeMessages) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
"stream": true,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var event struct {
Type string `json:"type"`
Delta *struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"delta"`
Usage *struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
switch event.Type {
case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
ch <- AIStreamChunk{Content: event.Delta.Text}
}
case "message_delta":
if event.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens,
}
return
}
case "message_stop":
ch <- AIStreamChunk{Done: true, Model: modelName}
return
case "error":
ch <- AIStreamChunk{Error: "Claude API 错误", Done: true}
return
}
}
}
// ==================== Gemini ====================
func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
// 构建 Gemini 格式的消息
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{
{"text": msg.Content},
},
})
}
if len(contents) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"contents": contents,
}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{
{"text": strings.TrimSpace(systemInstruction)},
},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
// Gemini 返回较大的 chunks增大 buffer
scanner.Buffer(make([]byte, 0), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var chunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata *struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
for _, candidate := range chunk.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
ch <- AIStreamChunk{Content: part.Text}
}
}
}
if chunk.UsageMetadata != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
}
return
}
}
ch <- AIStreamChunk{Done: true, Model: modelName}
}
// ==================== 非流式调用(用于生图等) ====================
// CallAINonStream 非流式调用 AI
func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) {
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages)
case "claude":
return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages)
case "gemini":
return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages)
default:
return "", errors.New("不支持的提供商类型: " + provider.ProviderType)
}
}
func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Choices) > 0 {
return result.Choices[0].Message.Content, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Content) > 0 {
return result.Content[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{{"text": msg.Content}},
})
}
body := map[string]interface{}{"contents": contents}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 {
return result.Candidates[0].Content.Parts[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}