🎨 优化扩展模块,完成ai接入和对话功能
This commit is contained in:
625
server/service/app/ai_client.go
Normal file
625
server/service/app/ai_client.go
Normal file
@@ -0,0 +1,625 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appModel "git.echol.cn/loser/st/server/model/app"
|
||||
)
|
||||
|
||||
// AIMessage AI调用的消息格式(统一)
|
||||
type AIMessagePayload struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// AIStreamChunk 流式响应的单个块
|
||||
type AIStreamChunk struct {
|
||||
Content string `json:"content"` // 增量文本
|
||||
Done bool `json:"done"` // 是否结束
|
||||
Model string `json:"model"` // 使用的模型
|
||||
PromptTokens int `json:"promptTokens"` // 提示词Token(仅结束时有值)
|
||||
CompletionTokens int `json:"completionTokens"` // 补全Token(仅结束时有值)
|
||||
Error string `json:"error"` // 错误信息
|
||||
}
|
||||
|
||||
// BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt
|
||||
func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload {
|
||||
var payload []AIMessagePayload
|
||||
|
||||
// 1. 系统提示词(System Prompt)
|
||||
systemPrompt := buildSystemPrompt(character)
|
||||
if systemPrompt != "" {
|
||||
payload = append(payload, AIMessagePayload{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// 2. 历史消息
|
||||
for _, msg := range messages {
|
||||
role := msg.Role
|
||||
if role == "assistant" || role == "user" || role == "system" {
|
||||
payload = append(payload, AIMessagePayload{
|
||||
Role: role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildSystemPrompt 构建系统提示词
|
||||
func buildSystemPrompt(character *appModel.AICharacter) string {
|
||||
if character == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
|
||||
// 角色描述
|
||||
if character.Description != "" {
|
||||
parts = append(parts, character.Description)
|
||||
}
|
||||
|
||||
// 角色性格
|
||||
if character.Personality != "" {
|
||||
parts = append(parts, "Personality: "+character.Personality)
|
||||
}
|
||||
|
||||
// 场景
|
||||
if character.Scenario != "" {
|
||||
parts = append(parts, "Scenario: "+character.Scenario)
|
||||
}
|
||||
|
||||
// 系统提示词
|
||||
if character.SystemPrompt != "" {
|
||||
parts = append(parts, character.SystemPrompt)
|
||||
}
|
||||
|
||||
// 示例消息
|
||||
if len(character.ExampleMessages) > 0 {
|
||||
parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n"))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// StreamAIResponse 流式调用 AI 并通过 channel 返回结果
|
||||
func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
defer close(ch)
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
baseURL := provider.BaseURL
|
||||
|
||||
switch provider.ProviderType {
|
||||
case "openai", "custom":
|
||||
streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch)
|
||||
case "claude":
|
||||
streamClaude(ctx, baseURL, apiKey, modelName, messages, ch)
|
||||
case "gemini":
|
||||
streamGemini(ctx, baseURL, apiKey, modelName, messages, ch)
|
||||
default:
|
||||
ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== OpenAI Compatible ====================
|
||||
|
||||
func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
}
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
var fullContent string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""}
|
||||
return
|
||||
}
|
||||
|
||||
var chunk struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
Usage *struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
content := chunk.Choices[0].Delta.Content
|
||||
fullContent += content
|
||||
ch <- AIStreamChunk{Content: content}
|
||||
}
|
||||
|
||||
if chunk.Usage != nil {
|
||||
ch <- AIStreamChunk{
|
||||
Done: true,
|
||||
Model: modelName,
|
||||
PromptTokens: chunk.Usage.PromptTokens,
|
||||
CompletionTokens: chunk.Usage.CompletionTokens,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果扫描结束但没收到 [DONE]
|
||||
if fullContent != "" {
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Claude ====================
|
||||
|
||||
func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
|
||||
|
||||
// Claude 需要把 system 消息分离出来
|
||||
var systemContent string
|
||||
var claudeMessages []map[string]string
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemContent += msg.Content + "\n"
|
||||
} else {
|
||||
claudeMessages = append(claudeMessages, map[string]string{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 确保至少有一条消息
|
||||
if len(claudeMessages) == 0 {
|
||||
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": claudeMessages,
|
||||
"max_tokens": 4096,
|
||||
"stream": true,
|
||||
}
|
||||
if systemContent != "" {
|
||||
body["system"] = strings.TrimSpace(systemContent)
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Delta *struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"delta"`
|
||||
Usage *struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "content_block_delta":
|
||||
if event.Delta != nil && event.Delta.Text != "" {
|
||||
ch <- AIStreamChunk{Content: event.Delta.Text}
|
||||
}
|
||||
case "message_delta":
|
||||
if event.Usage != nil {
|
||||
ch <- AIStreamChunk{
|
||||
Done: true,
|
||||
Model: modelName,
|
||||
PromptTokens: event.Usage.InputTokens,
|
||||
CompletionTokens: event.Usage.OutputTokens,
|
||||
}
|
||||
return
|
||||
}
|
||||
case "message_stop":
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||||
return
|
||||
case "error":
|
||||
ch <- AIStreamChunk{Error: "Claude API 错误", Done: true}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Gemini ====================
|
||||
|
||||
func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
|
||||
strings.TrimRight(baseURL, "/"), modelName, apiKey)
|
||||
|
||||
// 构建 Gemini 格式的消息
|
||||
var systemInstruction string
|
||||
var contents []map[string]interface{}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemInstruction += msg.Content + "\n"
|
||||
continue
|
||||
}
|
||||
|
||||
role := msg.Role
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
|
||||
contents = append(contents, map[string]interface{}{
|
||||
"role": role,
|
||||
"parts": []map[string]string{
|
||||
{"text": msg.Content},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if len(contents) == 0 {
|
||||
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"contents": contents,
|
||||
}
|
||||
if systemInstruction != "" {
|
||||
body["systemInstruction"] = map[string]interface{}{
|
||||
"parts": []map[string]string{
|
||||
{"text": strings.TrimSpace(systemInstruction)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// Gemini 返回较大的 chunks,增大 buffer
|
||||
scanner.Buffer(make([]byte, 0), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var chunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata *struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, candidate := range chunk.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
ch <- AIStreamChunk{Content: part.Text}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.UsageMetadata != nil {
|
||||
ch <- AIStreamChunk{
|
||||
Done: true,
|
||||
Model: modelName,
|
||||
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||||
}
|
||||
|
||||
// ==================== 非流式调用(用于生图等) ====================
|
||||
|
||||
// CallAINonStream 非流式调用 AI
|
||||
func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
baseURL := provider.BaseURL
|
||||
|
||||
switch provider.ProviderType {
|
||||
case "openai", "custom":
|
||||
return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages)
|
||||
case "claude":
|
||||
return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages)
|
||||
case "gemini":
|
||||
return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages)
|
||||
default:
|
||||
return "", errors.New("不支持的提供商类型: " + provider.ProviderType)
|
||||
}
|
||||
}
|
||||
|
||||
func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": messages,
|
||||
}
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(result.Choices) > 0 {
|
||||
return result.Choices[0].Message.Content, nil
|
||||
}
|
||||
return "", errors.New("AI 未返回有效回复")
|
||||
}
|
||||
|
||||
func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
|
||||
|
||||
var systemContent string
|
||||
var claudeMessages []map[string]string
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemContent += msg.Content + "\n"
|
||||
} else {
|
||||
claudeMessages = append(claudeMessages, map[string]string{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": claudeMessages,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
if systemContent != "" {
|
||||
body["system"] = strings.TrimSpace(systemContent)
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(result.Content) > 0 {
|
||||
return result.Content[0].Text, nil
|
||||
}
|
||||
return "", errors.New("AI 未返回有效回复")
|
||||
}
|
||||
|
||||
func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
|
||||
strings.TrimRight(baseURL, "/"), modelName, apiKey)
|
||||
|
||||
var systemInstruction string
|
||||
var contents []map[string]interface{}
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemInstruction += msg.Content + "\n"
|
||||
continue
|
||||
}
|
||||
role := msg.Role
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
contents = append(contents, map[string]interface{}{
|
||||
"role": role,
|
||||
"parts": []map[string]string{{"text": msg.Content}},
|
||||
})
|
||||
}
|
||||
|
||||
body := map[string]interface{}{"contents": contents}
|
||||
if systemInstruction != "" {
|
||||
body["systemInstruction"] = map[string]interface{}{
|
||||
"parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}},
|
||||
}
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 {
|
||||
return result.Candidates[0].Content.Parts[0].Text, nil
|
||||
}
|
||||
return "", errors.New("AI 未返回有效回复")
|
||||
}
|
||||
408
server/service/app/chat.go
Normal file
408
server/service/app/chat.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/app/response"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatService struct{}
|
||||
|
||||
// ==================== 对话 CRUD ====================
|
||||
|
||||
// CreateChat 创建对话
|
||||
func (cs *ChatService) CreateChat(req request.CreateChatRequest, userID uint) (response.ChatResponse, error) {
|
||||
// 获取角色卡信息
|
||||
var character app.AICharacter
|
||||
err := global.GVA_DB.First(&character, req.CharacterID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ChatResponse{}, errors.New("角色卡不存在")
|
||||
}
|
||||
return response.ChatResponse{}, err
|
||||
}
|
||||
|
||||
// 对话标题
|
||||
title := req.Title
|
||||
if title == "" {
|
||||
title = character.Name
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
chat := app.AIChat{
|
||||
Title: title,
|
||||
UserID: userID,
|
||||
CharacterID: &req.CharacterID,
|
||||
ChatType: "single",
|
||||
LastMessageAt: &now,
|
||||
}
|
||||
|
||||
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建对话
|
||||
if err := tx.Create(&chat).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果角色卡有 FirstMessage,自动创建第一条系统消息
|
||||
if character.FirstMessage != "" {
|
||||
firstMsg := app.AIMessage{
|
||||
ChatID: chat.ID,
|
||||
Content: character.FirstMessage,
|
||||
Role: "assistant",
|
||||
CharacterID: &req.CharacterID,
|
||||
SequenceNumber: 1,
|
||||
}
|
||||
if err := tx.Create(&firstMsg).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
chat.MessageCount = 1
|
||||
tx.Model(&chat).Update("message_count", 1)
|
||||
}
|
||||
|
||||
// 更新角色卡使用次数
|
||||
tx.Model(&character).Update("total_chats", gorm.Expr("total_chats + 1"))
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ChatResponse{}, err
|
||||
}
|
||||
|
||||
return toChatResponse(&chat, &character, nil), nil
|
||||
}
|
||||
|
||||
// GetChatList 获取用户的对话列表
|
||||
func (cs *ChatService) GetChatList(req request.ChatListRequest, userID uint) (response.ChatListResponse, error) {
|
||||
db := global.GVA_DB.Model(&app.AIChat{}).Where("user_id = ?", userID)
|
||||
|
||||
var total int64
|
||||
db.Count(&total)
|
||||
|
||||
var chats []app.AIChat
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
err := db.Preload("Character").
|
||||
Order("is_pinned DESC, last_message_at DESC").
|
||||
Offset(offset).Limit(req.PageSize).
|
||||
Find(&chats).Error
|
||||
if err != nil {
|
||||
return response.ChatListResponse{}, err
|
||||
}
|
||||
|
||||
// 获取每个对话的最后一条消息
|
||||
chatIDs := make([]uint, len(chats))
|
||||
for i, c := range chats {
|
||||
chatIDs[i] = c.ID
|
||||
}
|
||||
|
||||
lastMessages := make(map[uint]*response.MessageBrief)
|
||||
if len(chatIDs) > 0 {
|
||||
var messages []app.AIMessage
|
||||
// 获取每个对话的最后一条消息(通过子查询)
|
||||
global.GVA_DB.Where("chat_id IN ? AND is_deleted = ?", chatIDs, false).
|
||||
Order("sequence_number DESC").
|
||||
Find(&messages)
|
||||
|
||||
// 只保留每个对话的最后一条
|
||||
for _, msg := range messages {
|
||||
if _, exists := lastMessages[msg.ChatID]; !exists {
|
||||
content := msg.Content
|
||||
if len(content) > 100 {
|
||||
content = content[:100] + "..."
|
||||
}
|
||||
lastMessages[msg.ChatID] = &response.MessageBrief{
|
||||
Content: content,
|
||||
Role: msg.Role,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
list := make([]response.ChatResponse, len(chats))
|
||||
for i, c := range chats {
|
||||
list[i] = toChatResponse(&c, c.Character, lastMessages[c.ID])
|
||||
}
|
||||
|
||||
return response.ChatListResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChatDetail 获取对话详情(包含消息列表)
|
||||
func (cs *ChatService) GetChatDetail(chatID uint, userID uint) (response.ChatDetailResponse, error) {
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Preload("Character").
|
||||
Where("id = ? AND user_id = ?", chatID, userID).
|
||||
First(&chat).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ChatDetailResponse{}, errors.New("对话不存在")
|
||||
}
|
||||
return response.ChatDetailResponse{}, err
|
||||
}
|
||||
|
||||
// 获取消息列表(最近的50条)
|
||||
var messages []app.AIMessage
|
||||
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
|
||||
Order("sequence_number ASC").
|
||||
Limit(50).
|
||||
Find(&messages)
|
||||
|
||||
msgList := make([]response.MessageResponse, len(messages))
|
||||
for i, msg := range messages {
|
||||
msgList[i] = toMessageResponse(&msg, chat.Character)
|
||||
}
|
||||
|
||||
return response.ChatDetailResponse{
|
||||
Chat: toChatResponse(&chat, chat.Character, nil),
|
||||
Messages: msgList,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChatMessages 分页获取对话消息
|
||||
func (cs *ChatService) GetChatMessages(chatID uint, req request.ChatMessagesRequest, userID uint) (response.MessageListResponse, error) {
|
||||
// 验证对话归属
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||||
if err != nil {
|
||||
return response.MessageListResponse{}, errors.New("对话不存在")
|
||||
}
|
||||
|
||||
db := global.GVA_DB.Model(&app.AIMessage{}).
|
||||
Where("chat_id = ? AND is_deleted = ?", chatID, false)
|
||||
|
||||
var total int64
|
||||
db.Count(&total)
|
||||
|
||||
var messages []app.AIMessage
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
err = db.Order("sequence_number ASC").
|
||||
Offset(offset).Limit(req.PageSize).
|
||||
Find(&messages).Error
|
||||
if err != nil {
|
||||
return response.MessageListResponse{}, err
|
||||
}
|
||||
|
||||
// 预加载角色信息
|
||||
var character *app.AICharacter
|
||||
if chat.CharacterID != nil {
|
||||
var c app.AICharacter
|
||||
global.GVA_DB.First(&c, *chat.CharacterID)
|
||||
character = &c
|
||||
}
|
||||
|
||||
list := make([]response.MessageResponse, len(messages))
|
||||
for i, msg := range messages {
|
||||
list[i] = toMessageResponse(&msg, character)
|
||||
}
|
||||
|
||||
return response.MessageListResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteChat 删除对话
|
||||
func (cs *ChatService) DeleteChat(chatID uint, userID uint) error {
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("对话不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除消息
|
||||
if err := tx.Where("chat_id = ?", chatID).Delete(&app.AIMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除消息变体
|
||||
if err := tx.Where("message_id IN (?)",
|
||||
tx.Model(&app.AIMessage{}).Select("id").Where("chat_id = ?", chatID),
|
||||
).Delete(&app.AIMessageSwipe{}).Error; err != nil {
|
||||
// 忽略错误,可能没有变体
|
||||
}
|
||||
// 删除对话
|
||||
return tx.Delete(&chat).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 消息操作 ====================
|
||||
|
||||
// SaveUserMessage 保存用户消息(内部方法)
|
||||
func (cs *ChatService) SaveUserMessage(chatID uint, userID uint, content string) (*app.AIMessage, error) {
|
||||
// 获取当前最大序号
|
||||
var maxSeq int
|
||||
global.GVA_DB.Model(&app.AIMessage{}).
|
||||
Where("chat_id = ?", chatID).
|
||||
Select("COALESCE(MAX(sequence_number), 0)").
|
||||
Scan(&maxSeq)
|
||||
|
||||
msg := &app.AIMessage{
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Role: "user",
|
||||
SenderID: &userID,
|
||||
SequenceNumber: maxSeq + 1,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(msg).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 更新对话的消息数和最后消息时间
|
||||
now := time.Now()
|
||||
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
|
||||
"message_count": gorm.Expr("message_count + 1"),
|
||||
"last_message_at": now,
|
||||
})
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// SaveAssistantMessage 保存AI回复消息(内部方法)
|
||||
func (cs *ChatService) SaveAssistantMessage(chatID uint, characterID *uint, content string, model string, promptTokens, completionTokens int) (*app.AIMessage, error) {
|
||||
var maxSeq int
|
||||
global.GVA_DB.Model(&app.AIMessage{}).
|
||||
Where("chat_id = ?", chatID).
|
||||
Select("COALESCE(MAX(sequence_number), 0)").
|
||||
Scan(&maxSeq)
|
||||
|
||||
msg := &app.AIMessage{
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Role: "assistant",
|
||||
CharacterID: characterID,
|
||||
SequenceNumber: maxSeq + 1,
|
||||
Model: model,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(msg).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
|
||||
"message_count": gorm.Expr("message_count + 1"),
|
||||
"last_message_at": now,
|
||||
})
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// EditMessage 编辑消息
|
||||
func (cs *ChatService) EditMessage(req request.EditMessageRequest, userID uint) (response.MessageResponse, error) {
|
||||
var msg app.AIMessage
|
||||
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
|
||||
Where("ai_messages.id = ? AND ai_chats.user_id = ?", req.MessageID, userID).
|
||||
First(&msg).Error
|
||||
if err != nil {
|
||||
return response.MessageResponse{}, errors.New("消息不存在")
|
||||
}
|
||||
|
||||
msg.Content = req.Content
|
||||
if err := global.GVA_DB.Save(&msg).Error; err != nil {
|
||||
return response.MessageResponse{}, err
|
||||
}
|
||||
|
||||
return toMessageResponse(&msg, nil), nil
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息(软删除)
|
||||
func (cs *ChatService) DeleteMessage(messageID uint, userID uint) error {
|
||||
var msg app.AIMessage
|
||||
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
|
||||
Where("ai_messages.id = ? AND ai_chats.user_id = ?", messageID, userID).
|
||||
First(&msg).Error
|
||||
if err != nil {
|
||||
return errors.New("消息不存在")
|
||||
}
|
||||
|
||||
return global.GVA_DB.Model(&msg).Update("is_deleted", true).Error
|
||||
}
|
||||
|
||||
// GetChatForAI 获取对话用于AI调用的完整上下文(内部方法)
|
||||
func (cs *ChatService) GetChatForAI(chatID uint, userID uint) (*app.AIChat, *app.AICharacter, []app.AIMessage, error) {
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.New("对话不存在")
|
||||
}
|
||||
|
||||
var character *app.AICharacter
|
||||
if chat.CharacterID != nil {
|
||||
var c app.AICharacter
|
||||
if err := global.GVA_DB.First(&c, *chat.CharacterID).Error; err == nil {
|
||||
character = &c
|
||||
}
|
||||
}
|
||||
|
||||
// 获取历史消息(最近30条)
|
||||
var messages []app.AIMessage
|
||||
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
|
||||
Order("sequence_number ASC").
|
||||
Limit(30).
|
||||
Find(&messages)
|
||||
|
||||
return &chat, character, messages, nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
func toChatResponse(chat *app.AIChat, character *app.AICharacter, lastMsg *response.MessageBrief) response.ChatResponse {
|
||||
resp := response.ChatResponse{
|
||||
ID: chat.ID,
|
||||
Title: chat.Title,
|
||||
CharacterID: chat.CharacterID,
|
||||
ChatType: chat.ChatType,
|
||||
LastMessageAt: chat.LastMessageAt,
|
||||
MessageCount: chat.MessageCount,
|
||||
IsPinned: chat.IsPinned,
|
||||
LastMessage: lastMsg,
|
||||
CreatedAt: chat.CreatedAt,
|
||||
}
|
||||
|
||||
if character != nil {
|
||||
resp.CharacterName = character.Name
|
||||
resp.CharacterAvatar = character.Avatar
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func toMessageResponse(msg *app.AIMessage, character *app.AICharacter) response.MessageResponse {
|
||||
resp := response.MessageResponse{
|
||||
ID: msg.ID,
|
||||
ChatID: msg.ChatID,
|
||||
Content: msg.Content,
|
||||
Role: msg.Role,
|
||||
CharacterID: msg.CharacterID,
|
||||
Model: msg.Model,
|
||||
PromptTokens: msg.PromptTokens,
|
||||
CompletionTokens: msg.CompletionTokens,
|
||||
TotalTokens: msg.TotalTokens,
|
||||
SequenceNumber: msg.SequenceNumber,
|
||||
CreatedAt: msg.CreatedAt,
|
||||
}
|
||||
|
||||
if msg.Role == "assistant" && character != nil {
|
||||
resp.CharacterName = character.Name
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -6,4 +6,6 @@ type AppServiceGroup struct {
|
||||
WorldInfoService
|
||||
ExtensionService
|
||||
RegexScriptService
|
||||
ProviderService
|
||||
ChatService
|
||||
}
|
||||
|
||||
@@ -21,15 +21,70 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// extensionDataDir 扩展本地存储根目录
|
||||
// 与原版 SillyTavern 完全一致的路径结构:scripts/extensions/third-party/{name}/
|
||||
// 扩展 JS 中的相对路径 import(如 ../../../../../script.js)依赖此目录层级来正确解析
|
||||
// 所有 SillyTavern 核心脚本和扩展文件统一存储在 data/st-core-scripts/ 下,独立于 web-app/
|
||||
// 扩展代码是公共的(不按用户隔离),用户间差异仅在于数据库中的配置和启用状态
|
||||
const extensionDataDir = "data/st-core-scripts/scripts/extensions/third-party"
|
||||
|
||||
// getExtensionStorePath 获取扩展的本地存储路径: {extensionDataDir}/{extensionName}/
|
||||
func getExtensionStorePath(extensionName string) string {
|
||||
return filepath.Join(extensionDataDir, extensionName)
|
||||
}
|
||||
|
||||
// GetExtensionAssetLocalPath 获取扩展资源文件的本地绝对路径
|
||||
func (es *ExtensionService) GetExtensionAssetLocalPath(extensionName string, assetPath string) (string, error) {
|
||||
storePath := getExtensionStorePath(extensionName)
|
||||
fullPath := filepath.Join(storePath, assetPath)
|
||||
|
||||
// 安全检查:防止路径遍历攻击
|
||||
absStore, _ := filepath.Abs(storePath)
|
||||
absFile, _ := filepath.Abs(fullPath)
|
||||
if !strings.HasPrefix(absFile, absStore) {
|
||||
return "", errors.New("非法的资源路径")
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("资源文件不存在: %s", assetPath)
|
||||
}
|
||||
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
// ensureExtensionDir 确保扩展存储目录存在
|
||||
func ensureExtensionDir(extensionName string) (string, error) {
|
||||
storePath := getExtensionStorePath(extensionName)
|
||||
if err := os.MkdirAll(storePath, 0755); err != nil {
|
||||
return "", fmt.Errorf("创建扩展存储目录失败: %w", err)
|
||||
}
|
||||
return storePath, nil
|
||||
}
|
||||
|
||||
// removeExtensionDir 删除扩展的本地存储目录
|
||||
func removeExtensionDir(extensionName string) error {
|
||||
storePath := getExtensionStorePath(extensionName)
|
||||
if _, err := os.Stat(storePath); os.IsNotExist(err) {
|
||||
return nil // 目录不存在,无需删除
|
||||
}
|
||||
return os.RemoveAll(storePath)
|
||||
}
|
||||
|
||||
type ExtensionService struct{}
|
||||
|
||||
// CreateExtension 创建/安装扩展
|
||||
func (es *ExtensionService) CreateExtension(userID uint, req *request.CreateExtensionRequest) (*app.AIExtension, error) {
|
||||
// 校验名称
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("扩展名称不能为空")
|
||||
}
|
||||
|
||||
// 检查扩展是否已存在
|
||||
var existing app.AIExtension
|
||||
err := global.GVA_DB.Where("user_id = ? AND name = ?", userID, req.Name).First(&existing).Error
|
||||
if err == nil {
|
||||
return nil, errors.New("扩展已存在")
|
||||
return nil, fmt.Errorf("扩展 %s 已存在", req.Name)
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
@@ -136,15 +191,18 @@ func (es *ExtensionService) DeleteExtension(userID, extensionID uint, deleteFile
|
||||
return errors.New("系统内置扩展不允许删除")
|
||||
}
|
||||
|
||||
// TODO: 如果 deleteFiles=true,删除扩展文件
|
||||
// 这需要文件系统支持
|
||||
// 删除本地扩展文件(与原版 SillyTavern 一致:卸载扩展时清理本地文件)
|
||||
if err := removeExtensionDir(extension.Name); err != nil {
|
||||
global.GVA_LOG.Warn("删除扩展本地文件失败", zap.Error(err), zap.String("name", extension.Name))
|
||||
// 不阻断删除流程
|
||||
}
|
||||
|
||||
// 删除扩展(配置已经在扩展记录的 Settings 字段中,无需单独删除)
|
||||
// 删除数据库记录
|
||||
if err := global.GVA_DB.Delete(&extension).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID))
|
||||
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID), zap.String("name", extension.Name))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,6 +215,15 @@ func (es *ExtensionService) GetExtension(userID, extensionID uint) (*app.AIExten
|
||||
return &extension, nil
|
||||
}
|
||||
|
||||
// GetExtensionByID 通过扩展 ID 获取扩展信息(不限制用户,用于公开资源路由)
|
||||
func (es *ExtensionService) GetExtensionByID(extensionID uint) (*app.AIExtension, error) {
|
||||
var extension app.AIExtension
|
||||
if err := global.GVA_DB.Where("id = ?", extensionID).First(&extension).Error; err != nil {
|
||||
return nil, errors.New("扩展不存在")
|
||||
}
|
||||
return &extension, nil
|
||||
}
|
||||
|
||||
// GetExtensionList 获取扩展列表
|
||||
func (es *ExtensionService) GetExtensionList(userID uint, req *request.ExtensionListRequest) (*response.ExtensionListResponse, error) {
|
||||
var extensions []app.AIExtension
|
||||
@@ -550,9 +617,37 @@ func isGitURL(url string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装
|
||||
// GetExtensionAssetURL 根据扩展的安装来源构建资源文件的远程 URL
|
||||
func (es *ExtensionService) GetExtensionAssetURL(extension *app.AIExtension, assetPath string) (string, error) {
|
||||
if extension.SourceURL == "" {
|
||||
return "", errors.New("扩展没有源地址")
|
||||
}
|
||||
|
||||
sourceURL := strings.TrimSuffix(strings.TrimSuffix(extension.SourceURL, "/"), ".git")
|
||||
branch := extension.Branch
|
||||
if branch == "" {
|
||||
branch = "main"
|
||||
}
|
||||
|
||||
// GitLab: repo/-/raw/branch/path
|
||||
if strings.Contains(sourceURL, "gitlab.com") {
|
||||
return fmt.Sprintf("%s/-/raw/%s/%s", sourceURL, branch, assetPath), nil
|
||||
}
|
||||
// GitHub: raw.githubusercontent.com/user/repo/branch/path
|
||||
if strings.Contains(sourceURL, "github.com") {
|
||||
rawURL := strings.Replace(sourceURL, "github.com", "raw.githubusercontent.com", 1)
|
||||
return fmt.Sprintf("%s/%s/%s", rawURL, branch, assetPath), nil
|
||||
}
|
||||
// Gitee: repo/raw/branch/path
|
||||
if strings.Contains(sourceURL, "gitee.com") {
|
||||
return fmt.Sprintf("%s/raw/%s/%s", sourceURL, branch, assetPath), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s", sourceURL, assetPath), nil
|
||||
}
|
||||
|
||||
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装(同时下载资源文件到本地)
|
||||
func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manifestURL string) (*app.AIExtension, error) {
|
||||
// 创建 HTTP 客户端
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
@@ -568,7 +663,6 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
return nil, fmt.Errorf("下载 manifest.json 失败: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应内容
|
||||
manifestData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取 manifest.json 失败: %w", err)
|
||||
@@ -580,21 +674,63 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if manifest.Name == "" {
|
||||
return nil, errors.New("manifest.json 缺少 name 字段")
|
||||
// 获取有效名称
|
||||
effectiveName := manifest.GetEffectiveName()
|
||||
if effectiveName == "" {
|
||||
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
|
||||
}
|
||||
|
||||
// 检查扩展是否已存在
|
||||
var existing app.AIExtension
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建本地存储目录并保存 manifest.json
|
||||
storePath, err := ensureExtensionDir(effectiveName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(storePath, "manifest.json"), manifestData, 0644); err != nil {
|
||||
_ = removeExtensionDir(effectiveName)
|
||||
return nil, fmt.Errorf("保存 manifest.json 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取 manifest URL 的基础目录(用于下载关联资源)
|
||||
baseURL := manifestURL[:strings.LastIndex(manifestURL, "/")+1]
|
||||
|
||||
// 下载 JS/CSS 等资源文件到本地
|
||||
filesToDownload := []string{}
|
||||
if entry := manifest.GetEffectiveEntry(); entry != "" {
|
||||
filesToDownload = append(filesToDownload, entry)
|
||||
}
|
||||
if style := manifest.GetEffectiveStyle(); style != "" {
|
||||
filesToDownload = append(filesToDownload, style)
|
||||
}
|
||||
filesToDownload = append(filesToDownload, manifest.Assets...)
|
||||
|
||||
for _, file := range filesToDownload {
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
fileURL := baseURL + file
|
||||
if err := downloadFileToLocal(client, fileURL, filepath.Join(storePath, file)); err != nil {
|
||||
global.GVA_LOG.Warn("下载扩展资源文件失败(非致命)",
|
||||
zap.String("file", file),
|
||||
zap.String("url", fileURL),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("扩展文件已保存到本地",
|
||||
zap.String("name", effectiveName),
|
||||
zap.String("path", storePath))
|
||||
|
||||
// 将 manifest 转换为 map[string]interface{}
|
||||
var manifestMap map[string]interface{}
|
||||
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
|
||||
@@ -603,13 +739,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
|
||||
// 构建创建请求
|
||||
createReq := &request.CreateExtensionRequest{
|
||||
Name: manifest.Name,
|
||||
Name: effectiveName,
|
||||
DisplayName: manifest.DisplayName,
|
||||
Version: manifest.Version,
|
||||
Author: manifest.Author,
|
||||
Description: manifest.Description,
|
||||
Homepage: manifest.Homepage,
|
||||
Repository: manifest.Repository, // 使用 manifest 中的 repository
|
||||
Homepage: manifest.GetEffectiveHomepage(),
|
||||
Repository: manifest.Repository,
|
||||
License: manifest.License,
|
||||
Tags: manifest.Tags,
|
||||
ExtensionType: manifest.Type,
|
||||
@@ -617,13 +753,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
Dependencies: manifest.Dependencies,
|
||||
Conflicts: manifest.Conflicts,
|
||||
ManifestData: manifestMap,
|
||||
ScriptPath: manifest.Entry,
|
||||
StylePath: manifest.Style,
|
||||
ScriptPath: manifest.GetEffectiveEntry(),
|
||||
StylePath: manifest.GetEffectiveStyle(),
|
||||
AssetsPaths: manifest.Assets,
|
||||
Settings: manifest.Settings,
|
||||
Options: manifest.Options,
|
||||
InstallSource: "url",
|
||||
SourceURL: manifestURL, // 记录原始 URL 用于更新
|
||||
SourceURL: manifestURL,
|
||||
AutoUpdate: manifest.AutoUpdate,
|
||||
Metadata: nil,
|
||||
}
|
||||
@@ -636,6 +772,7 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
// 创建扩展
|
||||
extension, err := es.CreateExtension(userID, createReq)
|
||||
if err != nil {
|
||||
_ = removeExtensionDir(effectiveName)
|
||||
return nil, fmt.Errorf("创建扩展失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -647,6 +784,31 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
return extension, nil
|
||||
}
|
||||
|
||||
// downloadFileToLocal 下载远程文件到本地路径
|
||||
func downloadFileToLocal(client *http.Client, url string, localPath string) error {
|
||||
// 确保目标文件的父目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(localPath, data, 0644)
|
||||
}
|
||||
|
||||
// UpgradeExtension 升级扩展版本(根据安装来源自动选择更新方式)
|
||||
func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force bool) (*app.AIExtension, error) {
|
||||
// 获取扩展信息
|
||||
@@ -672,7 +834,7 @@ func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force boo
|
||||
}
|
||||
}
|
||||
|
||||
// updateExtensionFromGit 从 Git 仓库更新扩展
|
||||
// updateExtensionFromGit 从 Git 仓库更新扩展(先删除旧记录和文件,再重新安装)
|
||||
func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.AIExtension, force bool) (*app.AIExtension, error) {
|
||||
if extension.SourceURL == "" {
|
||||
return nil, errors.New("缺少 Git 仓库 URL")
|
||||
@@ -683,11 +845,17 @@ func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.A
|
||||
zap.String("sourceUrl", extension.SourceURL),
|
||||
zap.String("branch", extension.Branch))
|
||||
|
||||
// 重新克隆(简单方式,避免处理本地修改)
|
||||
// 先删除旧的数据库记录和本地文件
|
||||
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
|
||||
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
|
||||
}
|
||||
_ = removeExtensionDir(extension.Name)
|
||||
|
||||
// 重新克隆安装
|
||||
return es.InstallExtensionFromGit(userID, extension.SourceURL, extension.Branch)
|
||||
}
|
||||
|
||||
// updateExtensionFromURL 从 URL 更新扩展(重新下载 manifest.json)
|
||||
// updateExtensionFromURL 从 URL 更新扩展(先删除旧记录和文件,再重新下载安装)
|
||||
func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.AIExtension) (*app.AIExtension, error) {
|
||||
if extension.SourceURL == "" {
|
||||
return nil, errors.New("缺少 Manifest URL")
|
||||
@@ -697,18 +865,24 @@ func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.A
|
||||
zap.String("name", extension.Name),
|
||||
zap.String("sourceUrl", extension.SourceURL))
|
||||
|
||||
// 重新下载并安装
|
||||
// 先删除旧的数据库记录和本地文件
|
||||
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
|
||||
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
|
||||
}
|
||||
_ = removeExtensionDir(extension.Name)
|
||||
|
||||
// 重新下载安装
|
||||
return es.downloadAndInstallFromManifestURL(userID, extension.SourceURL)
|
||||
}
|
||||
|
||||
// InstallExtensionFromGit 从 Git URL 安装扩展
|
||||
// InstallExtensionFromGit 从 Git URL 安装扩展(与原版 SillyTavern 一致:将源码下载到本地)
|
||||
func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch string) (*app.AIExtension, error) {
|
||||
// 验证 Git URL
|
||||
if !strings.Contains(gitUrl, "://") && !strings.HasSuffix(gitUrl, ".git") {
|
||||
return nil, errors.New("无效的 Git URL")
|
||||
}
|
||||
|
||||
// 创建临时目录
|
||||
// 先 clone 到临时目录读取 manifest(获取扩展名后再移动到正式目录)
|
||||
tempDir, err := os.MkdirTemp("", "extension-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建临时目录失败: %w", err)
|
||||
@@ -717,10 +891,9 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
|
||||
|
||||
global.GVA_LOG.Info("开始从 Git 克隆扩展",
|
||||
zap.String("gitUrl", gitUrl),
|
||||
zap.String("branch", branch),
|
||||
zap.String("tempDir", tempDir))
|
||||
zap.String("branch", branch))
|
||||
|
||||
// 执行 git clone
|
||||
// 执行 git clone(浅克隆)
|
||||
cmd := exec.Command("git", "clone", "--depth=1", "--branch="+branch, gitUrl, tempDir)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
@@ -744,31 +917,53 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
|
||||
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取有效名称(兼容 SillyTavern manifest 没有 name 字段的情况)
|
||||
effectiveName := manifest.GetEffectiveName()
|
||||
if effectiveName == "" {
|
||||
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
|
||||
}
|
||||
|
||||
// 检查扩展是否已存在
|
||||
var existing app.AIExtension
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将扩展文件保存到公共目录: web-app/public/scripts/extensions/third-party/{extensionName}/
|
||||
storePath, err := ensureExtensionDir(effectiveName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 清空目标目录(如果有残留文件)后复制 clone 内容
|
||||
_ = os.RemoveAll(storePath)
|
||||
if err := copyDir(tempDir, storePath); err != nil {
|
||||
return nil, fmt.Errorf("保存扩展文件失败: %w", err)
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("扩展文件已保存到本地",
|
||||
zap.String("name", effectiveName),
|
||||
zap.String("path", storePath))
|
||||
|
||||
// 将 manifest 转换为 map[string]interface{}
|
||||
var manifestMap map[string]interface{}
|
||||
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
|
||||
return nil, fmt.Errorf("转换 manifest 失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建创建请求
|
||||
// 构建创建请求(使用兼容方法获取字段值)
|
||||
createReq := &request.CreateExtensionRequest{
|
||||
Name: manifest.Name,
|
||||
Name: effectiveName,
|
||||
DisplayName: manifest.DisplayName,
|
||||
Version: manifest.Version,
|
||||
Author: manifest.Author,
|
||||
Description: manifest.Description,
|
||||
Homepage: manifest.Homepage,
|
||||
Repository: manifest.Repository, // 使用 manifest 中的 repository
|
||||
Homepage: manifest.GetEffectiveHomepage(),
|
||||
Repository: manifest.Repository,
|
||||
License: manifest.License,
|
||||
Tags: manifest.Tags,
|
||||
ExtensionType: manifest.Type,
|
||||
@@ -776,28 +971,69 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
|
||||
Dependencies: manifest.Dependencies,
|
||||
Conflicts: manifest.Conflicts,
|
||||
ManifestData: manifestMap,
|
||||
ScriptPath: manifest.Entry,
|
||||
StylePath: manifest.Style,
|
||||
ScriptPath: manifest.GetEffectiveEntry(),
|
||||
StylePath: manifest.GetEffectiveStyle(),
|
||||
AssetsPaths: manifest.Assets,
|
||||
Settings: manifest.Settings,
|
||||
Options: manifest.Options,
|
||||
InstallSource: "git",
|
||||
SourceURL: gitUrl, // 记录 Git URL 用于更新
|
||||
Branch: branch, // 记录分支
|
||||
SourceURL: gitUrl,
|
||||
Branch: branch,
|
||||
AutoUpdate: manifest.AutoUpdate,
|
||||
Metadata: manifest.Metadata,
|
||||
}
|
||||
|
||||
// 确保扩展类型有效
|
||||
if createReq.ExtensionType == "" {
|
||||
createReq.ExtensionType = "ui"
|
||||
}
|
||||
|
||||
// 创建扩展记录
|
||||
extension, err := es.CreateExtension(userID, createReq)
|
||||
if err != nil {
|
||||
// 创建失败则清理本地文件
|
||||
_ = removeExtensionDir(effectiveName)
|
||||
return nil, fmt.Errorf("创建扩展记录失败: %w", err)
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("从 Git 安装扩展成功",
|
||||
zap.Uint("extensionID", extension.ID),
|
||||
zap.String("name", extension.Name),
|
||||
zap.String("version", extension.Version))
|
||||
zap.String("version", extension.Version),
|
||||
zap.String("localPath", storePath))
|
||||
|
||||
return extension, nil
|
||||
}
|
||||
|
||||
// copyDir 递归复制目录(排除 .git 目录以节省空间)
|
||||
func copyDir(src, dst string) error {
|
||||
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 计算相对路径
|
||||
relPath, err := filepath.Rel(src, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 排除 .git 目录
|
||||
if info.IsDir() && info.Name() == ".git" {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
dstPath := filepath.Join(dst, relPath)
|
||||
|
||||
if info.IsDir() {
|
||||
return os.MkdirAll(dstPath, info.Mode())
|
||||
}
|
||||
|
||||
// 复制文件
|
||||
srcFile, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dstPath, srcFile, info.Mode())
|
||||
})
|
||||
}
|
||||
|
||||
949
server/service/app/provider.go
Normal file
949
server/service/app/provider.go
Normal file
@@ -0,0 +1,949 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/app/response"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProviderService struct{}
|
||||
|
||||
// ==================== 提供商 CRUD ====================
|
||||
|
||||
// CreateProvider 创建AI提供商
|
||||
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||||
// 加密 API Key
|
||||
encryptedKey := encryptAPIKey(req.APIKey)
|
||||
|
||||
// 序列化额外配置
|
||||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||||
if req.APIConfig == nil {
|
||||
apiConfigJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// 根据类型确定能力
|
||||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||||
capJSON, _ := json.Marshal(capabilities)
|
||||
|
||||
// 如果 BaseURL 为空,使用默认地址
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
provider := app.AIProvider{
|
||||
UserID: &userID,
|
||||
ProviderName: req.ProviderName,
|
||||
ProviderType: req.ProviderType,
|
||||
BaseURL: baseURL,
|
||||
APIKey: encryptedKey,
|
||||
APIConfig: datatypes.JSON(apiConfigJSON),
|
||||
Capabilities: datatypes.JSON(capJSON),
|
||||
IsEnabled: true,
|
||||
IsDefault: false,
|
||||
}
|
||||
|
||||
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建提供商
|
||||
if err := tx.Create(&provider).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果携带了模型列表,同时创建模型
|
||||
if len(req.Models) > 0 {
|
||||
for _, m := range req.Models {
|
||||
modelConfigJSON, _ := json.Marshal(m.Config)
|
||||
if m.Config == nil {
|
||||
modelConfigJSON = []byte("{}")
|
||||
}
|
||||
isEnabled := true
|
||||
if m.IsEnabled != nil {
|
||||
isEnabled = *m.IsEnabled
|
||||
}
|
||||
model := app.AIModel{
|
||||
ProviderID: provider.ID,
|
||||
ModelName: m.ModelName,
|
||||
DisplayName: m.DisplayName,
|
||||
ModelType: m.ModelType,
|
||||
Config: datatypes.JSON(modelConfigJSON),
|
||||
IsEnabled: isEnabled,
|
||||
}
|
||||
if err := tx.Create(&model).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 没有指定模型时,自动添加预设模型
|
||||
presets := getPresetModels(req.ProviderType)
|
||||
for _, p := range presets {
|
||||
model := app.AIModel{
|
||||
ProviderID: provider.ID,
|
||||
ModelName: p.ModelName,
|
||||
DisplayName: p.DisplayName,
|
||||
ModelType: p.ModelType,
|
||||
Config: datatypes.JSON([]byte("{}")),
|
||||
IsEnabled: true,
|
||||
}
|
||||
if err := tx.Create(&model).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是用户的第一个提供商,自动设为默认
|
||||
var count int64
|
||||
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
|
||||
if count == 0 {
|
||||
provider.IsDefault = true
|
||||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
return ps.GetProviderDetail(provider.ID, userID)
|
||||
}
|
||||
|
||||
// GetProviderList 获取用户的提供商列表
|
||||
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
|
||||
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
|
||||
|
||||
if req.Keyword != "" {
|
||||
keyword := "%" + req.Keyword + "%"
|
||||
db = db.Where("provider_name ILIKE ?", keyword)
|
||||
}
|
||||
|
||||
var total int64
|
||||
db.Count(&total)
|
||||
|
||||
var providers []app.AIProvider
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
|
||||
Offset(offset).Limit(req.PageSize).Find(&providers).Error
|
||||
if err != nil {
|
||||
return response.ProviderListResponse{}, err
|
||||
}
|
||||
|
||||
// 获取所有提供商的模型
|
||||
providerIDs := make([]uint, len(providers))
|
||||
for i, p := range providers {
|
||||
providerIDs[i] = p.ID
|
||||
}
|
||||
|
||||
var models []app.AIModel
|
||||
if len(providerIDs) > 0 {
|
||||
global.GVA_DB.Where("provider_id IN ?", providerIDs).
|
||||
Order("model_type ASC, model_name ASC").Find(&models)
|
||||
}
|
||||
|
||||
// 按提供商ID分组模型
|
||||
modelMap := make(map[uint][]app.AIModel)
|
||||
for _, m := range models {
|
||||
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
|
||||
}
|
||||
|
||||
list := make([]response.ProviderResponse, len(providers))
|
||||
for i, p := range providers {
|
||||
list[i] = toProviderResponse(&p, modelMap[p.ID])
|
||||
}
|
||||
|
||||
return response.ProviderListResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderDetail 获取提供商详情
|
||||
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
var models []app.AIModel
|
||||
global.GVA_DB.Where("provider_id = ?", providerID).
|
||||
Order("model_type ASC, model_name ASC").Find(&models)
|
||||
|
||||
return toProviderResponse(&provider, models), nil
|
||||
}
|
||||
|
||||
// UpdateProvider 更新提供商
|
||||
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
updates := map[string]interface{}{
|
||||
"provider_name": req.ProviderName,
|
||||
"provider_type": req.ProviderType,
|
||||
"base_url": req.BaseURL,
|
||||
}
|
||||
|
||||
// APIKey 不为空时才更新
|
||||
if req.APIKey != "" {
|
||||
updates["api_key"] = encryptAPIKey(req.APIKey)
|
||||
}
|
||||
|
||||
if req.APIConfig != nil {
|
||||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||||
updates["api_config"] = datatypes.JSON(apiConfigJSON)
|
||||
}
|
||||
|
||||
if req.IsEnabled != nil {
|
||||
updates["is_enabled"] = *req.IsEnabled
|
||||
}
|
||||
|
||||
if req.SortOrder != nil {
|
||||
updates["sort_order"] = *req.SortOrder
|
||||
}
|
||||
|
||||
// 更新能力
|
||||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||||
capJSON, _ := json.Marshal(capabilities)
|
||||
updates["capabilities"] = datatypes.JSON(capJSON)
|
||||
|
||||
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理设置默认
|
||||
if req.IsDefault != nil && *req.IsDefault {
|
||||
// 先取消其他默认
|
||||
if err := tx.Model(&app.AIProvider{}).
|
||||
Where("user_id = ? AND id != ?", userID, req.ID).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
return ps.GetProviderDetail(req.ID, userID)
|
||||
}
|
||||
|
||||
// DeleteProvider 删除提供商
|
||||
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除关联的模型
|
||||
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除提供商
|
||||
if err := tx.Delete(&provider).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果删除的是默认提供商,自动将第一个提供商设为默认
|
||||
if provider.IsDefault {
|
||||
var firstProvider app.AIProvider
|
||||
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
|
||||
tx.Model(&firstProvider).Update("is_default", true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultProvider 设置默认提供商
|
||||
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 先取消所有默认
|
||||
if err := tx.Model(&app.AIProvider{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置新默认
|
||||
return tx.Model(&provider).Update("is_default", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 模型 CRUD ====================
|
||||
|
||||
// AddModel 为提供商添加模型
|
||||
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
|
||||
// 验证提供商归属
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.ModelResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
if req.Config == nil {
|
||||
configJSON = []byte("{}")
|
||||
}
|
||||
|
||||
isEnabled := true
|
||||
if req.IsEnabled != nil {
|
||||
isEnabled = *req.IsEnabled
|
||||
}
|
||||
|
||||
model := app.AIModel{
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
DisplayName: req.DisplayName,
|
||||
ModelType: req.ModelType,
|
||||
Config: datatypes.JSON(configJSON),
|
||||
IsEnabled: isEnabled,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(&model).Error; err != nil {
|
||||
return response.ModelResponse{}, err
|
||||
}
|
||||
|
||||
return toModelResponse(&model), nil
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
|
||||
var model app.AIModel
|
||||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||||
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return response.ModelResponse{}, errors.New("模型不存在")
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"model_name": req.ModelName,
|
||||
"display_name": req.DisplayName,
|
||||
"model_type": req.ModelType,
|
||||
}
|
||||
|
||||
if req.Config != nil {
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
updates["config"] = datatypes.JSON(configJSON)
|
||||
}
|
||||
|
||||
if req.IsEnabled != nil {
|
||||
updates["is_enabled"] = *req.IsEnabled
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
|
||||
return response.ModelResponse{}, err
|
||||
}
|
||||
|
||||
// 重新查询
|
||||
global.GVA_DB.First(&model, model.ID)
|
||||
return toModelResponse(&model), nil
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
|
||||
var model app.AIModel
|
||||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||||
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return errors.New("模型不存在")
|
||||
}
|
||||
|
||||
return global.GVA_DB.Delete(&model).Error
|
||||
}
|
||||
|
||||
// ==================== 连通性测试 ====================
|
||||
|
||||
// TestProvider 测试提供商连通性
|
||||
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
|
||||
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TestExistingProvider 测试已保存的提供商连通性
|
||||
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.TestProvider(request.TestProviderRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 辅助查询 ====================
|
||||
|
||||
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
|
||||
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
|
||||
return []response.ProviderTypeOption{
|
||||
{
|
||||
Value: "openai",
|
||||
Label: "OpenAI",
|
||||
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
|
||||
DefaultURL: "https://api.openai.com/v1",
|
||||
},
|
||||
{
|
||||
Value: "claude",
|
||||
Label: "Claude",
|
||||
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
|
||||
DefaultURL: "https://api.anthropic.com",
|
||||
},
|
||||
{
|
||||
Value: "gemini",
|
||||
Label: "Google Gemini",
|
||||
Description: "Google 的 Gemini 系列模型,支持多模态",
|
||||
DefaultURL: "https://generativelanguage.googleapis.com",
|
||||
},
|
||||
{
|
||||
Value: "custom",
|
||||
Label: "自定义(OpenAI 兼容)",
|
||||
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
|
||||
DefaultURL: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetPresetModels 获取指定提供商类型的预设模型列表
|
||||
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
|
||||
presets := getPresetModels(providerType)
|
||||
result := make([]response.PresetModelOption, len(presets))
|
||||
for i, p := range presets {
|
||||
result[i] = response.PresetModelOption{
|
||||
ModelName: p.ModelName,
|
||||
DisplayName: p.DisplayName,
|
||||
ModelType: p.ModelType,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
|
||||
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
|
||||
First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("请先配置 AI 接口")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// GetDecryptedAPIKey 获取解密后的API密钥(内部方法,给AI调用用)
|
||||
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
|
||||
return decryptAPIKey(provider.APIKey)
|
||||
}
|
||||
|
||||
// ==================== 内部辅助函数 ====================
|
||||
|
||||
// toProviderResponse 转换为响应对象
|
||||
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
|
||||
apiConfig := json.RawMessage(p.APIConfig)
|
||||
if len(apiConfig) == 0 {
|
||||
apiConfig = json.RawMessage("{}")
|
||||
}
|
||||
|
||||
capabilities := json.RawMessage(p.Capabilities)
|
||||
if len(capabilities) == 0 {
|
||||
capabilities = json.RawMessage("[]")
|
||||
}
|
||||
|
||||
// 模型列表
|
||||
modelList := make([]response.ModelResponse, len(models))
|
||||
for i, m := range models {
|
||||
modelList[i] = toModelResponse(&m)
|
||||
}
|
||||
|
||||
// API Key 提示
|
||||
apiKeyHint := ""
|
||||
apiKeySet := false
|
||||
if p.APIKey != "" {
|
||||
apiKeySet = true
|
||||
decrypted := decryptAPIKey(p.APIKey)
|
||||
if len(decrypted) > 8 {
|
||||
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
|
||||
} else if len(decrypted) > 0 {
|
||||
apiKeyHint = "****"
|
||||
}
|
||||
}
|
||||
|
||||
return response.ProviderResponse{
|
||||
ID: p.ID,
|
||||
ProviderName: p.ProviderName,
|
||||
ProviderType: p.ProviderType,
|
||||
BaseURL: p.BaseURL,
|
||||
APIKeySet: apiKeySet,
|
||||
APIKeyHint: apiKeyHint,
|
||||
APIConfig: apiConfig,
|
||||
Capabilities: capabilities,
|
||||
IsEnabled: p.IsEnabled,
|
||||
IsDefault: p.IsDefault,
|
||||
SortOrder: p.SortOrder,
|
||||
Models: modelList,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// toModelResponse 转换模型为响应对象
|
||||
func toModelResponse(m *app.AIModel) response.ModelResponse {
|
||||
config := json.RawMessage(m.Config)
|
||||
if len(config) == 0 {
|
||||
config = json.RawMessage("{}")
|
||||
}
|
||||
return response.ModelResponse{
|
||||
ID: m.ID,
|
||||
ProviderID: m.ProviderID,
|
||||
ModelName: m.ModelName,
|
||||
DisplayName: m.DisplayName,
|
||||
ModelType: m.ModelType,
|
||||
Config: config,
|
||||
IsEnabled: m.IsEnabled,
|
||||
CreatedAt: m.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// encryptAPIKey 加密API密钥
|
||||
// TODO: 后续可以替换为更安全的加密方式(如 AES),当前使用简单的 Base64 编码
|
||||
func encryptAPIKey(key string) string {
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
// 简单的混淆处理,生产环境应替换为 AES 加密
|
||||
import_encoding := []byte(key)
|
||||
for i := range import_encoding {
|
||||
import_encoding[i] ^= 0x5A
|
||||
}
|
||||
return fmt.Sprintf("enc:%x", import_encoding)
|
||||
}
|
||||
|
||||
// decryptAPIKey 解密API密钥
|
||||
func decryptAPIKey(encrypted string) string {
|
||||
if encrypted == "" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(encrypted, "enc:") {
|
||||
return encrypted // 未加密的旧数据,直接返回
|
||||
}
|
||||
|
||||
hexStr := encrypted[4:]
|
||||
var data []byte
|
||||
fmt.Sscanf(hexStr, "%x", &data)
|
||||
for i := range data {
|
||||
data[i] ^= 0x5A
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// getDefaultBaseURL 获取默认API基础地址
|
||||
func getDefaultBaseURL(providerType string) string {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return "https://api.openai.com/v1"
|
||||
case "claude":
|
||||
return "https://api.anthropic.com"
|
||||
case "gemini":
|
||||
return "https://generativelanguage.googleapis.com"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultCapabilities 获取默认能力列表
|
||||
func getDefaultCapabilities(providerType string) []string {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return []string{"chat", "image_gen"}
|
||||
case "claude":
|
||||
return []string{"chat"}
|
||||
case "gemini":
|
||||
return []string{"chat", "image_gen"}
|
||||
case "custom":
|
||||
return []string{"chat"}
|
||||
default:
|
||||
return []string{"chat"}
|
||||
}
|
||||
}
|
||||
|
||||
// presetModel 预设模型内部结构
|
||||
type presetModel struct {
|
||||
ModelName string
|
||||
DisplayName string
|
||||
ModelType string
|
||||
}
|
||||
|
||||
// getPresetModels 获取预设模型列表
|
||||
func getPresetModels(providerType string) []presetModel {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return []presetModel{
|
||||
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
|
||||
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
|
||||
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
|
||||
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
|
||||
}
|
||||
case "claude":
|
||||
return []presetModel{
|
||||
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
|
||||
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
|
||||
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
|
||||
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
|
||||
}
|
||||
case "gemini":
|
||||
return []presetModel{
|
||||
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
|
||||
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
|
||||
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
|
||||
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
|
||||
}
|
||||
case "custom":
|
||||
return []presetModel{} // 自定义不提供预设
|
||||
default:
|
||||
return []presetModel{}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 获取远程模型列表 ====================
|
||||
|
||||
// FetchRemoteModels 从远程API获取可用模型列表
|
||||
// 所有提供商类型统一使用 baseURL + /models 端点(OpenAI 兼容格式)
|
||||
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
result := fetchModelsUniversal(baseURL, req.APIKey)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
|
||||
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 发送测试消息 ====================
|
||||
|
||||
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
|
||||
// 所有提供商类型统一使用 baseURL + /chat/completions 端点(OpenAI 兼容格式)
|
||||
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
message := req.Message
|
||||
if message == "" {
|
||||
message = "你好,请用一句话介绍你自己。"
|
||||
}
|
||||
|
||||
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SendTestMessageExisting 发送测试消息(已保存的提供商)
|
||||
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.SendTestMessage(request.SendTestMessageRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
ModelName: modelName,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 连通性测试实现 ====================
|
||||
|
||||
// testOpenAICompatible 测试 OpenAI 兼容接口
|
||||
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return response.TestProviderResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// 解析模型列表
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var modelNames []string
|
||||
if err := json.Unmarshal(body, &modelsResp); err == nil {
|
||||
for _, m := range modelsResp.Data {
|
||||
modelNames = append(modelNames, m.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return response.TestProviderResponse{
|
||||
Success: true,
|
||||
Message: "连接成功",
|
||||
Models: modelNames,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 获取远程模型列表实现 ====================
|
||||
|
||||
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
|
||||
// 使用 baseURL + /models 端点,Authorization: Bearer 鉴权
|
||||
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return response.FetchRemoteModelsResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
|
||||
var modelsData struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var models []response.RemoteModel
|
||||
if err := json.Unmarshal(body, &modelsData); err == nil {
|
||||
for _, m := range modelsData.Data {
|
||||
models = append(models, response.RemoteModel{
|
||||
ID: m.ID,
|
||||
OwnedBy: m.OwnedBy,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return response.FetchRemoteModelsResponse{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
|
||||
Models: models,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 发送测试消息实现 ====================
|
||||
|
||||
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
|
||||
// 使用 baseURL + /chat/completions 端点,Authorization: Bearer 鉴权
|
||||
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"max_tokens": 100,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": message},
|
||||
},
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var errResp struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
|
||||
}
|
||||
return response.SendTestMessageResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 OpenAI 兼容格式的 chat completion 响应
|
||||
var chatResp struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
|
||||
}
|
||||
|
||||
reply := ""
|
||||
if len(chatResp.Choices) > 0 {
|
||||
reply = chatResp.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return response.SendTestMessageResponse{
|
||||
Success: true,
|
||||
Message: "测试成功",
|
||||
Reply: reply,
|
||||
Model: chatResp.Model,
|
||||
Tokens: chatResp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user