🎨 优化扩展模块,完成ai接入和对话功能

This commit is contained in:
2026-02-12 23:12:28 +08:00
parent 4e611d3a5e
commit 572f3aa15b
779 changed files with 194400 additions and 3136 deletions

View File

@@ -0,0 +1,625 @@
package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
appModel "git.echol.cn/loser/st/server/model/app"
)
// AIMessage AI调用的消息格式统一
type AIMessagePayload struct {
Role string `json:"role"`
Content string `json:"content"`
}
// AIStreamChunk 流式响应的单个块
type AIStreamChunk struct {
Content string `json:"content"` // 增量文本
Done bool `json:"done"` // 是否结束
Model string `json:"model"` // 使用的模型
PromptTokens int `json:"promptTokens"` // 提示词Token仅结束时有值
CompletionTokens int `json:"completionTokens"` // 补全Token仅结束时有值
Error string `json:"error"` // 错误信息
}
// BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt
func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload {
var payload []AIMessagePayload
// 1. 系统提示词System Prompt
systemPrompt := buildSystemPrompt(character)
if systemPrompt != "" {
payload = append(payload, AIMessagePayload{
Role: "system",
Content: systemPrompt,
})
}
// 2. 历史消息
for _, msg := range messages {
role := msg.Role
if role == "assistant" || role == "user" || role == "system" {
payload = append(payload, AIMessagePayload{
Role: role,
Content: msg.Content,
})
}
}
return payload
}
// buildSystemPrompt 构建系统提示词
func buildSystemPrompt(character *appModel.AICharacter) string {
if character == nil {
return ""
}
var parts []string
// 角色描述
if character.Description != "" {
parts = append(parts, character.Description)
}
// 角色性格
if character.Personality != "" {
parts = append(parts, "Personality: "+character.Personality)
}
// 场景
if character.Scenario != "" {
parts = append(parts, "Scenario: "+character.Scenario)
}
// 系统提示词
if character.SystemPrompt != "" {
parts = append(parts, character.SystemPrompt)
}
// 示例消息
if len(character.ExampleMessages) > 0 {
parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n"))
}
return strings.Join(parts, "\n\n")
}
// StreamAIResponse 流式调用 AI 并通过 channel 返回结果
func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
defer close(ch)
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch)
case "claude":
streamClaude(ctx, baseURL, apiKey, modelName, messages, ch)
case "gemini":
streamGemini(ctx, baseURL, apiKey, modelName, messages, ch)
default:
ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true}
}
}
// ==================== OpenAI Compatible ====================
func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
"stream": true,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
var fullContent string
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""}
return
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
content := chunk.Choices[0].Delta.Content
fullContent += content
ch <- AIStreamChunk{Content: content}
}
if chunk.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.Usage.PromptTokens,
CompletionTokens: chunk.Usage.CompletionTokens,
}
return
}
}
// 如果扫描结束但没收到 [DONE]
if fullContent != "" {
ch <- AIStreamChunk{Done: true, Model: modelName}
}
}
// ==================== Claude ====================
func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
// Claude 需要把 system 消息分离出来
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
// 确保至少有一条消息
if len(claudeMessages) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
"stream": true,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var event struct {
Type string `json:"type"`
Delta *struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"delta"`
Usage *struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
switch event.Type {
case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
ch <- AIStreamChunk{Content: event.Delta.Text}
}
case "message_delta":
if event.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens,
}
return
}
case "message_stop":
ch <- AIStreamChunk{Done: true, Model: modelName}
return
case "error":
ch <- AIStreamChunk{Error: "Claude API 错误", Done: true}
return
}
}
}
// ==================== Gemini ====================
func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
// 构建 Gemini 格式的消息
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{
{"text": msg.Content},
},
})
}
if len(contents) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"contents": contents,
}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{
{"text": strings.TrimSpace(systemInstruction)},
},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
// Gemini 返回较大的 chunks增大 buffer
scanner.Buffer(make([]byte, 0), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var chunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata *struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
for _, candidate := range chunk.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
ch <- AIStreamChunk{Content: part.Text}
}
}
}
if chunk.UsageMetadata != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
}
return
}
}
ch <- AIStreamChunk{Done: true, Model: modelName}
}
// ==================== 非流式调用(用于生图等) ====================
// CallAINonStream 非流式调用 AI
func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) {
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages)
case "claude":
return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages)
case "gemini":
return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages)
default:
return "", errors.New("不支持的提供商类型: " + provider.ProviderType)
}
}
func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Choices) > 0 {
return result.Choices[0].Message.Content, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Content) > 0 {
return result.Content[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{{"text": msg.Content}},
})
}
body := map[string]interface{}{"contents": contents}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 {
return result.Candidates[0].Content.Parts[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}

408
server/service/app/chat.go Normal file
View File

@@ -0,0 +1,408 @@
package app
import (
"errors"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/app/response"
"gorm.io/gorm"
)
type ChatService struct{}
// ==================== 对话 CRUD ====================
// CreateChat 创建对话
func (cs *ChatService) CreateChat(req request.CreateChatRequest, userID uint) (response.ChatResponse, error) {
// 获取角色卡信息
var character app.AICharacter
err := global.GVA_DB.First(&character, req.CharacterID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ChatResponse{}, errors.New("角色卡不存在")
}
return response.ChatResponse{}, err
}
// 对话标题
title := req.Title
if title == "" {
title = character.Name
}
now := time.Now()
chat := app.AIChat{
Title: title,
UserID: userID,
CharacterID: &req.CharacterID,
ChatType: "single",
LastMessageAt: &now,
}
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 创建对话
if err := tx.Create(&chat).Error; err != nil {
return err
}
// 如果角色卡有 FirstMessage自动创建第一条系统消息
if character.FirstMessage != "" {
firstMsg := app.AIMessage{
ChatID: chat.ID,
Content: character.FirstMessage,
Role: "assistant",
CharacterID: &req.CharacterID,
SequenceNumber: 1,
}
if err := tx.Create(&firstMsg).Error; err != nil {
return err
}
chat.MessageCount = 1
tx.Model(&chat).Update("message_count", 1)
}
// 更新角色卡使用次数
tx.Model(&character).Update("total_chats", gorm.Expr("total_chats + 1"))
return nil
})
if err != nil {
return response.ChatResponse{}, err
}
return toChatResponse(&chat, &character, nil), nil
}
// GetChatList 获取用户的对话列表
func (cs *ChatService) GetChatList(req request.ChatListRequest, userID uint) (response.ChatListResponse, error) {
db := global.GVA_DB.Model(&app.AIChat{}).Where("user_id = ?", userID)
var total int64
db.Count(&total)
var chats []app.AIChat
offset := (req.Page - 1) * req.PageSize
err := db.Preload("Character").
Order("is_pinned DESC, last_message_at DESC").
Offset(offset).Limit(req.PageSize).
Find(&chats).Error
if err != nil {
return response.ChatListResponse{}, err
}
// 获取每个对话的最后一条消息
chatIDs := make([]uint, len(chats))
for i, c := range chats {
chatIDs[i] = c.ID
}
lastMessages := make(map[uint]*response.MessageBrief)
if len(chatIDs) > 0 {
var messages []app.AIMessage
// 获取每个对话的最后一条消息(通过子查询)
global.GVA_DB.Where("chat_id IN ? AND is_deleted = ?", chatIDs, false).
Order("sequence_number DESC").
Find(&messages)
// 只保留每个对话的最后一条
for _, msg := range messages {
if _, exists := lastMessages[msg.ChatID]; !exists {
content := msg.Content
if len(content) > 100 {
content = content[:100] + "..."
}
lastMessages[msg.ChatID] = &response.MessageBrief{
Content: content,
Role: msg.Role,
}
}
}
}
list := make([]response.ChatResponse, len(chats))
for i, c := range chats {
list[i] = toChatResponse(&c, c.Character, lastMessages[c.ID])
}
return response.ChatListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// GetChatDetail 获取对话详情(包含消息列表)
func (cs *ChatService) GetChatDetail(chatID uint, userID uint) (response.ChatDetailResponse, error) {
var chat app.AIChat
err := global.GVA_DB.Preload("Character").
Where("id = ? AND user_id = ?", chatID, userID).
First(&chat).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ChatDetailResponse{}, errors.New("对话不存在")
}
return response.ChatDetailResponse{}, err
}
// 获取消息列表最近的50条
var messages []app.AIMessage
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
Order("sequence_number ASC").
Limit(50).
Find(&messages)
msgList := make([]response.MessageResponse, len(messages))
for i, msg := range messages {
msgList[i] = toMessageResponse(&msg, chat.Character)
}
return response.ChatDetailResponse{
Chat: toChatResponse(&chat, chat.Character, nil),
Messages: msgList,
}, nil
}
// GetChatMessages 分页获取对话消息
func (cs *ChatService) GetChatMessages(chatID uint, req request.ChatMessagesRequest, userID uint) (response.MessageListResponse, error) {
// 验证对话归属
var chat app.AIChat
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
if err != nil {
return response.MessageListResponse{}, errors.New("对话不存在")
}
db := global.GVA_DB.Model(&app.AIMessage{}).
Where("chat_id = ? AND is_deleted = ?", chatID, false)
var total int64
db.Count(&total)
var messages []app.AIMessage
offset := (req.Page - 1) * req.PageSize
err = db.Order("sequence_number ASC").
Offset(offset).Limit(req.PageSize).
Find(&messages).Error
if err != nil {
return response.MessageListResponse{}, err
}
// 预加载角色信息
var character *app.AICharacter
if chat.CharacterID != nil {
var c app.AICharacter
global.GVA_DB.First(&c, *chat.CharacterID)
character = &c
}
list := make([]response.MessageResponse, len(messages))
for i, msg := range messages {
list[i] = toMessageResponse(&msg, character)
}
return response.MessageListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// DeleteChat 删除对话
func (cs *ChatService) DeleteChat(chatID uint, userID uint) error {
var chat app.AIChat
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("对话不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 删除消息
if err := tx.Where("chat_id = ?", chatID).Delete(&app.AIMessage{}).Error; err != nil {
return err
}
// 删除消息变体
if err := tx.Where("message_id IN (?)",
tx.Model(&app.AIMessage{}).Select("id").Where("chat_id = ?", chatID),
).Delete(&app.AIMessageSwipe{}).Error; err != nil {
// 忽略错误,可能没有变体
}
// 删除对话
return tx.Delete(&chat).Error
})
}
// ==================== 消息操作 ====================
// SaveUserMessage 保存用户消息(内部方法)
func (cs *ChatService) SaveUserMessage(chatID uint, userID uint, content string) (*app.AIMessage, error) {
// 获取当前最大序号
var maxSeq int
global.GVA_DB.Model(&app.AIMessage{}).
Where("chat_id = ?", chatID).
Select("COALESCE(MAX(sequence_number), 0)").
Scan(&maxSeq)
msg := &app.AIMessage{
ChatID: chatID,
Content: content,
Role: "user",
SenderID: &userID,
SequenceNumber: maxSeq + 1,
}
if err := global.GVA_DB.Create(msg).Error; err != nil {
return nil, err
}
// 更新对话的消息数和最后消息时间
now := time.Now()
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
"message_count": gorm.Expr("message_count + 1"),
"last_message_at": now,
})
return msg, nil
}
// SaveAssistantMessage 保存AI回复消息内部方法
func (cs *ChatService) SaveAssistantMessage(chatID uint, characterID *uint, content string, model string, promptTokens, completionTokens int) (*app.AIMessage, error) {
var maxSeq int
global.GVA_DB.Model(&app.AIMessage{}).
Where("chat_id = ?", chatID).
Select("COALESCE(MAX(sequence_number), 0)").
Scan(&maxSeq)
msg := &app.AIMessage{
ChatID: chatID,
Content: content,
Role: "assistant",
CharacterID: characterID,
SequenceNumber: maxSeq + 1,
Model: model,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
if err := global.GVA_DB.Create(msg).Error; err != nil {
return nil, err
}
now := time.Now()
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
"message_count": gorm.Expr("message_count + 1"),
"last_message_at": now,
})
return msg, nil
}
// EditMessage 编辑消息
func (cs *ChatService) EditMessage(req request.EditMessageRequest, userID uint) (response.MessageResponse, error) {
var msg app.AIMessage
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
Where("ai_messages.id = ? AND ai_chats.user_id = ?", req.MessageID, userID).
First(&msg).Error
if err != nil {
return response.MessageResponse{}, errors.New("消息不存在")
}
msg.Content = req.Content
if err := global.GVA_DB.Save(&msg).Error; err != nil {
return response.MessageResponse{}, err
}
return toMessageResponse(&msg, nil), nil
}
// DeleteMessage 删除消息(软删除)
func (cs *ChatService) DeleteMessage(messageID uint, userID uint) error {
var msg app.AIMessage
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
Where("ai_messages.id = ? AND ai_chats.user_id = ?", messageID, userID).
First(&msg).Error
if err != nil {
return errors.New("消息不存在")
}
return global.GVA_DB.Model(&msg).Update("is_deleted", true).Error
}
// GetChatForAI 获取对话用于AI调用的完整上下文内部方法
func (cs *ChatService) GetChatForAI(chatID uint, userID uint) (*app.AIChat, *app.AICharacter, []app.AIMessage, error) {
var chat app.AIChat
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
if err != nil {
return nil, nil, nil, errors.New("对话不存在")
}
var character *app.AICharacter
if chat.CharacterID != nil {
var c app.AICharacter
if err := global.GVA_DB.First(&c, *chat.CharacterID).Error; err == nil {
character = &c
}
}
// 获取历史消息最近30条
var messages []app.AIMessage
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
Order("sequence_number ASC").
Limit(30).
Find(&messages)
return &chat, character, messages, nil
}
// ==================== 辅助函数 ====================
func toChatResponse(chat *app.AIChat, character *app.AICharacter, lastMsg *response.MessageBrief) response.ChatResponse {
resp := response.ChatResponse{
ID: chat.ID,
Title: chat.Title,
CharacterID: chat.CharacterID,
ChatType: chat.ChatType,
LastMessageAt: chat.LastMessageAt,
MessageCount: chat.MessageCount,
IsPinned: chat.IsPinned,
LastMessage: lastMsg,
CreatedAt: chat.CreatedAt,
}
if character != nil {
resp.CharacterName = character.Name
resp.CharacterAvatar = character.Avatar
}
return resp
}
func toMessageResponse(msg *app.AIMessage, character *app.AICharacter) response.MessageResponse {
resp := response.MessageResponse{
ID: msg.ID,
ChatID: msg.ChatID,
Content: msg.Content,
Role: msg.Role,
CharacterID: msg.CharacterID,
Model: msg.Model,
PromptTokens: msg.PromptTokens,
CompletionTokens: msg.CompletionTokens,
TotalTokens: msg.TotalTokens,
SequenceNumber: msg.SequenceNumber,
CreatedAt: msg.CreatedAt,
}
if msg.Role == "assistant" && character != nil {
resp.CharacterName = character.Name
}
return resp
}

View File

@@ -6,4 +6,6 @@ type AppServiceGroup struct {
WorldInfoService
ExtensionService
RegexScriptService
ProviderService
ChatService
}

View File

@@ -21,15 +21,70 @@ import (
"gorm.io/gorm"
)
// extensionDataDir 扩展本地存储根目录
// 与原版 SillyTavern 完全一致的路径结构scripts/extensions/third-party/{name}/
// 扩展 JS 中的相对路径 import如 ../../../../../script.js依赖此目录层级来正确解析
// 所有 SillyTavern 核心脚本和扩展文件统一存储在 data/st-core-scripts/ 下,独立于 web-app/
// 扩展代码是公共的(不按用户隔离),用户间差异仅在于数据库中的配置和启用状态
const extensionDataDir = "data/st-core-scripts/scripts/extensions/third-party"
// getExtensionStorePath 获取扩展的本地存储路径: {extensionDataDir}/{extensionName}/
func getExtensionStorePath(extensionName string) string {
return filepath.Join(extensionDataDir, extensionName)
}
// GetExtensionAssetLocalPath 获取扩展资源文件的本地绝对路径
func (es *ExtensionService) GetExtensionAssetLocalPath(extensionName string, assetPath string) (string, error) {
storePath := getExtensionStorePath(extensionName)
fullPath := filepath.Join(storePath, assetPath)
// 安全检查:防止路径遍历攻击
absStore, _ := filepath.Abs(storePath)
absFile, _ := filepath.Abs(fullPath)
if !strings.HasPrefix(absFile, absStore) {
return "", errors.New("非法的资源路径")
}
// 检查文件是否存在
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
return "", fmt.Errorf("资源文件不存在: %s", assetPath)
}
return fullPath, nil
}
// ensureExtensionDir 确保扩展存储目录存在
func ensureExtensionDir(extensionName string) (string, error) {
storePath := getExtensionStorePath(extensionName)
if err := os.MkdirAll(storePath, 0755); err != nil {
return "", fmt.Errorf("创建扩展存储目录失败: %w", err)
}
return storePath, nil
}
// removeExtensionDir 删除扩展的本地存储目录
func removeExtensionDir(extensionName string) error {
storePath := getExtensionStorePath(extensionName)
if _, err := os.Stat(storePath); os.IsNotExist(err) {
return nil // 目录不存在,无需删除
}
return os.RemoveAll(storePath)
}
type ExtensionService struct{}
// CreateExtension 创建/安装扩展
func (es *ExtensionService) CreateExtension(userID uint, req *request.CreateExtensionRequest) (*app.AIExtension, error) {
// 校验名称
if req.Name == "" {
return nil, errors.New("扩展名称不能为空")
}
// 检查扩展是否已存在
var existing app.AIExtension
err := global.GVA_DB.Where("user_id = ? AND name = ?", userID, req.Name).First(&existing).Error
if err == nil {
return nil, errors.New("扩展已存在")
return nil, fmt.Errorf("扩展 %s 已存在", req.Name)
}
if err != gorm.ErrRecordNotFound {
return nil, err
@@ -136,15 +191,18 @@ func (es *ExtensionService) DeleteExtension(userID, extensionID uint, deleteFile
return errors.New("系统内置扩展不允许删除")
}
// TODO: 如果 deleteFiles=true删除扩展文件
// 这需要文件系统支持
// 删除本地扩展文件(与原版 SillyTavern 一致:卸载扩展时清理本地文件
if err := removeExtensionDir(extension.Name); err != nil {
global.GVA_LOG.Warn("删除扩展本地文件失败", zap.Error(err), zap.String("name", extension.Name))
// 不阻断删除流程
}
// 删除扩展(配置已经在扩展记录的 Settings 字段中,无需单独删除)
// 删除数据库记录
if err := global.GVA_DB.Delete(&extension).Error; err != nil {
return err
}
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID))
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID), zap.String("name", extension.Name))
return nil
}
@@ -157,6 +215,15 @@ func (es *ExtensionService) GetExtension(userID, extensionID uint) (*app.AIExten
return &extension, nil
}
// GetExtensionByID 通过扩展 ID 获取扩展信息(不限制用户,用于公开资源路由)
func (es *ExtensionService) GetExtensionByID(extensionID uint) (*app.AIExtension, error) {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ?", extensionID).First(&extension).Error; err != nil {
return nil, errors.New("扩展不存在")
}
return &extension, nil
}
// GetExtensionList 获取扩展列表
func (es *ExtensionService) GetExtensionList(userID uint, req *request.ExtensionListRequest) (*response.ExtensionListResponse, error) {
var extensions []app.AIExtension
@@ -550,9 +617,37 @@ func isGitURL(url string) bool {
return false
}
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装
// GetExtensionAssetURL 根据扩展的安装来源构建资源文件的远程 URL
func (es *ExtensionService) GetExtensionAssetURL(extension *app.AIExtension, assetPath string) (string, error) {
if extension.SourceURL == "" {
return "", errors.New("扩展没有源地址")
}
sourceURL := strings.TrimSuffix(strings.TrimSuffix(extension.SourceURL, "/"), ".git")
branch := extension.Branch
if branch == "" {
branch = "main"
}
// GitLab: repo/-/raw/branch/path
if strings.Contains(sourceURL, "gitlab.com") {
return fmt.Sprintf("%s/-/raw/%s/%s", sourceURL, branch, assetPath), nil
}
// GitHub: raw.githubusercontent.com/user/repo/branch/path
if strings.Contains(sourceURL, "github.com") {
rawURL := strings.Replace(sourceURL, "github.com", "raw.githubusercontent.com", 1)
return fmt.Sprintf("%s/%s/%s", rawURL, branch, assetPath), nil
}
// Gitee: repo/raw/branch/path
if strings.Contains(sourceURL, "gitee.com") {
return fmt.Sprintf("%s/raw/%s/%s", sourceURL, branch, assetPath), nil
}
return fmt.Sprintf("%s/%s", sourceURL, assetPath), nil
}
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装(同时下载资源文件到本地)
func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manifestURL string) (*app.AIExtension, error) {
// 创建 HTTP 客户端
client := &http.Client{
Timeout: 30 * time.Second,
}
@@ -568,7 +663,6 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
return nil, fmt.Errorf("下载 manifest.json 失败: HTTP %d", resp.StatusCode)
}
// 读取响应内容
manifestData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取 manifest.json 失败: %w", err)
@@ -580,21 +674,63 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
// 验证必填字段
if manifest.Name == "" {
return nil, errors.New("manifest.json 缺少 name 字段")
// 获取有效名称
effectiveName := manifest.GetEffectiveName()
if effectiveName == "" {
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
}
// 检查扩展是否已存在
var existing app.AIExtension
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
if err == nil {
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 创建本地存储目录并保存 manifest.json
storePath, err := ensureExtensionDir(effectiveName)
if err != nil {
return nil, err
}
if err := os.WriteFile(filepath.Join(storePath, "manifest.json"), manifestData, 0644); err != nil {
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("保存 manifest.json 失败: %w", err)
}
// 获取 manifest URL 的基础目录(用于下载关联资源)
baseURL := manifestURL[:strings.LastIndex(manifestURL, "/")+1]
// 下载 JS/CSS 等资源文件到本地
filesToDownload := []string{}
if entry := manifest.GetEffectiveEntry(); entry != "" {
filesToDownload = append(filesToDownload, entry)
}
if style := manifest.GetEffectiveStyle(); style != "" {
filesToDownload = append(filesToDownload, style)
}
filesToDownload = append(filesToDownload, manifest.Assets...)
for _, file := range filesToDownload {
if file == "" {
continue
}
fileURL := baseURL + file
if err := downloadFileToLocal(client, fileURL, filepath.Join(storePath, file)); err != nil {
global.GVA_LOG.Warn("下载扩展资源文件失败(非致命)",
zap.String("file", file),
zap.String("url", fileURL),
zap.Error(err))
}
}
global.GVA_LOG.Info("扩展文件已保存到本地",
zap.String("name", effectiveName),
zap.String("path", storePath))
// 将 manifest 转换为 map[string]interface{}
var manifestMap map[string]interface{}
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
@@ -603,13 +739,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
// 构建创建请求
createReq := &request.CreateExtensionRequest{
Name: manifest.Name,
Name: effectiveName,
DisplayName: manifest.DisplayName,
Version: manifest.Version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.Homepage,
Repository: manifest.Repository, // 使用 manifest 中的 repository
Homepage: manifest.GetEffectiveHomepage(),
Repository: manifest.Repository,
License: manifest.License,
Tags: manifest.Tags,
ExtensionType: manifest.Type,
@@ -617,13 +753,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
Dependencies: manifest.Dependencies,
Conflicts: manifest.Conflicts,
ManifestData: manifestMap,
ScriptPath: manifest.Entry,
StylePath: manifest.Style,
ScriptPath: manifest.GetEffectiveEntry(),
StylePath: manifest.GetEffectiveStyle(),
AssetsPaths: manifest.Assets,
Settings: manifest.Settings,
Options: manifest.Options,
InstallSource: "url",
SourceURL: manifestURL, // 记录原始 URL 用于更新
SourceURL: manifestURL,
AutoUpdate: manifest.AutoUpdate,
Metadata: nil,
}
@@ -636,6 +772,7 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
// 创建扩展
extension, err := es.CreateExtension(userID, createReq)
if err != nil {
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("创建扩展失败: %w", err)
}
@@ -647,6 +784,31 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
return extension, nil
}
// downloadFileToLocal 下载远程文件到本地路径
func downloadFileToLocal(client *http.Client, url string, localPath string) error {
// 确保目标文件的父目录存在
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
return err
}
resp, err := client.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
return os.WriteFile(localPath, data, 0644)
}
// UpgradeExtension 升级扩展版本(根据安装来源自动选择更新方式)
func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force bool) (*app.AIExtension, error) {
// 获取扩展信息
@@ -672,7 +834,7 @@ func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force boo
}
}
// updateExtensionFromGit 从 Git 仓库更新扩展
// updateExtensionFromGit 从 Git 仓库更新扩展(先删除旧记录和文件,再重新安装)
func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.AIExtension, force bool) (*app.AIExtension, error) {
if extension.SourceURL == "" {
return nil, errors.New("缺少 Git 仓库 URL")
@@ -683,11 +845,17 @@ func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.A
zap.String("sourceUrl", extension.SourceURL),
zap.String("branch", extension.Branch))
// 重新克隆(简单方式,避免处理本地修改)
// 先删除旧的数据库记录和本地文件
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
}
_ = removeExtensionDir(extension.Name)
// 重新克隆安装
return es.InstallExtensionFromGit(userID, extension.SourceURL, extension.Branch)
}
// updateExtensionFromURL 从 URL 更新扩展(重新下载 manifest.json
// updateExtensionFromURL 从 URL 更新扩展(先删除旧记录和文件,再重新下载安装
func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.AIExtension) (*app.AIExtension, error) {
if extension.SourceURL == "" {
return nil, errors.New("缺少 Manifest URL")
@@ -697,18 +865,24 @@ func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.A
zap.String("name", extension.Name),
zap.String("sourceUrl", extension.SourceURL))
// 重新下载并安装
// 先删除旧的数据库记录和本地文件
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
}
_ = removeExtensionDir(extension.Name)
// 重新下载安装
return es.downloadAndInstallFromManifestURL(userID, extension.SourceURL)
}
// InstallExtensionFromGit 从 Git URL 安装扩展
// InstallExtensionFromGit 从 Git URL 安装扩展(与原版 SillyTavern 一致:将源码下载到本地)
func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch string) (*app.AIExtension, error) {
// 验证 Git URL
if !strings.Contains(gitUrl, "://") && !strings.HasSuffix(gitUrl, ".git") {
return nil, errors.New("无效的 Git URL")
}
// 创建临时目录
// 先 clone 到临时目录读取 manifest获取扩展名后再移动到正式目录
tempDir, err := os.MkdirTemp("", "extension-*")
if err != nil {
return nil, fmt.Errorf("创建临时目录失败: %w", err)
@@ -717,10 +891,9 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
global.GVA_LOG.Info("开始从 Git 克隆扩展",
zap.String("gitUrl", gitUrl),
zap.String("branch", branch),
zap.String("tempDir", tempDir))
zap.String("branch", branch))
// 执行 git clone
// 执行 git clone(浅克隆)
cmd := exec.Command("git", "clone", "--depth=1", "--branch="+branch, gitUrl, tempDir)
output, err := cmd.CombinedOutput()
if err != nil {
@@ -744,31 +917,53 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
// 获取有效名称(兼容 SillyTavern manifest 没有 name 字段的情况)
effectiveName := manifest.GetEffectiveName()
if effectiveName == "" {
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
}
// 检查扩展是否已存在
var existing app.AIExtension
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
if err == nil {
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 将扩展文件保存到公共目录: web-app/public/scripts/extensions/third-party/{extensionName}/
storePath, err := ensureExtensionDir(effectiveName)
if err != nil {
return nil, err
}
// 清空目标目录(如果有残留文件)后复制 clone 内容
_ = os.RemoveAll(storePath)
if err := copyDir(tempDir, storePath); err != nil {
return nil, fmt.Errorf("保存扩展文件失败: %w", err)
}
global.GVA_LOG.Info("扩展文件已保存到本地",
zap.String("name", effectiveName),
zap.String("path", storePath))
// 将 manifest 转换为 map[string]interface{}
var manifestMap map[string]interface{}
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
return nil, fmt.Errorf("转换 manifest 失败: %w", err)
}
// 构建创建请求
// 构建创建请求(使用兼容方法获取字段值)
createReq := &request.CreateExtensionRequest{
Name: manifest.Name,
Name: effectiveName,
DisplayName: manifest.DisplayName,
Version: manifest.Version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.Homepage,
Repository: manifest.Repository, // 使用 manifest 中的 repository
Homepage: manifest.GetEffectiveHomepage(),
Repository: manifest.Repository,
License: manifest.License,
Tags: manifest.Tags,
ExtensionType: manifest.Type,
@@ -776,28 +971,69 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
Dependencies: manifest.Dependencies,
Conflicts: manifest.Conflicts,
ManifestData: manifestMap,
ScriptPath: manifest.Entry,
StylePath: manifest.Style,
ScriptPath: manifest.GetEffectiveEntry(),
StylePath: manifest.GetEffectiveStyle(),
AssetsPaths: manifest.Assets,
Settings: manifest.Settings,
Options: manifest.Options,
InstallSource: "git",
SourceURL: gitUrl, // 记录 Git URL 用于更新
Branch: branch, // 记录分支
SourceURL: gitUrl,
Branch: branch,
AutoUpdate: manifest.AutoUpdate,
Metadata: manifest.Metadata,
}
// 确保扩展类型有效
if createReq.ExtensionType == "" {
createReq.ExtensionType = "ui"
}
// 创建扩展记录
extension, err := es.CreateExtension(userID, createReq)
if err != nil {
// 创建失败则清理本地文件
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("创建扩展记录失败: %w", err)
}
global.GVA_LOG.Info("从 Git 安装扩展成功",
zap.Uint("extensionID", extension.ID),
zap.String("name", extension.Name),
zap.String("version", extension.Version))
zap.String("version", extension.Version),
zap.String("localPath", storePath))
return extension, nil
}
// copyDir 递归复制目录(排除 .git 目录以节省空间)
func copyDir(src, dst string) error {
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 计算相对路径
relPath, err := filepath.Rel(src, path)
if err != nil {
return err
}
// 排除 .git 目录
if info.IsDir() && info.Name() == ".git" {
return filepath.SkipDir
}
dstPath := filepath.Join(dst, relPath)
if info.IsDir() {
return os.MkdirAll(dstPath, info.Mode())
}
// 复制文件
srcFile, err := os.ReadFile(path)
if err != nil {
return err
}
return os.WriteFile(dstPath, srcFile, info.Mode())
})
}

View File

@@ -0,0 +1,949 @@
package app
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/app/response"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type ProviderService struct{}
// ==================== 提供商 CRUD ====================
// CreateProvider 创建AI提供商
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
// 加密 API Key
encryptedKey := encryptAPIKey(req.APIKey)
// 序列化额外配置
apiConfigJSON, _ := json.Marshal(req.APIConfig)
if req.APIConfig == nil {
apiConfigJSON = []byte("{}")
}
// 根据类型确定能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
// 如果 BaseURL 为空,使用默认地址
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
provider := app.AIProvider{
UserID: &userID,
ProviderName: req.ProviderName,
ProviderType: req.ProviderType,
BaseURL: baseURL,
APIKey: encryptedKey,
APIConfig: datatypes.JSON(apiConfigJSON),
Capabilities: datatypes.JSON(capJSON),
IsEnabled: true,
IsDefault: false,
}
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 创建提供商
if err := tx.Create(&provider).Error; err != nil {
return err
}
// 如果携带了模型列表,同时创建模型
if len(req.Models) > 0 {
for _, m := range req.Models {
modelConfigJSON, _ := json.Marshal(m.Config)
if m.Config == nil {
modelConfigJSON = []byte("{}")
}
isEnabled := true
if m.IsEnabled != nil {
isEnabled = *m.IsEnabled
}
model := app.AIModel{
ProviderID: provider.ID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: datatypes.JSON(modelConfigJSON),
IsEnabled: isEnabled,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
} else {
// 没有指定模型时,自动添加预设模型
presets := getPresetModels(req.ProviderType)
for _, p := range presets {
model := app.AIModel{
ProviderID: provider.ID,
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
Config: datatypes.JSON([]byte("{}")),
IsEnabled: true,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
}
// 如果是用户的第一个提供商,自动设为默认
var count int64
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
if count == 0 {
provider.IsDefault = true
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(provider.ID, userID)
}
// GetProviderList 获取用户的提供商列表
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
if req.Keyword != "" {
keyword := "%" + req.Keyword + "%"
db = db.Where("provider_name ILIKE ?", keyword)
}
var total int64
db.Count(&total)
var providers []app.AIProvider
offset := (req.Page - 1) * req.PageSize
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
Offset(offset).Limit(req.PageSize).Find(&providers).Error
if err != nil {
return response.ProviderListResponse{}, err
}
// 获取所有提供商的模型
providerIDs := make([]uint, len(providers))
for i, p := range providers {
providerIDs[i] = p.ID
}
var models []app.AIModel
if len(providerIDs) > 0 {
global.GVA_DB.Where("provider_id IN ?", providerIDs).
Order("model_type ASC, model_name ASC").Find(&models)
}
// 按提供商ID分组模型
modelMap := make(map[uint][]app.AIModel)
for _, m := range models {
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
}
list := make([]response.ProviderResponse, len(providers))
for i, p := range providers {
list[i] = toProviderResponse(&p, modelMap[p.ID])
}
return response.ProviderListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// GetProviderDetail 获取提供商详情
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
var models []app.AIModel
global.GVA_DB.Where("provider_id = ?", providerID).
Order("model_type ASC, model_name ASC").Find(&models)
return toProviderResponse(&provider, models), nil
}
// UpdateProvider 更新提供商
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
// 更新字段
updates := map[string]interface{}{
"provider_name": req.ProviderName,
"provider_type": req.ProviderType,
"base_url": req.BaseURL,
}
// APIKey 不为空时才更新
if req.APIKey != "" {
updates["api_key"] = encryptAPIKey(req.APIKey)
}
if req.APIConfig != nil {
apiConfigJSON, _ := json.Marshal(req.APIConfig)
updates["api_config"] = datatypes.JSON(apiConfigJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if req.SortOrder != nil {
updates["sort_order"] = *req.SortOrder
}
// 更新能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
updates["capabilities"] = datatypes.JSON(capJSON)
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
return err
}
// 处理设置默认
if req.IsDefault != nil && *req.IsDefault {
// 先取消其他默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ? AND id != ?", userID, req.ID).
Update("is_default", false).Error; err != nil {
return err
}
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(req.ID, userID)
}
// DeleteProvider 删除提供商
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 删除关联的模型
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
return err
}
// 删除提供商
if err := tx.Delete(&provider).Error; err != nil {
return err
}
// 如果删除的是默认提供商,自动将第一个提供商设为默认
if provider.IsDefault {
var firstProvider app.AIProvider
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
tx.Model(&firstProvider).Update("is_default", true)
}
}
return nil
})
}
// SetDefaultProvider 设置默认提供商
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 先取消所有默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ?", userID).
Update("is_default", false).Error; err != nil {
return err
}
// 设置新默认
return tx.Model(&provider).Update("is_default", true).Error
})
}
// ==================== 模型 CRUD ====================
// AddModel 为提供商添加模型
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
// 验证提供商归属
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
if err != nil {
return response.ModelResponse{}, errors.New("提供商不存在")
}
configJSON, _ := json.Marshal(req.Config)
if req.Config == nil {
configJSON = []byte("{}")
}
isEnabled := true
if req.IsEnabled != nil {
isEnabled = *req.IsEnabled
}
model := app.AIModel{
ProviderID: req.ProviderID,
ModelName: req.ModelName,
DisplayName: req.DisplayName,
ModelType: req.ModelType,
Config: datatypes.JSON(configJSON),
IsEnabled: isEnabled,
}
if err := global.GVA_DB.Create(&model).Error; err != nil {
return response.ModelResponse{}, err
}
return toModelResponse(&model), nil
}
// UpdateModel 更新模型
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
First(&model).Error
if err != nil {
return response.ModelResponse{}, errors.New("模型不存在")
}
updates := map[string]interface{}{
"model_name": req.ModelName,
"display_name": req.DisplayName,
"model_type": req.ModelType,
}
if req.Config != nil {
configJSON, _ := json.Marshal(req.Config)
updates["config"] = datatypes.JSON(configJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
return response.ModelResponse{}, err
}
// 重新查询
global.GVA_DB.First(&model, model.ID)
return toModelResponse(&model), nil
}
// DeleteModel 删除模型
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
First(&model).Error
if err != nil {
return errors.New("模型不存在")
}
return global.GVA_DB.Delete(&model).Error
}
// ==================== 连通性测试 ====================
// TestProvider 测试提供商连通性
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// TestExistingProvider 测试已保存的提供商连通性
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.TestProviderResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.TestProvider(request.TestProviderRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 辅助查询 ====================
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
return []response.ProviderTypeOption{
{
Value: "openai",
Label: "OpenAI",
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
DefaultURL: "https://api.openai.com/v1",
},
{
Value: "claude",
Label: "Claude",
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
DefaultURL: "https://api.anthropic.com",
},
{
Value: "gemini",
Label: "Google Gemini",
Description: "Google 的 Gemini 系列模型,支持多模态",
DefaultURL: "https://generativelanguage.googleapis.com",
},
{
Value: "custom",
Label: "自定义OpenAI 兼容)",
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
DefaultURL: "",
},
}
}
// GetPresetModels 获取指定提供商类型的预设模型列表
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
presets := getPresetModels(providerType)
result := make([]response.PresetModelOption, len(presets))
for i, p := range presets {
result[i] = response.PresetModelOption{
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
}
}
return result
}
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("请先配置 AI 接口")
}
return nil, err
}
return &provider, nil
}
// GetDecryptedAPIKey 获取解密后的API密钥内部方法给AI调用用
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
return decryptAPIKey(provider.APIKey)
}
// ==================== 内部辅助函数 ====================
// toProviderResponse 转换为响应对象
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
apiConfig := json.RawMessage(p.APIConfig)
if len(apiConfig) == 0 {
apiConfig = json.RawMessage("{}")
}
capabilities := json.RawMessage(p.Capabilities)
if len(capabilities) == 0 {
capabilities = json.RawMessage("[]")
}
// 模型列表
modelList := make([]response.ModelResponse, len(models))
for i, m := range models {
modelList[i] = toModelResponse(&m)
}
// API Key 提示
apiKeyHint := ""
apiKeySet := false
if p.APIKey != "" {
apiKeySet = true
decrypted := decryptAPIKey(p.APIKey)
if len(decrypted) > 8 {
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
} else if len(decrypted) > 0 {
apiKeyHint = "****"
}
}
return response.ProviderResponse{
ID: p.ID,
ProviderName: p.ProviderName,
ProviderType: p.ProviderType,
BaseURL: p.BaseURL,
APIKeySet: apiKeySet,
APIKeyHint: apiKeyHint,
APIConfig: apiConfig,
Capabilities: capabilities,
IsEnabled: p.IsEnabled,
IsDefault: p.IsDefault,
SortOrder: p.SortOrder,
Models: modelList,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
// toModelResponse 转换模型为响应对象
func toModelResponse(m *app.AIModel) response.ModelResponse {
config := json.RawMessage(m.Config)
if len(config) == 0 {
config = json.RawMessage("{}")
}
return response.ModelResponse{
ID: m.ID,
ProviderID: m.ProviderID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: config,
IsEnabled: m.IsEnabled,
CreatedAt: m.CreatedAt,
}
}
// encryptAPIKey 加密API密钥
// TODO: 后续可以替换为更安全的加密方式(如 AES当前使用简单的 Base64 编码
func encryptAPIKey(key string) string {
if key == "" {
return ""
}
// 简单的混淆处理,生产环境应替换为 AES 加密
import_encoding := []byte(key)
for i := range import_encoding {
import_encoding[i] ^= 0x5A
}
return fmt.Sprintf("enc:%x", import_encoding)
}
// decryptAPIKey 解密API密钥
func decryptAPIKey(encrypted string) string {
if encrypted == "" {
return ""
}
if !strings.HasPrefix(encrypted, "enc:") {
return encrypted // 未加密的旧数据,直接返回
}
hexStr := encrypted[4:]
var data []byte
fmt.Sscanf(hexStr, "%x", &data)
for i := range data {
data[i] ^= 0x5A
}
return string(data)
}
// getDefaultBaseURL 获取默认API基础地址
func getDefaultBaseURL(providerType string) string {
switch providerType {
case "openai":
return "https://api.openai.com/v1"
case "claude":
return "https://api.anthropic.com"
case "gemini":
return "https://generativelanguage.googleapis.com"
default:
return ""
}
}
// getDefaultCapabilities 获取默认能力列表
func getDefaultCapabilities(providerType string) []string {
switch providerType {
case "openai":
return []string{"chat", "image_gen"}
case "claude":
return []string{"chat"}
case "gemini":
return []string{"chat", "image_gen"}
case "custom":
return []string{"chat"}
default:
return []string{"chat"}
}
}
// presetModel 预设模型内部结构
type presetModel struct {
ModelName string
DisplayName string
ModelType string
}
// getPresetModels 获取预设模型列表
func getPresetModels(providerType string) []presetModel {
switch providerType {
case "openai":
return []presetModel{
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
}
case "claude":
return []presetModel{
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
}
case "gemini":
return []presetModel{
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
}
case "custom":
return []presetModel{} // 自定义不提供预设
default:
return []presetModel{}
}
}
// ==================== 获取远程模型列表 ====================
// FetchRemoteModels 从远程API获取可用模型列表
// 所有提供商类型统一使用 baseURL + /models 端点OpenAI 兼容格式)
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
result := fetchModelsUniversal(baseURL, req.APIKey)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 发送测试消息 ====================
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
// 所有提供商类型统一使用 baseURL + /chat/completions 端点OpenAI 兼容格式)
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
message := req.Message
if message == "" {
message = "你好,请用一句话介绍你自己。"
}
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// SendTestMessageExisting 发送测试消息(已保存的提供商)
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.SendTestMessage(request.SendTestMessageRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
ModelName: modelName,
Message: message,
})
}
// ==================== 连通性测试实现 ====================
// testOpenAICompatible 测试 OpenAI 兼容接口
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.TestProviderResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
var modelNames []string
if err := json.Unmarshal(body, &modelsResp); err == nil {
for _, m := range modelsResp.Data {
modelNames = append(modelNames, m.ID)
}
}
return response.TestProviderResponse{
Success: true,
Message: "连接成功",
Models: modelNames,
}
}
// ==================== 获取远程模型列表实现 ====================
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
// 使用 baseURL + /models 端点Authorization: Bearer 鉴权
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
if resp.StatusCode == 401 {
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.FetchRemoteModelsResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
body, _ := io.ReadAll(resp.Body)
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
var modelsData struct {
Data []struct {
ID string `json:"id"`
OwnedBy string `json:"owned_by"`
} `json:"data"`
}
var models []response.RemoteModel
if err := json.Unmarshal(body, &modelsData); err == nil {
for _, m := range modelsData.Data {
models = append(models, response.RemoteModel{
ID: m.ID,
OwnedBy: m.OwnedBy,
})
}
}
return response.FetchRemoteModelsResponse{
Success: true,
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
Models: models,
}
}
// ==================== 发送测试消息实现 ====================
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
// 使用 baseURL + /chat/completions 端点Authorization: Bearer 鉴权
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
payload := map[string]interface{}{
"model": modelName,
"max_tokens": 100,
"messages": []map[string]string{
{"role": "user", "content": message},
},
}
payloadBytes, _ := json.Marshal(payload)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
var errResp struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
}
return response.SendTestMessageResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析 OpenAI 兼容格式的 chat completion 响应
var chatResp struct {
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &chatResp); err != nil {
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
}
reply := ""
if len(chatResp.Choices) > 0 {
reply = chatResp.Choices[0].Message.Content
}
return response.SendTestMessageResponse{
Success: true,
Message: "测试成功",
Reply: reply,
Model: chatResp.Model,
Tokens: chatResp.Usage.TotalTokens,
}
}