🎨 优化扩展模块,完成ai接入和对话功能

This commit is contained in:
2026-02-12 23:12:28 +08:00
parent 4e611d3a5e
commit 572f3aa15b
779 changed files with 194400 additions and 3136 deletions

View File

@@ -0,0 +1,625 @@
package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
appModel "git.echol.cn/loser/st/server/model/app"
)
// AIMessage AI调用的消息格式统一
type AIMessagePayload struct {
Role string `json:"role"`
Content string `json:"content"`
}
// AIStreamChunk 流式响应的单个块
type AIStreamChunk struct {
Content string `json:"content"` // 增量文本
Done bool `json:"done"` // 是否结束
Model string `json:"model"` // 使用的模型
PromptTokens int `json:"promptTokens"` // 提示词Token仅结束时有值
CompletionTokens int `json:"completionTokens"` // 补全Token仅结束时有值
Error string `json:"error"` // 错误信息
}
// BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt
func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload {
var payload []AIMessagePayload
// 1. 系统提示词System Prompt
systemPrompt := buildSystemPrompt(character)
if systemPrompt != "" {
payload = append(payload, AIMessagePayload{
Role: "system",
Content: systemPrompt,
})
}
// 2. 历史消息
for _, msg := range messages {
role := msg.Role
if role == "assistant" || role == "user" || role == "system" {
payload = append(payload, AIMessagePayload{
Role: role,
Content: msg.Content,
})
}
}
return payload
}
// buildSystemPrompt 构建系统提示词
func buildSystemPrompt(character *appModel.AICharacter) string {
if character == nil {
return ""
}
var parts []string
// 角色描述
if character.Description != "" {
parts = append(parts, character.Description)
}
// 角色性格
if character.Personality != "" {
parts = append(parts, "Personality: "+character.Personality)
}
// 场景
if character.Scenario != "" {
parts = append(parts, "Scenario: "+character.Scenario)
}
// 系统提示词
if character.SystemPrompt != "" {
parts = append(parts, character.SystemPrompt)
}
// 示例消息
if len(character.ExampleMessages) > 0 {
parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n"))
}
return strings.Join(parts, "\n\n")
}
// StreamAIResponse 流式调用 AI 并通过 channel 返回结果
func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
defer close(ch)
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch)
case "claude":
streamClaude(ctx, baseURL, apiKey, modelName, messages, ch)
case "gemini":
streamGemini(ctx, baseURL, apiKey, modelName, messages, ch)
default:
ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true}
}
}
// ==================== OpenAI Compatible ====================
func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
"stream": true,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
var fullContent string
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""}
return
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
content := chunk.Choices[0].Delta.Content
fullContent += content
ch <- AIStreamChunk{Content: content}
}
if chunk.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.Usage.PromptTokens,
CompletionTokens: chunk.Usage.CompletionTokens,
}
return
}
}
// 如果扫描结束但没收到 [DONE]
if fullContent != "" {
ch <- AIStreamChunk{Done: true, Model: modelName}
}
}
// ==================== Claude ====================
func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
// Claude 需要把 system 消息分离出来
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
// 确保至少有一条消息
if len(claudeMessages) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
"stream": true,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var event struct {
Type string `json:"type"`
Delta *struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"delta"`
Usage *struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
switch event.Type {
case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
ch <- AIStreamChunk{Content: event.Delta.Text}
}
case "message_delta":
if event.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens,
}
return
}
case "message_stop":
ch <- AIStreamChunk{Done: true, Model: modelName}
return
case "error":
ch <- AIStreamChunk{Error: "Claude API 错误", Done: true}
return
}
}
}
// ==================== Gemini ====================
func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
// 构建 Gemini 格式的消息
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{
{"text": msg.Content},
},
})
}
if len(contents) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"contents": contents,
}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{
{"text": strings.TrimSpace(systemInstruction)},
},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
// Gemini 返回较大的 chunks增大 buffer
scanner.Buffer(make([]byte, 0), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var chunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata *struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
for _, candidate := range chunk.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
ch <- AIStreamChunk{Content: part.Text}
}
}
}
if chunk.UsageMetadata != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
}
return
}
}
ch <- AIStreamChunk{Done: true, Model: modelName}
}
// ==================== 非流式调用(用于生图等) ====================
// CallAINonStream 非流式调用 AI
func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) {
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages)
case "claude":
return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages)
case "gemini":
return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages)
default:
return "", errors.New("不支持的提供商类型: " + provider.ProviderType)
}
}
func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Choices) > 0 {
return result.Choices[0].Message.Content, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Content) > 0 {
return result.Content[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{{"text": msg.Content}},
})
}
body := map[string]interface{}{"contents": contents}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 {
return result.Candidates[0].Content.Parts[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}