626 lines
16 KiB
Go
626 lines
16 KiB
Go
package app
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
appModel "git.echol.cn/loser/st/server/model/app"
|
||
)
|
||
|
||
// AIMessage AI调用的消息格式(统一)
|
||
type AIMessagePayload struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
// AIStreamChunk 流式响应的单个块
|
||
type AIStreamChunk struct {
|
||
Content string `json:"content"` // 增量文本
|
||
Done bool `json:"done"` // 是否结束
|
||
Model string `json:"model"` // 使用的模型
|
||
PromptTokens int `json:"promptTokens"` // 提示词Token(仅结束时有值)
|
||
CompletionTokens int `json:"completionTokens"` // 补全Token(仅结束时有值)
|
||
Error string `json:"error"` // 错误信息
|
||
}
|
||
|
||
// BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt
|
||
func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload {
|
||
var payload []AIMessagePayload
|
||
|
||
// 1. 系统提示词(System Prompt)
|
||
systemPrompt := buildSystemPrompt(character)
|
||
if systemPrompt != "" {
|
||
payload = append(payload, AIMessagePayload{
|
||
Role: "system",
|
||
Content: systemPrompt,
|
||
})
|
||
}
|
||
|
||
// 2. 历史消息
|
||
for _, msg := range messages {
|
||
role := msg.Role
|
||
if role == "assistant" || role == "user" || role == "system" {
|
||
payload = append(payload, AIMessagePayload{
|
||
Role: role,
|
||
Content: msg.Content,
|
||
})
|
||
}
|
||
}
|
||
|
||
return payload
|
||
}
|
||
|
||
// buildSystemPrompt 构建系统提示词
|
||
func buildSystemPrompt(character *appModel.AICharacter) string {
|
||
if character == nil {
|
||
return ""
|
||
}
|
||
|
||
var parts []string
|
||
|
||
// 角色描述
|
||
if character.Description != "" {
|
||
parts = append(parts, character.Description)
|
||
}
|
||
|
||
// 角色性格
|
||
if character.Personality != "" {
|
||
parts = append(parts, "Personality: "+character.Personality)
|
||
}
|
||
|
||
// 场景
|
||
if character.Scenario != "" {
|
||
parts = append(parts, "Scenario: "+character.Scenario)
|
||
}
|
||
|
||
// 系统提示词
|
||
if character.SystemPrompt != "" {
|
||
parts = append(parts, character.SystemPrompt)
|
||
}
|
||
|
||
// 示例消息
|
||
if len(character.ExampleMessages) > 0 {
|
||
parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n"))
|
||
}
|
||
|
||
return strings.Join(parts, "\n\n")
|
||
}
|
||
|
||
// StreamAIResponse 流式调用 AI 并通过 channel 返回结果
|
||
func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||
defer close(ch)
|
||
|
||
apiKey := decryptAPIKey(provider.APIKey)
|
||
baseURL := provider.BaseURL
|
||
|
||
switch provider.ProviderType {
|
||
case "openai", "custom":
|
||
streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch)
|
||
case "claude":
|
||
streamClaude(ctx, baseURL, apiKey, modelName, messages, ch)
|
||
case "gemini":
|
||
streamGemini(ctx, baseURL, apiKey, modelName, messages, ch)
|
||
default:
|
||
ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true}
|
||
}
|
||
}
|
||
|
||
// ==================== OpenAI Compatible ====================
|
||
|
||
func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||
|
||
body := map[string]interface{}{
|
||
"model": modelName,
|
||
"messages": messages,
|
||
"stream": true,
|
||
}
|
||
bodyJSON, _ := json.Marshal(body)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||
if err != nil {
|
||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||
return
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 120 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != 200 {
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||
return
|
||
}
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
var fullContent string
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") {
|
||
continue
|
||
}
|
||
|
||
data := strings.TrimPrefix(line, "data: ")
|
||
if data == "[DONE]" {
|
||
ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""}
|
||
return
|
||
}
|
||
|
||
var chunk struct {
|
||
Choices []struct {
|
||
Delta struct {
|
||
Content string `json:"content"`
|
||
} `json:"delta"`
|
||
} `json:"choices"`
|
||
Usage *struct {
|
||
PromptTokens int `json:"prompt_tokens"`
|
||
CompletionTokens int `json:"completion_tokens"`
|
||
} `json:"usage"`
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||
continue
|
||
}
|
||
|
||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||
content := chunk.Choices[0].Delta.Content
|
||
fullContent += content
|
||
ch <- AIStreamChunk{Content: content}
|
||
}
|
||
|
||
if chunk.Usage != nil {
|
||
ch <- AIStreamChunk{
|
||
Done: true,
|
||
Model: modelName,
|
||
PromptTokens: chunk.Usage.PromptTokens,
|
||
CompletionTokens: chunk.Usage.CompletionTokens,
|
||
}
|
||
return
|
||
}
|
||
}
|
||
|
||
// 如果扫描结束但没收到 [DONE]
|
||
if fullContent != "" {
|
||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||
}
|
||
}
|
||
|
||
// ==================== Claude ====================
|
||
|
||
func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
|
||
|
||
// Claude 需要把 system 消息分离出来
|
||
var systemContent string
|
||
var claudeMessages []map[string]string
|
||
for _, msg := range messages {
|
||
if msg.Role == "system" {
|
||
systemContent += msg.Content + "\n"
|
||
} else {
|
||
claudeMessages = append(claudeMessages, map[string]string{
|
||
"role": msg.Role,
|
||
"content": msg.Content,
|
||
})
|
||
}
|
||
}
|
||
|
||
// 确保至少有一条消息
|
||
if len(claudeMessages) == 0 {
|
||
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
|
||
return
|
||
}
|
||
|
||
body := map[string]interface{}{
|
||
"model": modelName,
|
||
"messages": claudeMessages,
|
||
"max_tokens": 4096,
|
||
"stream": true,
|
||
}
|
||
if systemContent != "" {
|
||
body["system"] = strings.TrimSpace(systemContent)
|
||
}
|
||
|
||
bodyJSON, _ := json.Marshal(body)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||
if err != nil {
|
||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||
return
|
||
}
|
||
req.Header.Set("x-api-key", apiKey)
|
||
req.Header.Set("anthropic-version", "2023-06-01")
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 120 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != 200 {
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||
return
|
||
}
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") {
|
||
continue
|
||
}
|
||
|
||
data := strings.TrimPrefix(line, "data: ")
|
||
|
||
var event struct {
|
||
Type string `json:"type"`
|
||
Delta *struct {
|
||
Type string `json:"type"`
|
||
Text string `json:"text"`
|
||
} `json:"delta"`
|
||
Usage *struct {
|
||
InputTokens int `json:"input_tokens"`
|
||
OutputTokens int `json:"output_tokens"`
|
||
} `json:"usage"`
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||
continue
|
||
}
|
||
|
||
switch event.Type {
|
||
case "content_block_delta":
|
||
if event.Delta != nil && event.Delta.Text != "" {
|
||
ch <- AIStreamChunk{Content: event.Delta.Text}
|
||
}
|
||
case "message_delta":
|
||
if event.Usage != nil {
|
||
ch <- AIStreamChunk{
|
||
Done: true,
|
||
Model: modelName,
|
||
PromptTokens: event.Usage.InputTokens,
|
||
CompletionTokens: event.Usage.OutputTokens,
|
||
}
|
||
return
|
||
}
|
||
case "message_stop":
|
||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||
return
|
||
case "error":
|
||
ch <- AIStreamChunk{Error: "Claude API 错误", Done: true}
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// ==================== Gemini ====================
|
||
|
||
func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
|
||
strings.TrimRight(baseURL, "/"), modelName, apiKey)
|
||
|
||
// 构建 Gemini 格式的消息
|
||
var systemInstruction string
|
||
var contents []map[string]interface{}
|
||
|
||
for _, msg := range messages {
|
||
if msg.Role == "system" {
|
||
systemInstruction += msg.Content + "\n"
|
||
continue
|
||
}
|
||
|
||
role := msg.Role
|
||
if role == "assistant" {
|
||
role = "model"
|
||
}
|
||
|
||
contents = append(contents, map[string]interface{}{
|
||
"role": role,
|
||
"parts": []map[string]string{
|
||
{"text": msg.Content},
|
||
},
|
||
})
|
||
}
|
||
|
||
if len(contents) == 0 {
|
||
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
|
||
return
|
||
}
|
||
|
||
body := map[string]interface{}{
|
||
"contents": contents,
|
||
}
|
||
if systemInstruction != "" {
|
||
body["systemInstruction"] = map[string]interface{}{
|
||
"parts": []map[string]string{
|
||
{"text": strings.TrimSpace(systemInstruction)},
|
||
},
|
||
}
|
||
}
|
||
|
||
bodyJSON, _ := json.Marshal(body)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||
if err != nil {
|
||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||
return
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 120 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != 200 {
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||
return
|
||
}
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
// Gemini 返回较大的 chunks,增大 buffer
|
||
scanner.Buffer(make([]byte, 0), 1024*1024)
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") {
|
||
continue
|
||
}
|
||
|
||
data := strings.TrimPrefix(line, "data: ")
|
||
|
||
var chunk struct {
|
||
Candidates []struct {
|
||
Content struct {
|
||
Parts []struct {
|
||
Text string `json:"text"`
|
||
} `json:"parts"`
|
||
} `json:"content"`
|
||
} `json:"candidates"`
|
||
UsageMetadata *struct {
|
||
PromptTokenCount int `json:"promptTokenCount"`
|
||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||
} `json:"usageMetadata"`
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||
continue
|
||
}
|
||
|
||
for _, candidate := range chunk.Candidates {
|
||
for _, part := range candidate.Content.Parts {
|
||
if part.Text != "" {
|
||
ch <- AIStreamChunk{Content: part.Text}
|
||
}
|
||
}
|
||
}
|
||
|
||
if chunk.UsageMetadata != nil {
|
||
ch <- AIStreamChunk{
|
||
Done: true,
|
||
Model: modelName,
|
||
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
||
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
||
}
|
||
return
|
||
}
|
||
}
|
||
|
||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||
}
|
||
|
||
// ==================== 非流式调用(用于生图等) ====================
|
||
|
||
// CallAINonStream 非流式调用 AI
|
||
func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) {
|
||
apiKey := decryptAPIKey(provider.APIKey)
|
||
baseURL := provider.BaseURL
|
||
|
||
switch provider.ProviderType {
|
||
case "openai", "custom":
|
||
return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages)
|
||
case "claude":
|
||
return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages)
|
||
case "gemini":
|
||
return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages)
|
||
default:
|
||
return "", errors.New("不支持的提供商类型: " + provider.ProviderType)
|
||
}
|
||
}
|
||
|
||
func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||
|
||
body := map[string]interface{}{
|
||
"model": modelName,
|
||
"messages": messages,
|
||
}
|
||
bodyJSON, _ := json.Marshal(body)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 60 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
if resp.StatusCode != 200 {
|
||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
var result struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(result.Choices) > 0 {
|
||
return result.Choices[0].Message.Content, nil
|
||
}
|
||
return "", errors.New("AI 未返回有效回复")
|
||
}
|
||
|
||
func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
|
||
|
||
var systemContent string
|
||
var claudeMessages []map[string]string
|
||
for _, msg := range messages {
|
||
if msg.Role == "system" {
|
||
systemContent += msg.Content + "\n"
|
||
} else {
|
||
claudeMessages = append(claudeMessages, map[string]string{
|
||
"role": msg.Role,
|
||
"content": msg.Content,
|
||
})
|
||
}
|
||
}
|
||
|
||
body := map[string]interface{}{
|
||
"model": modelName,
|
||
"messages": claudeMessages,
|
||
"max_tokens": 4096,
|
||
}
|
||
if systemContent != "" {
|
||
body["system"] = strings.TrimSpace(systemContent)
|
||
}
|
||
|
||
bodyJSON, _ := json.Marshal(body)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
req.Header.Set("x-api-key", apiKey)
|
||
req.Header.Set("anthropic-version", "2023-06-01")
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 60 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
if resp.StatusCode != 200 {
|
||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
var result struct {
|
||
Content []struct {
|
||
Text string `json:"text"`
|
||
} `json:"content"`
|
||
}
|
||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(result.Content) > 0 {
|
||
return result.Content[0].Text, nil
|
||
}
|
||
return "", errors.New("AI 未返回有效回复")
|
||
}
|
||
|
||
func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
|
||
strings.TrimRight(baseURL, "/"), modelName, apiKey)
|
||
|
||
var systemInstruction string
|
||
var contents []map[string]interface{}
|
||
for _, msg := range messages {
|
||
if msg.Role == "system" {
|
||
systemInstruction += msg.Content + "\n"
|
||
continue
|
||
}
|
||
role := msg.Role
|
||
if role == "assistant" {
|
||
role = "model"
|
||
}
|
||
contents = append(contents, map[string]interface{}{
|
||
"role": role,
|
||
"parts": []map[string]string{{"text": msg.Content}},
|
||
})
|
||
}
|
||
|
||
body := map[string]interface{}{"contents": contents}
|
||
if systemInstruction != "" {
|
||
body["systemInstruction"] = map[string]interface{}{
|
||
"parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}},
|
||
}
|
||
}
|
||
|
||
bodyJSON, _ := json.Marshal(body)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 60 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
if resp.StatusCode != 200 {
|
||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
var result struct {
|
||
Candidates []struct {
|
||
Content struct {
|
||
Parts []struct {
|
||
Text string `json:"text"`
|
||
} `json:"parts"`
|
||
} `json:"content"`
|
||
} `json:"candidates"`
|
||
}
|
||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 {
|
||
return result.Candidates[0].Content.Parts[0].Text, nil
|
||
}
|
||
return "", errors.New("AI 未返回有效回复")
|
||
}
|