🎨 优化扩展模块,完成ai接入和对话功能

This commit is contained in:
2026-02-12 23:12:28 +08:00
parent 4e611d3a5e
commit 572f3aa15b
779 changed files with 194400 additions and 3136 deletions

470
server/api/v1/app/chat.go Normal file
View File

@@ -0,0 +1,470 @@
package app
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/middleware"
appModel "git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/common/response"
appService "git.echol.cn/loser/st/server/service/app"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type ChatApi struct{}
// ==================== 对话管理 ====================
// CreateChat 创建对话
// @Tags 对话
// @Summary 创建新对话
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.CreateChatRequest true "角色卡ID"
// @Success 200 {object} response.Response{data=appResponse.ChatResponse} "创建成功"
// @Router /app/chat [post]
func (ca *ChatApi) CreateChat(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.CreateChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
chat, err := chatService.CreateChat(req, userID)
if err != nil {
global.GVA_LOG.Error("创建对话失败", zap.Error(err))
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(chat, c)
}
// GetChatList 获取对话列表
// @Tags 对话
// @Summary 获取对话列表
// @Security ApiKeyAuth
// @Produce application/json
// @Param page query int true "页码"
// @Param pageSize query int true "每页数量"
// @Success 200 {object} response.Response{data=appResponse.ChatListResponse} "获取成功"
// @Router /app/chat/list [get]
func (ca *ChatApi) GetChatList(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.ChatListRequest
if err := c.ShouldBindQuery(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if req.Page == 0 {
req.Page = 1
}
if req.PageSize == 0 {
req.PageSize = 20
}
list, err := chatService.GetChatList(req, userID)
if err != nil {
global.GVA_LOG.Error("获取对话列表失败", zap.Error(err))
response.FailWithMessage("获取失败", c)
return
}
response.OkWithData(list, c)
}
// GetChatDetail 获取对话详情(含消息)
// @Tags 对话
// @Summary 获取对话详情
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "对话ID"
// @Success 200 {object} response.Response{data=appResponse.ChatDetailResponse} "获取成功"
// @Router /app/chat/:id [get]
func (ca *ChatApi) GetChatDetail(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
detail, err := chatService.GetChatDetail(uint(id), userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(detail, c)
}
// GetChatMessages 分页获取消息
// @Tags 对话
// @Summary 获取对话消息
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "对话ID"
// @Param page query int true "页码"
// @Param pageSize query int true "每页数量"
// @Success 200 {object} response.Response{data=appResponse.MessageListResponse} "获取成功"
// @Router /app/chat/:id/messages [get]
func (ca *ChatApi) GetChatMessages(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
var req request.ChatMessagesRequest
if err := c.ShouldBindQuery(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if req.Page == 0 {
req.Page = 1
}
if req.PageSize == 0 {
req.PageSize = 50
}
messages, err := chatService.GetChatMessages(uint(id), req, userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(messages, c)
}
// DeleteChat 删除对话
// @Tags 对话
// @Summary 删除对话
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "对话ID"
// @Success 200 {object} response.Response{msg=string} "删除成功"
// @Router /app/chat/:id [delete]
func (ca *ChatApi) DeleteChat(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
if err := chatService.DeleteChat(uint(id), userID); err != nil {
global.GVA_LOG.Error("删除对话失败", zap.Error(err))
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithMessage("删除成功", c)
}
// ==================== 消息操作 ====================
// SendMessage 发送消息并获取AI回复SSE 流式响应)
// @Tags 对话
// @Summary 发送消息SSE流式
// @Security ApiKeyAuth
// @accept application/json
// @Produce text/event-stream
// @Param data body request.SendMessageRequest true "消息内容"
// @Router /app/chat/send [post]
func (ca *ChatApi) SendMessage(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.SendMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
return
}
// 1. 获取对话上下文
chat, character, historyMsgs, err := chatService.GetChatForAI(req.ChatID, userID)
if err != nil {
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
return
}
// 2. 保存用户消息
_, err = chatService.SaveUserMessage(chat.ID, userID, req.Content)
if err != nil {
c.JSON(500, gin.H{"code": 7, "msg": "保存消息失败"})
return
}
// 2.1 世界书匹配(根据角色卡 + 历史消息)
var worldInfoTexts []string
if character != nil && chat.CharacterID != nil {
// 收集最近的消息文本(用于关键词匹配),附加当前用户输入
messagesForMatch := make([]string, 0, len(historyMsgs)+1)
for _, m := range historyMsgs {
if m.Content != "" {
messagesForMatch = append(messagesForMatch, m.Content)
}
}
if req.Content != "" {
messagesForMatch = append(messagesForMatch, req.Content)
}
matchReq := request.MatchWorldInfoRequest{
CharacterID: *chat.CharacterID,
Messages: messagesForMatch,
ScanDepth: 10,
MaxTokens: 2000,
}
if result, wErr := worldInfoService.MatchWorldInfo(userID, &matchReq); wErr != nil {
global.GVA_LOG.Warn("匹配世界书失败",
zap.Uint("chatID", chat.ID),
zap.Uint("characterID", *chat.CharacterID),
zap.Error(wErr))
} else if result != nil && len(result.Entries) > 0 {
// 按 position 分组,暂时统一作为 System 级别补充上下文
beforeChar := make([]string, 0)
afterChar := make([]string, 0)
for _, entry := range result.Entries {
text := strings.TrimSpace(entry.Content)
if text == "" {
continue
}
if entry.Position == "after_char" {
afterChar = append(afterChar, text)
} else {
// 默认 before_char
beforeChar = append(beforeChar, text)
}
}
if len(beforeChar) > 0 {
worldInfoTexts = append(worldInfoTexts, strings.Join(beforeChar, "\n\n"))
}
if len(afterChar) > 0 {
worldInfoTexts = append(worldInfoTexts, strings.Join(afterChar, "\n\n"))
}
}
}
// 3. 获取 AI 提供商和模型
var provider *appModel.AIProvider
var modelName string
if req.ProviderID != nil {
// 指定了提供商
p, pErr := providerService.GetUserDefaultProvider(userID) // 简化:暂时仅用默认
if pErr != nil {
c.JSON(400, gin.H{"code": 7, "msg": "指定的 AI 接口不可用"})
return
}
provider = p
} else {
// 使用默认提供商
p, pErr := providerService.GetUserDefaultProvider(userID)
if pErr != nil {
c.JSON(400, gin.H{"code": 7, "msg": pErr.Error()})
return
}
provider = p
}
if req.ModelName != "" {
modelName = req.ModelName
} else {
// 使用提供商的第一个启用的聊天模型
modelName = getDefaultChatModelForProvider(provider.ID)
if modelName == "" {
c.JSON(400, gin.H{"code": 7, "msg": "该 AI 接口没有可用的聊天模型"})
return
}
}
// 4. 构建 Prompt角色卡 + 世界书 + 历史消息)
prompt := appService.BuildPrompt(character, historyMsgs)
// 将世界书内容插入为额外 System 提示(靠近角色定义)
if len(worldInfoTexts) > 0 {
worldInfoContent := "World Info:\n" + strings.Join(worldInfoTexts, "\n\n")
inserted := false
for i, msg := range prompt {
if msg.Role == "system" {
// 在第一个 system 消息之后插入一条世界书消息
newPrompt := make([]appService.AIMessagePayload, 0, len(prompt)+1)
newPrompt = append(newPrompt, prompt[:i+1]...)
newPrompt = append(newPrompt, appService.AIMessagePayload{
Role: "system",
Content: worldInfoContent,
})
newPrompt = append(newPrompt, prompt[i+1:]...)
prompt = newPrompt
inserted = true
break
}
}
if !inserted {
prompt = append([]appService.AIMessagePayload{{
Role: "system",
Content: worldInfoContent,
}}, prompt...)
}
}
// 添加用户的新消息
prompt = append(prompt, appService.AIMessagePayload{
Role: "user",
Content: req.Content,
})
// 5. 设置 SSE 响应头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 6. 流式调用 AI
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
ch := make(chan appService.AIStreamChunk, 100)
go appService.StreamAIResponse(ctx, provider, modelName, prompt, ch)
var fullContent string
var promptTokens, completionTokens int
flusher, ok := c.Writer.(interface{ Flush() })
if !ok {
c.JSON(500, gin.H{"code": 7, "msg": "服务器不支持流式响应"})
return
}
for chunk := range ch {
if chunk.Error != "" {
data, _ := json.Marshal(map[string]interface{}{
"type": "error",
"error": chunk.Error,
})
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
flusher.Flush()
break
}
if chunk.Content != "" {
fullContent += chunk.Content
data, _ := json.Marshal(map[string]interface{}{
"type": "content",
"content": chunk.Content,
})
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
flusher.Flush()
}
if chunk.Done {
promptTokens = chunk.PromptTokens
completionTokens = chunk.CompletionTokens
// 保存 AI 回复
if fullContent != "" {
chatService.SaveAssistantMessage(
chat.ID, chat.CharacterID, fullContent,
modelName, promptTokens, completionTokens,
)
}
data, _ := json.Marshal(map[string]interface{}{
"type": "done",
"model": modelName,
"promptTokens": promptTokens,
"completionTokens": completionTokens,
})
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
flusher.Flush()
}
}
}
// EditMessage 编辑消息
// @Tags 对话
// @Summary 编辑消息内容
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.EditMessageRequest true "编辑信息"
// @Success 200 {object} response.Response{data=appResponse.MessageResponse} "编辑成功"
// @Router /app/chat/message/edit [post]
func (ca *ChatApi) EditMessage(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.EditMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
msg, err := chatService.EditMessage(req, userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(msg, c)
}
// DeleteMessage 删除消息
// @Tags 对话
// @Summary 删除消息
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.DeleteMessageRequest true "消息ID"
// @Success 200 {object} response.Response{msg=string} "删除成功"
// @Router /app/chat/message/delete [post]
func (ca *ChatApi) DeleteMessage(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.DeleteMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if err := chatService.DeleteMessage(req.MessageID, userID); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithMessage("删除成功", c)
}
// ==================== 内部辅助 ====================
// getDefaultChatModelForProvider 获取提供商的默认聊天模型
func getDefaultChatModelForProvider(providerID uint) string {
var model appModel.AIModel
err := global.GVA_DB.Where("provider_id = ? AND model_type = ? AND is_enabled = ?",
providerID, "chat", true).
Order("created_at ASC").
First(&model).Error
if err != nil {
return ""
}
return model.ModelName
}

View File

@@ -8,9 +8,13 @@ type ApiGroup struct {
WorldInfoApi
ExtensionApi
RegexScriptApi
ProviderApi
ChatApi
}
var (
authService = service.ServiceGroupApp.AppServiceGroup.AuthService
characterService = service.ServiceGroupApp.AppServiceGroup.CharacterService
providerService = service.ServiceGroupApp.AppServiceGroup.ProviderService
chatService = service.ServiceGroupApp.AppServiceGroup.ChatService
)

View File

@@ -3,7 +3,10 @@ package app
import (
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strconv"
"git.echol.cn/loser/st/server/global"
@@ -563,3 +566,72 @@ func (a *ExtensionApi) InstallExtensionFromGit(c *gin.Context) {
sysResponse.OkWithData(response.ToExtensionResponse(extension), c)
}
// ProxyExtensionAsset 获取扩展资源文件(从本地文件系统读取)
// @Summary 获取扩展资源文件
// @Description 从本地存储读取扩展的 JS/CSS 等资源文件(与原版 SillyTavern 一致,扩展文件存储在本地)
// @Tags 扩展管理
// @Produce octet-stream
// @Param id path int true "扩展ID"
// @Param path path string true "资源文件路径"
// @Success 200 {file} binary
// @Router /app/extension/:id/asset/*path [get]
func (a *ExtensionApi) ProxyExtensionAsset(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
sysResponse.FailWithMessage("无效的扩展ID", c)
return
}
extensionID := uint(id)
// 获取资源路径(去掉前导 /
assetPath := c.Param("path")
if len(assetPath) > 0 && assetPath[0] == '/' {
assetPath = assetPath[1:]
}
if assetPath == "" {
sysResponse.FailWithMessage("资源路径不能为空", c)
return
}
// 通过扩展 ID 查库获取信息(公开路由,不做 userID 过滤)
extInfo, err := extensionService.GetExtensionByID(extensionID)
if err != nil {
sysResponse.FailWithMessage("扩展不存在", c)
return
}
// 从本地文件系统读取资源
localPath, err := extensionService.GetExtensionAssetLocalPath(extInfo.Name, assetPath)
if err != nil {
global.GVA_LOG.Error("获取扩展资源失败",
zap.Error(err),
zap.String("name", extInfo.Name),
zap.String("asset", assetPath))
sysResponse.FailWithMessage("资源不存在: "+err.Error(), c)
return
}
// 读取文件内容
data, err := os.ReadFile(localPath)
if err != nil {
global.GVA_LOG.Error("读取扩展资源文件失败", zap.Error(err), zap.String("path", localPath))
sysResponse.FailWithMessage("资源读取失败", c)
return
}
// 根据文件扩展名设置正确的 Content-Type
fileExt := filepath.Ext(assetPath)
contentType := mime.TypeByExtension(fileExt)
if contentType == "" {
contentType = "application/octet-stream"
}
// 设置缓存和 CORS 头
c.Header("Content-Type", contentType)
c.Header("Cache-Control", "public, max-age=3600")
c.Header("Access-Control-Allow-Origin", "*")
c.Data(http.StatusOK, contentType, data)
}

View File

@@ -0,0 +1,476 @@
package app
import (
"strconv"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/middleware"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/common/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type ProviderApi struct{}
// ==================== 提供商接口 ====================
// CreateProvider 创建AI提供商
// @Tags AI配置
// @Summary 创建AI提供商
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.CreateProviderRequest true "提供商信息"
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "创建成功"
// @Router /app/provider [post]
func (pa *ProviderApi) CreateProvider(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.CreateProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
provider, err := providerService.CreateProvider(req, userID)
if err != nil {
global.GVA_LOG.Error("创建AI提供商失败", zap.Error(err))
response.FailWithMessage("创建失败: "+err.Error(), c)
return
}
response.OkWithData(provider, c)
}
// GetProviderList 获取AI提供商列表
// @Tags AI配置
// @Summary 获取AI提供商列表
// @Security ApiKeyAuth
// @Produce application/json
// @Param page query int true "页码"
// @Param pageSize query int true "每页数量"
// @Param keyword query string false "搜索关键词"
// @Success 200 {object} response.Response{data=appResponse.ProviderListResponse} "获取成功"
// @Router /app/provider/list [get]
func (pa *ProviderApi) GetProviderList(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.ProviderListRequest
if err := c.ShouldBindQuery(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
// 默认分页
if req.Page == 0 {
req.Page = 1
}
if req.PageSize == 0 {
req.PageSize = 20
}
list, err := providerService.GetProviderList(req, userID)
if err != nil {
global.GVA_LOG.Error("获取AI提供商列表失败", zap.Error(err))
response.FailWithMessage("获取失败", c)
return
}
response.OkWithData(list, c)
}
// GetProviderDetail 获取AI提供商详情
// @Tags AI配置
// @Summary 获取AI提供商详情
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "提供商ID"
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "获取成功"
// @Router /app/provider/:id [get]
func (pa *ProviderApi) GetProviderDetail(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
provider, err := providerService.GetProviderDetail(uint(id), userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(provider, c)
}
// UpdateProvider 更新AI提供商
// @Tags AI配置
// @Summary 更新AI提供商
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.UpdateProviderRequest true "更新信息"
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "更新成功"
// @Router /app/provider [put]
func (pa *ProviderApi) UpdateProvider(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.UpdateProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
provider, err := providerService.UpdateProvider(req, userID)
if err != nil {
global.GVA_LOG.Error("更新AI提供商失败", zap.Error(err))
response.FailWithMessage("更新失败: "+err.Error(), c)
return
}
response.OkWithData(provider, c)
}
// DeleteProvider 删除AI提供商
// @Tags AI配置
// @Summary 删除AI提供商
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "提供商ID"
// @Success 200 {object} response.Response{msg=string} "删除成功"
// @Router /app/provider/:id [delete]
func (pa *ProviderApi) DeleteProvider(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
if err := providerService.DeleteProvider(uint(id), userID); err != nil {
global.GVA_LOG.Error("删除AI提供商失败", zap.Error(err))
response.FailWithMessage("删除失败: "+err.Error(), c)
return
}
response.OkWithMessage("删除成功", c)
}
// SetDefaultProvider 设置默认提供商
// @Tags AI配置
// @Summary 设置默认AI提供商
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.SetDefaultProviderRequest true "提供商ID"
// @Success 200 {object} response.Response{msg=string} "设置成功"
// @Router /app/provider/setDefault [post]
func (pa *ProviderApi) SetDefaultProvider(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.SetDefaultProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if err := providerService.SetDefaultProvider(req.ProviderID, userID); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithMessage("设置成功", c)
}
// TestProvider 测试提供商连通性(不需要先保存)
// @Tags AI配置
// @Summary 测试AI提供商连通性
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.TestProviderRequest true "测试参数"
// @Success 200 {object} response.Response{data=appResponse.TestProviderResponse} "测试结果"
// @Router /app/provider/test [post]
func (pa *ProviderApi) TestProvider(c *gin.Context) {
var req request.TestProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
result, err := providerService.TestProvider(req)
if err != nil {
response.FailWithMessage("测试失败: "+err.Error(), c)
return
}
response.OkWithData(result, c)
}
// TestExistingProvider 测试已保存的提供商连通性
// @Tags AI配置
// @Summary 测试已保存的提供商连通性
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "提供商ID"
// @Success 200 {object} response.Response{data=appResponse.TestProviderResponse} "测试结果"
// @Router /app/provider/test/:id [get]
func (pa *ProviderApi) TestExistingProvider(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
result, err := providerService.TestExistingProvider(uint(id), userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(result, c)
}
// ==================== 模型接口 ====================
// AddModel 为提供商添加模型
// @Tags AI配置
// @Summary 添加AI模型
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.CreateModelRequest true "模型信息"
// @Success 200 {object} response.Response{data=appResponse.ModelResponse} "创建成功"
// @Router /app/provider/model [post]
func (pa *ProviderApi) AddModel(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.CreateModelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
model, err := providerService.AddModel(req, userID)
if err != nil {
global.GVA_LOG.Error("添加模型失败", zap.Error(err))
response.FailWithMessage("添加失败: "+err.Error(), c)
return
}
response.OkWithData(model, c)
}
// UpdateModel 更新模型
// @Tags AI配置
// @Summary 更新AI模型
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.UpdateModelRequest true "更新信息"
// @Success 200 {object} response.Response{data=appResponse.ModelResponse} "更新成功"
// @Router /app/provider/model [put]
func (pa *ProviderApi) UpdateModel(c *gin.Context) {
userID := middleware.GetAppUserID(c)
var req request.UpdateModelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
model, err := providerService.UpdateModel(req, userID)
if err != nil {
global.GVA_LOG.Error("更新模型失败", zap.Error(err))
response.FailWithMessage("更新失败: "+err.Error(), c)
return
}
response.OkWithData(model, c)
}
// DeleteModel 删除模型
// @Tags AI配置
// @Summary 删除AI模型
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "模型ID"
// @Success 200 {object} response.Response{msg=string} "删除成功"
// @Router /app/provider/model/:id [delete]
func (pa *ProviderApi) DeleteModel(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
if err := providerService.DeleteModel(uint(id), userID); err != nil {
global.GVA_LOG.Error("删除模型失败", zap.Error(err))
response.FailWithMessage("删除失败: "+err.Error(), c)
return
}
response.OkWithMessage("删除成功", c)
}
// ==================== 远程模型获取 ====================
// FetchRemoteModels 获取远程可用模型列表(不需要先保存)
// @Tags AI配置
// @Summary 从远程获取可用模型列表
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.FetchRemoteModelsRequest true "提供商配置"
// @Success 200 {object} response.Response{data=appResponse.FetchRemoteModelsResponse} "获取结果"
// @Router /app/provider/fetchModels [post]
func (pa *ProviderApi) FetchRemoteModels(c *gin.Context) {
var req request.FetchRemoteModelsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
result, err := providerService.FetchRemoteModels(req)
if err != nil {
response.FailWithMessage("获取失败: "+err.Error(), c)
return
}
response.OkWithData(result, c)
}
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
// @Tags AI配置
// @Summary 获取已保存提供商的远程可用模型
// @Security ApiKeyAuth
// @Produce application/json
// @Param id path uint true "提供商ID"
// @Success 200 {object} response.Response{data=appResponse.FetchRemoteModelsResponse} "获取结果"
// @Router /app/provider/fetchModels/:id [get]
func (pa *ProviderApi) FetchRemoteModelsExisting(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
result, err := providerService.FetchRemoteModelsExisting(uint(id), userID)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(result, c)
}
// ==================== 测试消息 ====================
// SendTestMessage 发送测试消息(不需要先保存)
// @Tags AI配置
// @Summary 发送测试消息验证AI接口
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.SendTestMessageRequest true "测试参数"
// @Success 200 {object} response.Response{data=appResponse.SendTestMessageResponse} "测试结果"
// @Router /app/provider/testMessage [post]
func (pa *ProviderApi) SendTestMessage(c *gin.Context) {
var req request.SendTestMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
result, err := providerService.SendTestMessage(req)
if err != nil {
response.FailWithMessage("测试失败: "+err.Error(), c)
return
}
response.OkWithData(result, c)
}
// SendTestMessageExisting 使用已保存的提供商发送测试消息
// @Tags AI配置
// @Summary 使用已保存的提供商发送测试消息
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param id path uint true "提供商ID"
// @Param data body object{modelName=string,message=string} true "测试参数"
// @Success 200 {object} response.Response{data=appResponse.SendTestMessageResponse} "测试结果"
// @Router /app/provider/testMessage/:id [post]
func (pa *ProviderApi) SendTestMessageExisting(c *gin.Context) {
userID := middleware.GetAppUserID(c)
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 32)
if err != nil {
response.FailWithMessage("无效的ID", c)
return
}
var body struct {
ModelName string `json:"modelName" binding:"required"`
Message string `json:"message"`
}
if err := c.ShouldBindJSON(&body); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
result, err := providerService.SendTestMessageExisting(uint(id), userID, body.ModelName, body.Message)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(result, c)
}
// ==================== 辅助接口 ====================
// GetProviderTypes 获取支持的提供商类型列表
// @Tags AI配置
// @Summary 获取支持的AI提供商类型
// @Produce application/json
// @Success 200 {object} response.Response{data=[]appResponse.ProviderTypeOption} "获取成功"
// @Router /app/provider/types [get]
func (pa *ProviderApi) GetProviderTypes(c *gin.Context) {
types := providerService.GetProviderTypes()
response.OkWithData(types, c)
}
// GetPresetModels 获取预设模型列表
// @Tags AI配置
// @Summary 获取指定提供商类型的预设模型
// @Produce application/json
// @Param type query string true "提供商类型(openai/claude/gemini/custom)"
// @Success 200 {object} response.Response{data=[]appResponse.PresetModelOption} "获取成功"
// @Router /app/provider/presetModels [get]
func (pa *ProviderApi) GetPresetModels(c *gin.Context) {
providerType := c.Query("type")
if providerType == "" {
response.FailWithMessage("请指定提供商类型", c)
return
}
models := providerService.GetPresetModels(providerType)
response.OkWithData(models, c)
}

View File

@@ -45,6 +45,11 @@ func Routers() *gin.Engine {
Router.Use(gin.Logger())
}
// 跨域配置(前台应用需要)
// 必须在静态文件路由之前注册,否则静态文件跨域会失败
Router.Use(middleware.Cors())
global.GVA_LOG.Info("use middleware cors")
if !global.GVA_CONFIG.MCP.Separate {
sseServer := McpRun()
@@ -63,20 +68,23 @@ func Routers() *gin.Engine {
exampleRouter := router.RouterGroupApp.Example
appRouter := router.RouterGroupApp.App // 前台应用路由
// 前台用户端静态文件服务web-app
// 开发环境:直接使用 web-app/public 目录
// 生产环境:使用打包后的 web-app/dist 目录
webAppPath := "../web-app/public"
if _, err := os.Stat(webAppPath); err == nil {
Router.Static("/css", webAppPath+"/css")
Router.Static("/scripts", webAppPath+"/scripts")
Router.Static("/img", webAppPath+"/img")
Router.Static("/fonts", webAppPath+"/fonts")
Router.Static("/webfonts", webAppPath+"/webfonts")
Router.StaticFile("/auth.html", webAppPath+"/auth.html")
Router.StaticFile("/dashboard-example.html", webAppPath+"/dashboard-example.html")
Router.StaticFile("/favicon.ico", webAppPath+"/favicon.ico")
global.GVA_LOG.Info("前台静态文件服务已启动: " + webAppPath)
// SillyTavern 核心脚本静态文件服务
// 扩展通过 ES module 相对路径 import 引用这些核心模块(如 ../../../../../script.js → /script.js
// 所有核心文件存储在 data/st-core-scripts/ 下,完全独立于 web-app/ 目录
// 扩展文件存储在 data/st-core-scripts/scripts/extensions/third-party/{name}/ 下
stCorePath := "data/st-core-scripts"
if _, err := os.Stat(stCorePath); err == nil {
Router.Static("/scripts", stCorePath+"/scripts")
Router.Static("/css", stCorePath+"/css")
Router.Static("/img", stCorePath+"/img")
Router.Static("/webfonts", stCorePath+"/webfonts")
Router.Static("/lib", stCorePath+"/lib") // SillyTavern 扩展依赖的第三方库
Router.Static("/locales", stCorePath+"/locales") // 国际化文件
Router.StaticFile("/script.js", stCorePath+"/script.js") // SillyTavern 主入口
Router.StaticFile("/lib.js", stCorePath+"/lib.js") // Webpack 编译后的 lib.js
global.GVA_LOG.Info("SillyTavern 核心脚本服务已启动: " + stCorePath)
} else {
global.GVA_LOG.Warn("SillyTavern 核心脚本目录不存在: " + stCorePath + ",扩展功能将不可用")
}
// 管理后台前端静态文件web
@@ -90,9 +98,6 @@ func Routers() *gin.Engine {
Router.StaticFS(global.GVA_CONFIG.Local.StorePath, justFilesFilesystem{http.Dir(global.GVA_CONFIG.Local.StorePath)}) // Router.Use(middleware.LoadTls()) // 如果需要使用https 请打开此中间件 然后前往 core/server.go 将启动模式 更变为 Router.RunTLS("端口","你的cre/pem文件","你的key文件")
// 跨域配置(前台应用需要)
Router.Use(middleware.Cors()) // 直接放行全部跨域请求
global.GVA_LOG.Info("use middleware cors")
docs.SwaggerInfo.BasePath = global.GVA_CONFIG.System.RouterPrefix
Router.GET(global.GVA_CONFIG.System.RouterPrefix+"/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
global.GVA_LOG.Info("register swagger handler")
@@ -149,6 +154,8 @@ func Routers() *gin.Engine {
appRouter.InitWorldInfoRouter(appGroup) // 世界书路由:/app/worldbook/*
appRouter.InitExtensionRouter(appGroup) // 扩展路由:/app/extension/*
appRouter.InitRegexScriptRouter(appGroup) // 正则脚本路由:/app/regex/*
appRouter.InitProviderRouter(appGroup) // AI提供商路由/app/provider/*
appRouter.InitChatRouter(appGroup) // 对话路由:/app/chat/*
}
//插件路由安装

View File

@@ -67,7 +67,7 @@ func (AIExtension) TableName() string {
return "ai_extensions"
}
// AIExtensionManifest 扩展清单结构 (对应 manifest.json)
// AIExtensionManifest 扩展清单结构 (对应 manifest.json,兼容 SillyTavern 格式)
type AIExtensionManifest struct {
Name string `json:"name"`
DisplayName string `json:"display_name,omitempty"`
@@ -75,6 +75,7 @@ type AIExtensionManifest struct {
Description string `json:"description"`
Author string `json:"author"`
Homepage string `json:"homepage,omitempty"`
HomePage string `json:"homePage,omitempty"` // SillyTavern 兼容(驼峰命名)
Repository string `json:"repository,omitempty"`
License string `json:"license,omitempty"`
Tags []string `json:"tags,omitempty"`
@@ -83,13 +84,54 @@ type AIExtensionManifest struct {
Dependencies map[string]string `json:"dependencies,omitempty"` // {"extension-name": ">=1.0.0"}
Conflicts []string `json:"conflicts,omitempty"`
Entry string `json:"entry,omitempty"` // 主入口文件
Js string `json:"js,omitempty"` // SillyTavern 兼容: JS 入口文件
Style string `json:"style,omitempty"` // 样式文件
Css string `json:"css,omitempty"` // SillyTavern 兼容: CSS 样式文件
Assets []string `json:"assets,omitempty"` // 资源文件列表
Settings map[string]interface{} `json:"settings,omitempty"` // 默认设置
Options map[string]interface{} `json:"options,omitempty"` // 扩展选项
Metadata map[string]interface{} `json:"metadata,omitempty"` // 扩展元数据
AutoUpdate bool `json:"auto_update,omitempty"` // 是否自动更新SillyTavern 兼容)
InlineScript string `json:"inline_script,omitempty"` // 内联脚本SillyTavern 兼容)
Requires []string `json:"requires,omitempty"` // SillyTavern 兼容: 依赖列表
Optional []string `json:"optional,omitempty"` // SillyTavern 兼容: 可选依赖
LoadingOrder int `json:"loading_order,omitempty"` // SillyTavern 兼容: 加载顺序
I18n map[string]string `json:"i18n,omitempty"` // SillyTavern 兼容: 国际化文件
}
// GetEffectiveName 获取有效名称,兼容 SillyTavern manifest 没有 name 字段的情况
func (m *AIExtensionManifest) GetEffectiveName() string {
if m.Name != "" {
return m.Name
}
if m.DisplayName != "" {
return m.DisplayName
}
return ""
}
// GetEffectiveHomepage 获取有效主页地址
func (m *AIExtensionManifest) GetEffectiveHomepage() string {
if m.Homepage != "" {
return m.Homepage
}
return m.HomePage
}
// GetEffectiveEntry 获取有效的 JS 入口文件路径
func (m *AIExtensionManifest) GetEffectiveEntry() string {
if m.Entry != "" {
return m.Entry
}
return m.Js
}
// GetEffectiveStyle 获取有效的 CSS 样式文件路径
func (m *AIExtensionManifest) GetEffectiveStyle() string {
if m.Style != "" {
return m.Style
}
return m.Css
}
// AIExtensionSettings 用户的扩展配置(已废弃,配置现在直接存储在 AIExtension.Settings 中)

View File

@@ -6,14 +6,21 @@ import (
)
// AIProvider AI 服务提供商配置
// 每个用户可以配置多个 AI 提供商(如 OpenAI、Claude、Gemini 等)
// UserID 为 NULL 表示系统预设的提供商配置
type AIProvider struct {
global.GVA_MODEL
UserID *uint `json:"userId" gorm:"index;comment:用户IDNULL表示系统配置"`
User *AppUser `json:"user" gorm:"foreignKey:UserID"`
User *AppUser `json:"user,omitempty" gorm:"foreignKey:UserID"`
ProviderName string `json:"providerName" gorm:"type:varchar(100);not null;index;comment:提供商名称"`
APIConfig datatypes.JSON `json:"apiConfig" gorm:"type:jsonb;not null;comment:API配置加密存储"`
ProviderType string `json:"providerType" gorm:"type:varchar(50);not null;default:openai;comment:提供商类型(openai/claude/gemini/custom)"`
BaseURL string `json:"baseUrl" gorm:"type:varchar(500);comment:API基础地址"`
APIKey string `json:"-" gorm:"type:varchar(500);comment:API密钥加密存储"`
APIConfig datatypes.JSON `json:"apiConfig" gorm:"type:jsonb;comment:额外API配置"`
Capabilities datatypes.JSON `json:"capabilities" gorm:"type:jsonb;comment:支持的能力(chat/image_gen等)"`
IsEnabled bool `json:"isEnabled" gorm:"default:true;comment:是否启用"`
IsDefault bool `json:"isDefault" gorm:"default:false;comment:是否为默认提供商"`
SortOrder int `json:"sortOrder" gorm:"default:0;comment:排序权重"`
}
func (AIProvider) TableName() string {
@@ -21,12 +28,14 @@ func (AIProvider) TableName() string {
}
// AIModel AI 模型配置
// 每个提供商下可以有多个模型(如 gpt-4o、gpt-4o-mini 等)
type AIModel struct {
global.GVA_MODEL
ProviderID uint `json:"providerId" gorm:"not null;index;comment:提供商ID"`
Provider *AIProvider `json:"provider" gorm:"foreignKey:ProviderID"`
ModelName string `json:"modelName" gorm:"type:varchar(200);not null;comment:模型名称"`
Provider *AIProvider `json:"provider,omitempty" gorm:"foreignKey:ProviderID"`
ModelName string `json:"modelName" gorm:"type:varchar(200);not null;comment:模型标识API调用用"`
DisplayName string `json:"displayName" gorm:"type:varchar(200);comment:模型显示名称"`
ModelType string `json:"modelType" gorm:"type:varchar(50);not null;default:chat;comment:模型类型(chat/image_gen)"`
Config datatypes.JSON `json:"config" gorm:"type:jsonb;comment:模型参数配置"`
IsEnabled bool `json:"isEnabled" gorm:"default:true;comment:是否启用"`
}

View File

@@ -0,0 +1,45 @@
package request
// CreateChatRequest 创建对话请求
type CreateChatRequest struct {
CharacterID uint `json:"characterId" binding:"required"` // 角色卡ID
Title string `json:"title"` // 对话标题(可选,默认使用角色名)
}
// ChatListRequest 对话列表请求
type ChatListRequest struct {
Page int `form:"page" binding:"min=1"`
PageSize int `form:"pageSize" binding:"min=1,max=50"`
}
// SendMessageRequest 发送消息请求
type SendMessageRequest struct {
ChatID uint `json:"chatId" binding:"required"` // 对话ID
Content string `json:"content" binding:"required"` // 消息内容
ProviderID *uint `json:"providerId"` // 指定提供商(可选,默认使用用户默认)
ModelName string `json:"modelName"` // 指定模型(可选)
}
// RegenerateRequest 重新生成请求
type RegenerateRequest struct {
ChatID uint `json:"chatId" binding:"required"` // 对话ID
MessageID uint `json:"messageId" binding:"required"` // 要重新生成的消息ID
ModelName string `json:"modelName"` // 可选,指定其他模型
}
// EditMessageRequest 编辑消息请求
type EditMessageRequest struct {
MessageID uint `json:"messageId" binding:"required"`
Content string `json:"content" binding:"required"`
}
// DeleteMessageRequest 删除消息请求
type DeleteMessageRequest struct {
MessageID uint `json:"messageId" binding:"required"`
}
// ChatMessagesRequest 获取对话消息请求
type ChatMessagesRequest struct {
Page int `form:"page" binding:"min=1"`
PageSize int `form:"pageSize" binding:"min=1,max=100"`
}

View File

@@ -0,0 +1,80 @@
package request
// CreateProviderRequest 创建AI提供商请求
type CreateProviderRequest struct {
ProviderName string `json:"providerName" binding:"required,min=1,max=100"` // 提供商名称(如"我的OpenAI"
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"` // 提供商类型
BaseURL string `json:"baseUrl"` // API基础地址自定义或中转站
APIKey string `json:"apiKey" binding:"required"` // API密钥
APIConfig map[string]interface{} `json:"apiConfig"` // 额外配置
Models []CreateModelRequest `json:"models"` // 同时创建的模型列表
}
// UpdateProviderRequest 更新AI提供商请求
type UpdateProviderRequest struct {
ID uint `json:"id" binding:"required"`
ProviderName string `json:"providerName" binding:"required,min=1,max=100"`
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey"` // 为空表示不修改
APIConfig map[string]interface{} `json:"apiConfig"`
IsEnabled *bool `json:"isEnabled"`
IsDefault *bool `json:"isDefault"`
SortOrder *int `json:"sortOrder"`
}
// CreateModelRequest 创建AI模型请求
type CreateModelRequest struct {
ProviderID uint `json:"providerId"` // 关联提供商ID
ModelName string `json:"modelName" binding:"required,min=1,max=200"` // 模型标识(如 gpt-4o
DisplayName string `json:"displayName"` // 显示名称
ModelType string `json:"modelType" binding:"required,oneof=chat image_gen"` // 模型类型
Config map[string]interface{} `json:"config"` // 模型参数配置
IsEnabled *bool `json:"isEnabled"`
}
// UpdateModelRequest 更新AI模型请求
type UpdateModelRequest struct {
ID uint `json:"id" binding:"required"`
ModelName string `json:"modelName" binding:"required,min=1,max=200"`
DisplayName string `json:"displayName"`
ModelType string `json:"modelType" binding:"required,oneof=chat image_gen"`
Config map[string]interface{} `json:"config"`
IsEnabled *bool `json:"isEnabled"`
}
// ProviderListRequest 提供商列表请求
type ProviderListRequest struct {
Page int `form:"page" binding:"min=1"`
PageSize int `form:"pageSize" binding:"min=1,max=100"`
Keyword string `form:"keyword"`
}
// TestProviderRequest 测试提供商连通性请求
type TestProviderRequest struct {
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey" binding:"required"`
ModelName string `json:"modelName"` // 可选,用于测试特定模型
}
// SetDefaultProviderRequest 设置默认提供商请求
type SetDefaultProviderRequest struct {
ProviderID uint `json:"providerId" binding:"required"`
}
// FetchRemoteModelsRequest 获取远程可用模型列表请求
type FetchRemoteModelsRequest struct {
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey" binding:"required"`
}
// SendTestMessageRequest 发送测试消息请求
type SendTestMessageRequest struct {
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
BaseURL string `json:"baseUrl"`
APIKey string `json:"apiKey" binding:"required"`
ModelName string `json:"modelName" binding:"required"` // 要测试的模型
Message string `json:"message"` // 测试消息内容,为空则使用默认
}

View File

@@ -0,0 +1,64 @@
package response
import (
"time"
)
// ChatResponse 对话响应
type ChatResponse struct {
ID uint `json:"id"`
Title string `json:"title"`
CharacterID *uint `json:"characterId"`
CharacterName string `json:"characterName"`
CharacterAvatar string `json:"characterAvatar"`
ChatType string `json:"chatType"`
LastMessageAt *time.Time `json:"lastMessageAt"`
MessageCount int `json:"messageCount"`
IsPinned bool `json:"isPinned"`
LastMessage *MessageBrief `json:"lastMessage,omitempty"` // 最后一条消息摘要
CreatedAt time.Time `json:"createdAt"`
}
// MessageBrief 消息摘要(用于对话列表显示)
type MessageBrief struct {
Content string `json:"content"`
Role string `json:"role"`
}
// ChatListResponse 对话列表响应
type ChatListResponse struct {
List []ChatResponse `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"pageSize"`
}
// MessageResponse 消息响应
type MessageResponse struct {
ID uint `json:"id"`
ChatID uint `json:"chatId"`
Content string `json:"content"`
Role string `json:"role"` // user / assistant / system
CharacterID *uint `json:"characterId"`
CharacterName string `json:"characterName,omitempty"`
Model string `json:"model,omitempty"`
PromptTokens int `json:"promptTokens,omitempty"`
CompletionTokens int `json:"completionTokens,omitempty"`
TotalTokens int `json:"totalTokens,omitempty"`
SequenceNumber int `json:"sequenceNumber"`
CreatedAt time.Time `json:"createdAt"`
}
// MessageListResponse 消息列表响应
type MessageListResponse struct {
List []MessageResponse `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"pageSize"`
}
// ChatDetailResponse 对话详情响应(包含角色信息 + 消息列表)
type ChatDetailResponse struct {
Chat ChatResponse `json:"chat"`
Messages []MessageResponse `json:"messages"`
}

View File

@@ -8,39 +8,44 @@ import (
// ExtensionResponse 扩展响应
type ExtensionResponse struct {
ID uint `json:"id"`
UserID uint `json:"userId"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
Version string `json:"version"`
Author string `json:"author"`
Description string `json:"description"`
Homepage string `json:"homepage"`
Repository string `json:"repository"`
License string `json:"license"`
Tags []string `json:"tags"`
ExtensionType string `json:"extensionType"`
Category string `json:"category"`
Dependencies map[string]string `json:"dependencies"`
Conflicts []string `json:"conflicts"`
ManifestData map[string]interface{} `json:"manifestData"`
ScriptPath string `json:"scriptPath"`
StylePath string `json:"stylePath"`
AssetsPaths []string `json:"assetsPaths"`
Settings map[string]interface{} `json:"settings"`
Options map[string]interface{} `json:"options"`
IsEnabled bool `json:"isEnabled"`
IsInstalled bool `json:"isInstalled"`
IsSystemExt bool `json:"isSystemExt"`
InstallSource string `json:"installSource"`
InstallDate time.Time `json:"installDate"`
LastEnabled time.Time `json:"lastEnabled"`
UsageCount int `json:"usageCount"`
ErrorCount int `json:"errorCount"`
LoadTime int `json:"loadTime"`
Metadata map[string]interface{} `json:"metadata"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
ID uint `json:"id"`
UserID uint `json:"userId"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
Version string `json:"version"`
Author string `json:"author"`
Description string `json:"description"`
Homepage string `json:"homepage"`
Repository string `json:"repository"`
License string `json:"license"`
Tags []string `json:"tags"`
ExtensionType string `json:"extensionType"`
Category string `json:"category"`
Dependencies map[string]string `json:"dependencies"`
Conflicts []string `json:"conflicts"`
ManifestData map[string]interface{} `json:"manifestData"`
ScriptPath string `json:"scriptPath"`
StylePath string `json:"stylePath"`
AssetsPaths []string `json:"assetsPaths"`
Settings map[string]interface{} `json:"settings"`
Options map[string]interface{} `json:"options"`
IsEnabled bool `json:"isEnabled"`
IsInstalled bool `json:"isInstalled"`
IsSystemExt bool `json:"isSystemExt"`
InstallSource string `json:"installSource"`
SourceURL string `json:"sourceUrl"`
Branch string `json:"branch"`
AutoUpdate bool `json:"autoUpdate"`
InstallDate time.Time `json:"installDate"`
LastEnabled time.Time `json:"lastEnabled"`
LastUpdateCheck *time.Time `json:"lastUpdateCheck"`
AvailableVersion string `json:"availableVersion"`
UsageCount int `json:"usageCount"`
ErrorCount int `json:"errorCount"`
LoadTime int `json:"loadTime"`
Metadata map[string]interface{} `json:"metadata"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// ExtensionListResponse 扩展列表响应
@@ -151,38 +156,43 @@ func ToExtensionResponse(ext *app.AIExtension) ExtensionResponse {
}
return ExtensionResponse{
ID: ext.ID,
UserID: ext.UserID,
Name: ext.Name,
DisplayName: ext.DisplayName,
Version: ext.Version,
Author: ext.Author,
Description: ext.Description,
Homepage: ext.Homepage,
Repository: ext.Repository,
License: ext.License,
Tags: tags,
ExtensionType: ext.ExtensionType,
Category: ext.Category,
Dependencies: dependencies,
Conflicts: conflicts,
ManifestData: manifestData,
ScriptPath: ext.ScriptPath,
StylePath: ext.StylePath,
AssetsPaths: assetsPaths,
Settings: settings,
Options: options,
IsEnabled: ext.IsEnabled,
IsInstalled: ext.IsInstalled,
IsSystemExt: ext.IsSystemExt,
InstallSource: ext.InstallSource,
InstallDate: ext.InstallDate,
LastEnabled: ext.LastEnabled,
UsageCount: ext.UsageCount,
ErrorCount: ext.ErrorCount,
LoadTime: ext.LoadTime,
Metadata: metadata,
CreatedAt: ext.CreatedAt,
UpdatedAt: ext.UpdatedAt,
ID: ext.ID,
UserID: ext.UserID,
Name: ext.Name,
DisplayName: ext.DisplayName,
Version: ext.Version,
Author: ext.Author,
Description: ext.Description,
Homepage: ext.Homepage,
Repository: ext.Repository,
License: ext.License,
Tags: tags,
ExtensionType: ext.ExtensionType,
Category: ext.Category,
Dependencies: dependencies,
Conflicts: conflicts,
ManifestData: manifestData,
ScriptPath: ext.ScriptPath,
StylePath: ext.StylePath,
AssetsPaths: assetsPaths,
Settings: settings,
Options: options,
IsEnabled: ext.IsEnabled,
IsInstalled: ext.IsInstalled,
IsSystemExt: ext.IsSystemExt,
InstallSource: ext.InstallSource,
SourceURL: ext.SourceURL,
Branch: ext.Branch,
AutoUpdate: ext.AutoUpdate,
InstallDate: ext.InstallDate,
LastEnabled: ext.LastEnabled,
LastUpdateCheck: ext.LastUpdateCheck,
AvailableVersion: ext.AvailableVersion,
UsageCount: ext.UsageCount,
ErrorCount: ext.ErrorCount,
LoadTime: ext.LoadTime,
Metadata: metadata,
CreatedAt: ext.CreatedAt,
UpdatedAt: ext.UpdatedAt,
}
}

View File

@@ -0,0 +1,92 @@
package response
import (
"encoding/json"
"time"
)
// ProviderResponse 提供商响应
type ProviderResponse struct {
ID uint `json:"id"`
ProviderName string `json:"providerName"`
ProviderType string `json:"providerType"`
BaseURL string `json:"baseUrl"`
APIKeySet bool `json:"apiKeySet"` // 是否已设置API密钥不返回明文
APIKeyHint string `json:"apiKeyHint"` // API密钥提示如 sk-****1234
APIConfig json.RawMessage `json:"apiConfig"`
Capabilities json.RawMessage `json:"capabilities"`
IsEnabled bool `json:"isEnabled"`
IsDefault bool `json:"isDefault"`
SortOrder int `json:"sortOrder"`
Models []ModelResponse `json:"models"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// ModelResponse 模型响应
type ModelResponse struct {
ID uint `json:"id"`
ProviderID uint `json:"providerId"`
ModelName string `json:"modelName"`
DisplayName string `json:"displayName"`
ModelType string `json:"modelType"`
Config json.RawMessage `json:"config"`
IsEnabled bool `json:"isEnabled"`
CreatedAt time.Time `json:"createdAt"`
}
// ProviderListResponse 提供商列表响应
type ProviderListResponse struct {
List []ProviderResponse `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"pageSize"`
}
// TestProviderResponse 测试连通性响应
type TestProviderResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Models []string `json:"models,omitempty"` // 获取到的可用模型列表
Latency int64 `json:"latency"` // 响应延迟(毫秒)
}
// ProviderTypeOption 提供商类型选项(前端下拉用)
type ProviderTypeOption struct {
Value string `json:"value"` // 类型标识
Label string `json:"label"` // 显示名称
Description string `json:"description"` // 描述
DefaultURL string `json:"defaultUrl"` // 默认API地址
}
// PresetModelOption 预设模型选项
type PresetModelOption struct {
ModelName string `json:"modelName"`
DisplayName string `json:"displayName"`
ModelType string `json:"modelType"` // chat / image_gen
}
// RemoteModel 远程获取到的模型信息
type RemoteModel struct {
ID string `json:"id"` // 模型标识API调用用
DisplayName string `json:"displayName"` // 显示名称
OwnedBy string `json:"ownedBy"` // 所有者/来源
}
// FetchRemoteModelsResponse 获取远程模型列表响应
type FetchRemoteModelsResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Models []RemoteModel `json:"models"`
Latency int64 `json:"latency"` // 响应延迟(毫秒)
}
// SendTestMessageResponse 发送测试消息响应
type SendTestMessageResponse struct {
Success bool `json:"success"`
Message string `json:"message"` // 状态信息
Reply string `json:"reply"` // AI 回复内容
Model string `json:"model"` // 实际使用的模型
Latency int64 `json:"latency"` // 响应延迟(毫秒)
Tokens int `json:"tokens"` // 消耗的token数
}

29
server/router/app/chat.go Normal file
View File

@@ -0,0 +1,29 @@
package app
import (
v1 "git.echol.cn/loser/st/server/api/v1"
"git.echol.cn/loser/st/server/middleware"
"github.com/gin-gonic/gin"
)
type ChatRouter struct{}
func (cr *ChatRouter) InitChatRouter(Router *gin.RouterGroup) {
chatApi := v1.ApiGroupApp.AppApiGroup.ChatApi
// 所有对话接口都需要登录
chatRouter := Router.Group("chat").Use(middleware.AppJWTAuth())
{
// 对话管理
chatRouter.POST("", chatApi.CreateChat) // 创建对话
chatRouter.GET("/list", chatApi.GetChatList) // 对话列表
chatRouter.GET("/:id", chatApi.GetChatDetail) // 对话详情
chatRouter.GET("/:id/messages", chatApi.GetChatMessages) // 获取消息
chatRouter.DELETE("/:id", chatApi.DeleteChat) // 删除对话
// 消息操作
chatRouter.POST("/send", chatApi.SendMessage) // 发送消息SSE 流式)
chatRouter.POST("/message/edit", chatApi.EditMessage) // 编辑消息
chatRouter.POST("/message/delete", chatApi.DeleteMessage) // 删除消息
}
}

View File

@@ -6,4 +6,6 @@ type RouterGroup struct {
WorldInfoRouter
ExtensionRouter
RegexScriptRouter
ProviderRouter
ChatRouter
}

View File

@@ -43,4 +43,13 @@ func (r *ExtensionRouter) InitExtensionRouter(Router *gin.RouterGroup) {
// 统计
extensionRouter.POST("/stats", extensionApi.UpdateExtensionStats) // 更新扩展统计
}
// 扩展资源文件 - 公开路由(不需要鉴权)
// 原因:<script type="module"> 标签无法携带 JWT header
// 且 ES module 的 import 语句也无法携带认证信息。
// 与原版 SillyTavern 一致:扩展文件作为公开静态资源提供。
extensionPublicRouter := Router.Group("extension")
{
extensionPublicRouter.GET("/:id/asset/*path", extensionApi.ProxyExtensionAsset)
}
}

View File

@@ -0,0 +1,49 @@
package app
import (
v1 "git.echol.cn/loser/st/server/api/v1"
"git.echol.cn/loser/st/server/middleware"
"github.com/gin-gonic/gin"
)
type ProviderRouter struct{}
func (pr *ProviderRouter) InitProviderRouter(Router *gin.RouterGroup) {
providerApi := v1.ApiGroupApp.AppApiGroup.ProviderApi
// 公开接口(不需要登录)
providerPublicRouter := Router.Group("provider")
{
providerPublicRouter.GET("/types", providerApi.GetProviderTypes) // 获取支持的提供商类型
providerPublicRouter.GET("/presetModels", providerApi.GetPresetModels) // 获取预设模型列表
}
// 需要登录的接口
providerAuthRouter := Router.Group("provider").Use(middleware.AppJWTAuth())
{
// 提供商 CRUD
providerAuthRouter.POST("", providerApi.CreateProvider) // 创建提供商
providerAuthRouter.GET("/list", providerApi.GetProviderList) // 获取列表
providerAuthRouter.GET("/:id", providerApi.GetProviderDetail) // 获取详情
providerAuthRouter.PUT("", providerApi.UpdateProvider) // 更新提供商
providerAuthRouter.DELETE("/:id", providerApi.DeleteProvider) // 删除提供商
providerAuthRouter.POST("/setDefault", providerApi.SetDefaultProvider) // 设置默认
// 连通性测试
providerAuthRouter.POST("/test", providerApi.TestProvider) // 测试连接(不需要先保存)
providerAuthRouter.GET("/test/:id", providerApi.TestExistingProvider) // 测试已保存的连接
// 远程模型获取
providerAuthRouter.POST("/fetchModels", providerApi.FetchRemoteModels) // 获取远程模型列表
providerAuthRouter.GET("/fetchModels/:id", providerApi.FetchRemoteModelsExisting) // 获取已保存提供商的远程模型
// 测试消息
providerAuthRouter.POST("/testMessage", providerApi.SendTestMessage) // 发送测试消息
providerAuthRouter.POST("/testMessage/:id", providerApi.SendTestMessageExisting) // 已保存提供商发测试消息
// 模型管理
providerAuthRouter.POST("/model", providerApi.AddModel) // 添加模型
providerAuthRouter.PUT("/model", providerApi.UpdateModel) // 更新模型
providerAuthRouter.DELETE("/model/:id", providerApi.DeleteModel) // 删除模型
}
}

View File

@@ -0,0 +1,625 @@
package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
appModel "git.echol.cn/loser/st/server/model/app"
)
// AIMessage AI调用的消息格式统一
type AIMessagePayload struct {
Role string `json:"role"`
Content string `json:"content"`
}
// AIStreamChunk 流式响应的单个块
type AIStreamChunk struct {
Content string `json:"content"` // 增量文本
Done bool `json:"done"` // 是否结束
Model string `json:"model"` // 使用的模型
PromptTokens int `json:"promptTokens"` // 提示词Token仅结束时有值
CompletionTokens int `json:"completionTokens"` // 补全Token仅结束时有值
Error string `json:"error"` // 错误信息
}
// BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt
func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload {
var payload []AIMessagePayload
// 1. 系统提示词System Prompt
systemPrompt := buildSystemPrompt(character)
if systemPrompt != "" {
payload = append(payload, AIMessagePayload{
Role: "system",
Content: systemPrompt,
})
}
// 2. 历史消息
for _, msg := range messages {
role := msg.Role
if role == "assistant" || role == "user" || role == "system" {
payload = append(payload, AIMessagePayload{
Role: role,
Content: msg.Content,
})
}
}
return payload
}
// buildSystemPrompt 构建系统提示词
func buildSystemPrompt(character *appModel.AICharacter) string {
if character == nil {
return ""
}
var parts []string
// 角色描述
if character.Description != "" {
parts = append(parts, character.Description)
}
// 角色性格
if character.Personality != "" {
parts = append(parts, "Personality: "+character.Personality)
}
// 场景
if character.Scenario != "" {
parts = append(parts, "Scenario: "+character.Scenario)
}
// 系统提示词
if character.SystemPrompt != "" {
parts = append(parts, character.SystemPrompt)
}
// 示例消息
if len(character.ExampleMessages) > 0 {
parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n"))
}
return strings.Join(parts, "\n\n")
}
// StreamAIResponse 流式调用 AI 并通过 channel 返回结果
func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
defer close(ch)
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch)
case "claude":
streamClaude(ctx, baseURL, apiKey, modelName, messages, ch)
case "gemini":
streamGemini(ctx, baseURL, apiKey, modelName, messages, ch)
default:
ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true}
}
}
// ==================== OpenAI Compatible ====================
func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
"stream": true,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
var fullContent string
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""}
return
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
content := chunk.Choices[0].Delta.Content
fullContent += content
ch <- AIStreamChunk{Content: content}
}
if chunk.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.Usage.PromptTokens,
CompletionTokens: chunk.Usage.CompletionTokens,
}
return
}
}
// 如果扫描结束但没收到 [DONE]
if fullContent != "" {
ch <- AIStreamChunk{Done: true, Model: modelName}
}
}
// ==================== Claude ====================
func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
// Claude 需要把 system 消息分离出来
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
// 确保至少有一条消息
if len(claudeMessages) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
"stream": true,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var event struct {
Type string `json:"type"`
Delta *struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"delta"`
Usage *struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
switch event.Type {
case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" {
ch <- AIStreamChunk{Content: event.Delta.Text}
}
case "message_delta":
if event.Usage != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens,
}
return
}
case "message_stop":
ch <- AIStreamChunk{Done: true, Model: modelName}
return
case "error":
ch <- AIStreamChunk{Error: "Claude API 错误", Done: true}
return
}
}
}
// ==================== Gemini ====================
func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
// 构建 Gemini 格式的消息
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{
{"text": msg.Content},
},
})
}
if len(contents) == 0 {
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
return
}
body := map[string]interface{}{
"contents": contents,
}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{
{"text": strings.TrimSpace(systemInstruction)},
},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
return
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
respBody, _ := io.ReadAll(resp.Body)
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
return
}
scanner := bufio.NewScanner(resp.Body)
// Gemini 返回较大的 chunks增大 buffer
scanner.Buffer(make([]byte, 0), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
var chunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata *struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
for _, candidate := range chunk.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
ch <- AIStreamChunk{Content: part.Text}
}
}
}
if chunk.UsageMetadata != nil {
ch <- AIStreamChunk{
Done: true,
Model: modelName,
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
}
return
}
}
ch <- AIStreamChunk{Done: true, Model: modelName}
}
// ==================== 非流式调用(用于生图等) ====================
// CallAINonStream 非流式调用 AI
func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) {
apiKey := decryptAPIKey(provider.APIKey)
baseURL := provider.BaseURL
switch provider.ProviderType {
case "openai", "custom":
return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages)
case "claude":
return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages)
case "gemini":
return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages)
default:
return "", errors.New("不支持的提供商类型: " + provider.ProviderType)
}
}
func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body := map[string]interface{}{
"model": modelName,
"messages": messages,
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Choices) > 0 {
return result.Choices[0].Message.Content, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
var systemContent string
var claudeMessages []map[string]string
for _, msg := range messages {
if msg.Role == "system" {
systemContent += msg.Content + "\n"
} else {
claudeMessages = append(claudeMessages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
}
body := map[string]interface{}{
"model": modelName,
"messages": claudeMessages,
"max_tokens": 4096,
}
if systemContent != "" {
body["system"] = strings.TrimSpace(systemContent)
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Content) > 0 {
return result.Content[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}
func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
strings.TrimRight(baseURL, "/"), modelName, apiKey)
var systemInstruction string
var contents []map[string]interface{}
for _, msg := range messages {
if msg.Role == "system" {
systemInstruction += msg.Content + "\n"
continue
}
role := msg.Role
if role == "assistant" {
role = "model"
}
contents = append(contents, map[string]interface{}{
"role": role,
"parts": []map[string]string{{"text": msg.Content}},
})
}
body := map[string]interface{}{"contents": contents}
if systemInstruction != "" {
body["systemInstruction"] = map[string]interface{}{
"parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}},
}
}
bodyJSON, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 60 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", err
}
if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 {
return result.Candidates[0].Content.Parts[0].Text, nil
}
return "", errors.New("AI 未返回有效回复")
}

408
server/service/app/chat.go Normal file
View File

@@ -0,0 +1,408 @@
package app
import (
"errors"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/app/response"
"gorm.io/gorm"
)
type ChatService struct{}
// ==================== 对话 CRUD ====================
// CreateChat 创建对话
func (cs *ChatService) CreateChat(req request.CreateChatRequest, userID uint) (response.ChatResponse, error) {
// 获取角色卡信息
var character app.AICharacter
err := global.GVA_DB.First(&character, req.CharacterID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ChatResponse{}, errors.New("角色卡不存在")
}
return response.ChatResponse{}, err
}
// 对话标题
title := req.Title
if title == "" {
title = character.Name
}
now := time.Now()
chat := app.AIChat{
Title: title,
UserID: userID,
CharacterID: &req.CharacterID,
ChatType: "single",
LastMessageAt: &now,
}
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 创建对话
if err := tx.Create(&chat).Error; err != nil {
return err
}
// 如果角色卡有 FirstMessage自动创建第一条系统消息
if character.FirstMessage != "" {
firstMsg := app.AIMessage{
ChatID: chat.ID,
Content: character.FirstMessage,
Role: "assistant",
CharacterID: &req.CharacterID,
SequenceNumber: 1,
}
if err := tx.Create(&firstMsg).Error; err != nil {
return err
}
chat.MessageCount = 1
tx.Model(&chat).Update("message_count", 1)
}
// 更新角色卡使用次数
tx.Model(&character).Update("total_chats", gorm.Expr("total_chats + 1"))
return nil
})
if err != nil {
return response.ChatResponse{}, err
}
return toChatResponse(&chat, &character, nil), nil
}
// GetChatList 获取用户的对话列表
func (cs *ChatService) GetChatList(req request.ChatListRequest, userID uint) (response.ChatListResponse, error) {
db := global.GVA_DB.Model(&app.AIChat{}).Where("user_id = ?", userID)
var total int64
db.Count(&total)
var chats []app.AIChat
offset := (req.Page - 1) * req.PageSize
err := db.Preload("Character").
Order("is_pinned DESC, last_message_at DESC").
Offset(offset).Limit(req.PageSize).
Find(&chats).Error
if err != nil {
return response.ChatListResponse{}, err
}
// 获取每个对话的最后一条消息
chatIDs := make([]uint, len(chats))
for i, c := range chats {
chatIDs[i] = c.ID
}
lastMessages := make(map[uint]*response.MessageBrief)
if len(chatIDs) > 0 {
var messages []app.AIMessage
// 获取每个对话的最后一条消息(通过子查询)
global.GVA_DB.Where("chat_id IN ? AND is_deleted = ?", chatIDs, false).
Order("sequence_number DESC").
Find(&messages)
// 只保留每个对话的最后一条
for _, msg := range messages {
if _, exists := lastMessages[msg.ChatID]; !exists {
content := msg.Content
if len(content) > 100 {
content = content[:100] + "..."
}
lastMessages[msg.ChatID] = &response.MessageBrief{
Content: content,
Role: msg.Role,
}
}
}
}
list := make([]response.ChatResponse, len(chats))
for i, c := range chats {
list[i] = toChatResponse(&c, c.Character, lastMessages[c.ID])
}
return response.ChatListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// GetChatDetail 获取对话详情(包含消息列表)
func (cs *ChatService) GetChatDetail(chatID uint, userID uint) (response.ChatDetailResponse, error) {
var chat app.AIChat
err := global.GVA_DB.Preload("Character").
Where("id = ? AND user_id = ?", chatID, userID).
First(&chat).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ChatDetailResponse{}, errors.New("对话不存在")
}
return response.ChatDetailResponse{}, err
}
// 获取消息列表最近的50条
var messages []app.AIMessage
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
Order("sequence_number ASC").
Limit(50).
Find(&messages)
msgList := make([]response.MessageResponse, len(messages))
for i, msg := range messages {
msgList[i] = toMessageResponse(&msg, chat.Character)
}
return response.ChatDetailResponse{
Chat: toChatResponse(&chat, chat.Character, nil),
Messages: msgList,
}, nil
}
// GetChatMessages 分页获取对话消息
func (cs *ChatService) GetChatMessages(chatID uint, req request.ChatMessagesRequest, userID uint) (response.MessageListResponse, error) {
// 验证对话归属
var chat app.AIChat
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
if err != nil {
return response.MessageListResponse{}, errors.New("对话不存在")
}
db := global.GVA_DB.Model(&app.AIMessage{}).
Where("chat_id = ? AND is_deleted = ?", chatID, false)
var total int64
db.Count(&total)
var messages []app.AIMessage
offset := (req.Page - 1) * req.PageSize
err = db.Order("sequence_number ASC").
Offset(offset).Limit(req.PageSize).
Find(&messages).Error
if err != nil {
return response.MessageListResponse{}, err
}
// 预加载角色信息
var character *app.AICharacter
if chat.CharacterID != nil {
var c app.AICharacter
global.GVA_DB.First(&c, *chat.CharacterID)
character = &c
}
list := make([]response.MessageResponse, len(messages))
for i, msg := range messages {
list[i] = toMessageResponse(&msg, character)
}
return response.MessageListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// DeleteChat 删除对话
func (cs *ChatService) DeleteChat(chatID uint, userID uint) error {
var chat app.AIChat
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("对话不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 删除消息
if err := tx.Where("chat_id = ?", chatID).Delete(&app.AIMessage{}).Error; err != nil {
return err
}
// 删除消息变体
if err := tx.Where("message_id IN (?)",
tx.Model(&app.AIMessage{}).Select("id").Where("chat_id = ?", chatID),
).Delete(&app.AIMessageSwipe{}).Error; err != nil {
// 忽略错误,可能没有变体
}
// 删除对话
return tx.Delete(&chat).Error
})
}
// ==================== 消息操作 ====================
// SaveUserMessage 保存用户消息(内部方法)
func (cs *ChatService) SaveUserMessage(chatID uint, userID uint, content string) (*app.AIMessage, error) {
// 获取当前最大序号
var maxSeq int
global.GVA_DB.Model(&app.AIMessage{}).
Where("chat_id = ?", chatID).
Select("COALESCE(MAX(sequence_number), 0)").
Scan(&maxSeq)
msg := &app.AIMessage{
ChatID: chatID,
Content: content,
Role: "user",
SenderID: &userID,
SequenceNumber: maxSeq + 1,
}
if err := global.GVA_DB.Create(msg).Error; err != nil {
return nil, err
}
// 更新对话的消息数和最后消息时间
now := time.Now()
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
"message_count": gorm.Expr("message_count + 1"),
"last_message_at": now,
})
return msg, nil
}
// SaveAssistantMessage 保存AI回复消息内部方法
func (cs *ChatService) SaveAssistantMessage(chatID uint, characterID *uint, content string, model string, promptTokens, completionTokens int) (*app.AIMessage, error) {
var maxSeq int
global.GVA_DB.Model(&app.AIMessage{}).
Where("chat_id = ?", chatID).
Select("COALESCE(MAX(sequence_number), 0)").
Scan(&maxSeq)
msg := &app.AIMessage{
ChatID: chatID,
Content: content,
Role: "assistant",
CharacterID: characterID,
SequenceNumber: maxSeq + 1,
Model: model,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
if err := global.GVA_DB.Create(msg).Error; err != nil {
return nil, err
}
now := time.Now()
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
"message_count": gorm.Expr("message_count + 1"),
"last_message_at": now,
})
return msg, nil
}
// EditMessage 编辑消息
func (cs *ChatService) EditMessage(req request.EditMessageRequest, userID uint) (response.MessageResponse, error) {
var msg app.AIMessage
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
Where("ai_messages.id = ? AND ai_chats.user_id = ?", req.MessageID, userID).
First(&msg).Error
if err != nil {
return response.MessageResponse{}, errors.New("消息不存在")
}
msg.Content = req.Content
if err := global.GVA_DB.Save(&msg).Error; err != nil {
return response.MessageResponse{}, err
}
return toMessageResponse(&msg, nil), nil
}
// DeleteMessage 删除消息(软删除)
func (cs *ChatService) DeleteMessage(messageID uint, userID uint) error {
var msg app.AIMessage
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
Where("ai_messages.id = ? AND ai_chats.user_id = ?", messageID, userID).
First(&msg).Error
if err != nil {
return errors.New("消息不存在")
}
return global.GVA_DB.Model(&msg).Update("is_deleted", true).Error
}
// GetChatForAI 获取对话用于AI调用的完整上下文内部方法
func (cs *ChatService) GetChatForAI(chatID uint, userID uint) (*app.AIChat, *app.AICharacter, []app.AIMessage, error) {
var chat app.AIChat
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
if err != nil {
return nil, nil, nil, errors.New("对话不存在")
}
var character *app.AICharacter
if chat.CharacterID != nil {
var c app.AICharacter
if err := global.GVA_DB.First(&c, *chat.CharacterID).Error; err == nil {
character = &c
}
}
// 获取历史消息最近30条
var messages []app.AIMessage
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
Order("sequence_number ASC").
Limit(30).
Find(&messages)
return &chat, character, messages, nil
}
// ==================== 辅助函数 ====================
func toChatResponse(chat *app.AIChat, character *app.AICharacter, lastMsg *response.MessageBrief) response.ChatResponse {
resp := response.ChatResponse{
ID: chat.ID,
Title: chat.Title,
CharacterID: chat.CharacterID,
ChatType: chat.ChatType,
LastMessageAt: chat.LastMessageAt,
MessageCount: chat.MessageCount,
IsPinned: chat.IsPinned,
LastMessage: lastMsg,
CreatedAt: chat.CreatedAt,
}
if character != nil {
resp.CharacterName = character.Name
resp.CharacterAvatar = character.Avatar
}
return resp
}
func toMessageResponse(msg *app.AIMessage, character *app.AICharacter) response.MessageResponse {
resp := response.MessageResponse{
ID: msg.ID,
ChatID: msg.ChatID,
Content: msg.Content,
Role: msg.Role,
CharacterID: msg.CharacterID,
Model: msg.Model,
PromptTokens: msg.PromptTokens,
CompletionTokens: msg.CompletionTokens,
TotalTokens: msg.TotalTokens,
SequenceNumber: msg.SequenceNumber,
CreatedAt: msg.CreatedAt,
}
if msg.Role == "assistant" && character != nil {
resp.CharacterName = character.Name
}
return resp
}

View File

@@ -6,4 +6,6 @@ type AppServiceGroup struct {
WorldInfoService
ExtensionService
RegexScriptService
ProviderService
ChatService
}

View File

@@ -21,15 +21,70 @@ import (
"gorm.io/gorm"
)
// extensionDataDir 扩展本地存储根目录
// 与原版 SillyTavern 完全一致的路径结构scripts/extensions/third-party/{name}/
// 扩展 JS 中的相对路径 import如 ../../../../../script.js依赖此目录层级来正确解析
// 所有 SillyTavern 核心脚本和扩展文件统一存储在 data/st-core-scripts/ 下,独立于 web-app/
// 扩展代码是公共的(不按用户隔离),用户间差异仅在于数据库中的配置和启用状态
const extensionDataDir = "data/st-core-scripts/scripts/extensions/third-party"
// getExtensionStorePath 获取扩展的本地存储路径: {extensionDataDir}/{extensionName}/
func getExtensionStorePath(extensionName string) string {
return filepath.Join(extensionDataDir, extensionName)
}
// GetExtensionAssetLocalPath 获取扩展资源文件的本地绝对路径
func (es *ExtensionService) GetExtensionAssetLocalPath(extensionName string, assetPath string) (string, error) {
storePath := getExtensionStorePath(extensionName)
fullPath := filepath.Join(storePath, assetPath)
// 安全检查:防止路径遍历攻击
absStore, _ := filepath.Abs(storePath)
absFile, _ := filepath.Abs(fullPath)
if !strings.HasPrefix(absFile, absStore) {
return "", errors.New("非法的资源路径")
}
// 检查文件是否存在
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
return "", fmt.Errorf("资源文件不存在: %s", assetPath)
}
return fullPath, nil
}
// ensureExtensionDir 确保扩展存储目录存在
func ensureExtensionDir(extensionName string) (string, error) {
storePath := getExtensionStorePath(extensionName)
if err := os.MkdirAll(storePath, 0755); err != nil {
return "", fmt.Errorf("创建扩展存储目录失败: %w", err)
}
return storePath, nil
}
// removeExtensionDir 删除扩展的本地存储目录
func removeExtensionDir(extensionName string) error {
storePath := getExtensionStorePath(extensionName)
if _, err := os.Stat(storePath); os.IsNotExist(err) {
return nil // 目录不存在,无需删除
}
return os.RemoveAll(storePath)
}
type ExtensionService struct{}
// CreateExtension 创建/安装扩展
func (es *ExtensionService) CreateExtension(userID uint, req *request.CreateExtensionRequest) (*app.AIExtension, error) {
// 校验名称
if req.Name == "" {
return nil, errors.New("扩展名称不能为空")
}
// 检查扩展是否已存在
var existing app.AIExtension
err := global.GVA_DB.Where("user_id = ? AND name = ?", userID, req.Name).First(&existing).Error
if err == nil {
return nil, errors.New("扩展已存在")
return nil, fmt.Errorf("扩展 %s 已存在", req.Name)
}
if err != gorm.ErrRecordNotFound {
return nil, err
@@ -136,15 +191,18 @@ func (es *ExtensionService) DeleteExtension(userID, extensionID uint, deleteFile
return errors.New("系统内置扩展不允许删除")
}
// TODO: 如果 deleteFiles=true删除扩展文件
// 这需要文件系统支持
// 删除本地扩展文件(与原版 SillyTavern 一致:卸载扩展时清理本地文件
if err := removeExtensionDir(extension.Name); err != nil {
global.GVA_LOG.Warn("删除扩展本地文件失败", zap.Error(err), zap.String("name", extension.Name))
// 不阻断删除流程
}
// 删除扩展(配置已经在扩展记录的 Settings 字段中,无需单独删除)
// 删除数据库记录
if err := global.GVA_DB.Delete(&extension).Error; err != nil {
return err
}
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID))
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID), zap.String("name", extension.Name))
return nil
}
@@ -157,6 +215,15 @@ func (es *ExtensionService) GetExtension(userID, extensionID uint) (*app.AIExten
return &extension, nil
}
// GetExtensionByID 通过扩展 ID 获取扩展信息(不限制用户,用于公开资源路由)
func (es *ExtensionService) GetExtensionByID(extensionID uint) (*app.AIExtension, error) {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ?", extensionID).First(&extension).Error; err != nil {
return nil, errors.New("扩展不存在")
}
return &extension, nil
}
// GetExtensionList 获取扩展列表
func (es *ExtensionService) GetExtensionList(userID uint, req *request.ExtensionListRequest) (*response.ExtensionListResponse, error) {
var extensions []app.AIExtension
@@ -550,9 +617,37 @@ func isGitURL(url string) bool {
return false
}
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装
// GetExtensionAssetURL 根据扩展的安装来源构建资源文件的远程 URL
func (es *ExtensionService) GetExtensionAssetURL(extension *app.AIExtension, assetPath string) (string, error) {
if extension.SourceURL == "" {
return "", errors.New("扩展没有源地址")
}
sourceURL := strings.TrimSuffix(strings.TrimSuffix(extension.SourceURL, "/"), ".git")
branch := extension.Branch
if branch == "" {
branch = "main"
}
// GitLab: repo/-/raw/branch/path
if strings.Contains(sourceURL, "gitlab.com") {
return fmt.Sprintf("%s/-/raw/%s/%s", sourceURL, branch, assetPath), nil
}
// GitHub: raw.githubusercontent.com/user/repo/branch/path
if strings.Contains(sourceURL, "github.com") {
rawURL := strings.Replace(sourceURL, "github.com", "raw.githubusercontent.com", 1)
return fmt.Sprintf("%s/%s/%s", rawURL, branch, assetPath), nil
}
// Gitee: repo/raw/branch/path
if strings.Contains(sourceURL, "gitee.com") {
return fmt.Sprintf("%s/raw/%s/%s", sourceURL, branch, assetPath), nil
}
return fmt.Sprintf("%s/%s", sourceURL, assetPath), nil
}
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装(同时下载资源文件到本地)
func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manifestURL string) (*app.AIExtension, error) {
// 创建 HTTP 客户端
client := &http.Client{
Timeout: 30 * time.Second,
}
@@ -568,7 +663,6 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
return nil, fmt.Errorf("下载 manifest.json 失败: HTTP %d", resp.StatusCode)
}
// 读取响应内容
manifestData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取 manifest.json 失败: %w", err)
@@ -580,21 +674,63 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
// 验证必填字段
if manifest.Name == "" {
return nil, errors.New("manifest.json 缺少 name 字段")
// 获取有效名称
effectiveName := manifest.GetEffectiveName()
if effectiveName == "" {
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
}
// 检查扩展是否已存在
var existing app.AIExtension
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
if err == nil {
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 创建本地存储目录并保存 manifest.json
storePath, err := ensureExtensionDir(effectiveName)
if err != nil {
return nil, err
}
if err := os.WriteFile(filepath.Join(storePath, "manifest.json"), manifestData, 0644); err != nil {
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("保存 manifest.json 失败: %w", err)
}
// 获取 manifest URL 的基础目录(用于下载关联资源)
baseURL := manifestURL[:strings.LastIndex(manifestURL, "/")+1]
// 下载 JS/CSS 等资源文件到本地
filesToDownload := []string{}
if entry := manifest.GetEffectiveEntry(); entry != "" {
filesToDownload = append(filesToDownload, entry)
}
if style := manifest.GetEffectiveStyle(); style != "" {
filesToDownload = append(filesToDownload, style)
}
filesToDownload = append(filesToDownload, manifest.Assets...)
for _, file := range filesToDownload {
if file == "" {
continue
}
fileURL := baseURL + file
if err := downloadFileToLocal(client, fileURL, filepath.Join(storePath, file)); err != nil {
global.GVA_LOG.Warn("下载扩展资源文件失败(非致命)",
zap.String("file", file),
zap.String("url", fileURL),
zap.Error(err))
}
}
global.GVA_LOG.Info("扩展文件已保存到本地",
zap.String("name", effectiveName),
zap.String("path", storePath))
// 将 manifest 转换为 map[string]interface{}
var manifestMap map[string]interface{}
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
@@ -603,13 +739,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
// 构建创建请求
createReq := &request.CreateExtensionRequest{
Name: manifest.Name,
Name: effectiveName,
DisplayName: manifest.DisplayName,
Version: manifest.Version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.Homepage,
Repository: manifest.Repository, // 使用 manifest 中的 repository
Homepage: manifest.GetEffectiveHomepage(),
Repository: manifest.Repository,
License: manifest.License,
Tags: manifest.Tags,
ExtensionType: manifest.Type,
@@ -617,13 +753,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
Dependencies: manifest.Dependencies,
Conflicts: manifest.Conflicts,
ManifestData: manifestMap,
ScriptPath: manifest.Entry,
StylePath: manifest.Style,
ScriptPath: manifest.GetEffectiveEntry(),
StylePath: manifest.GetEffectiveStyle(),
AssetsPaths: manifest.Assets,
Settings: manifest.Settings,
Options: manifest.Options,
InstallSource: "url",
SourceURL: manifestURL, // 记录原始 URL 用于更新
SourceURL: manifestURL,
AutoUpdate: manifest.AutoUpdate,
Metadata: nil,
}
@@ -636,6 +772,7 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
// 创建扩展
extension, err := es.CreateExtension(userID, createReq)
if err != nil {
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("创建扩展失败: %w", err)
}
@@ -647,6 +784,31 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
return extension, nil
}
// downloadFileToLocal 下载远程文件到本地路径
func downloadFileToLocal(client *http.Client, url string, localPath string) error {
// 确保目标文件的父目录存在
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
return err
}
resp, err := client.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
return os.WriteFile(localPath, data, 0644)
}
// UpgradeExtension 升级扩展版本(根据安装来源自动选择更新方式)
func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force bool) (*app.AIExtension, error) {
// 获取扩展信息
@@ -672,7 +834,7 @@ func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force boo
}
}
// updateExtensionFromGit 从 Git 仓库更新扩展
// updateExtensionFromGit 从 Git 仓库更新扩展(先删除旧记录和文件,再重新安装)
func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.AIExtension, force bool) (*app.AIExtension, error) {
if extension.SourceURL == "" {
return nil, errors.New("缺少 Git 仓库 URL")
@@ -683,11 +845,17 @@ func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.A
zap.String("sourceUrl", extension.SourceURL),
zap.String("branch", extension.Branch))
// 重新克隆(简单方式,避免处理本地修改)
// 先删除旧的数据库记录和本地文件
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
}
_ = removeExtensionDir(extension.Name)
// 重新克隆安装
return es.InstallExtensionFromGit(userID, extension.SourceURL, extension.Branch)
}
// updateExtensionFromURL 从 URL 更新扩展(重新下载 manifest.json
// updateExtensionFromURL 从 URL 更新扩展(先删除旧记录和文件,再重新下载安装
func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.AIExtension) (*app.AIExtension, error) {
if extension.SourceURL == "" {
return nil, errors.New("缺少 Manifest URL")
@@ -697,18 +865,24 @@ func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.A
zap.String("name", extension.Name),
zap.String("sourceUrl", extension.SourceURL))
// 重新下载并安装
// 先删除旧的数据库记录和本地文件
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
}
_ = removeExtensionDir(extension.Name)
// 重新下载安装
return es.downloadAndInstallFromManifestURL(userID, extension.SourceURL)
}
// InstallExtensionFromGit 从 Git URL 安装扩展
// InstallExtensionFromGit 从 Git URL 安装扩展(与原版 SillyTavern 一致:将源码下载到本地)
func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch string) (*app.AIExtension, error) {
// 验证 Git URL
if !strings.Contains(gitUrl, "://") && !strings.HasSuffix(gitUrl, ".git") {
return nil, errors.New("无效的 Git URL")
}
// 创建临时目录
// 先 clone 到临时目录读取 manifest获取扩展名后再移动到正式目录
tempDir, err := os.MkdirTemp("", "extension-*")
if err != nil {
return nil, fmt.Errorf("创建临时目录失败: %w", err)
@@ -717,10 +891,9 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
global.GVA_LOG.Info("开始从 Git 克隆扩展",
zap.String("gitUrl", gitUrl),
zap.String("branch", branch),
zap.String("tempDir", tempDir))
zap.String("branch", branch))
// 执行 git clone
// 执行 git clone(浅克隆)
cmd := exec.Command("git", "clone", "--depth=1", "--branch="+branch, gitUrl, tempDir)
output, err := cmd.CombinedOutput()
if err != nil {
@@ -744,31 +917,53 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
// 获取有效名称(兼容 SillyTavern manifest 没有 name 字段的情况)
effectiveName := manifest.GetEffectiveName()
if effectiveName == "" {
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
}
// 检查扩展是否已存在
var existing app.AIExtension
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
if err == nil {
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 将扩展文件保存到公共目录: web-app/public/scripts/extensions/third-party/{extensionName}/
storePath, err := ensureExtensionDir(effectiveName)
if err != nil {
return nil, err
}
// 清空目标目录(如果有残留文件)后复制 clone 内容
_ = os.RemoveAll(storePath)
if err := copyDir(tempDir, storePath); err != nil {
return nil, fmt.Errorf("保存扩展文件失败: %w", err)
}
global.GVA_LOG.Info("扩展文件已保存到本地",
zap.String("name", effectiveName),
zap.String("path", storePath))
// 将 manifest 转换为 map[string]interface{}
var manifestMap map[string]interface{}
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
return nil, fmt.Errorf("转换 manifest 失败: %w", err)
}
// 构建创建请求
// 构建创建请求(使用兼容方法获取字段值)
createReq := &request.CreateExtensionRequest{
Name: manifest.Name,
Name: effectiveName,
DisplayName: manifest.DisplayName,
Version: manifest.Version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.Homepage,
Repository: manifest.Repository, // 使用 manifest 中的 repository
Homepage: manifest.GetEffectiveHomepage(),
Repository: manifest.Repository,
License: manifest.License,
Tags: manifest.Tags,
ExtensionType: manifest.Type,
@@ -776,28 +971,69 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
Dependencies: manifest.Dependencies,
Conflicts: manifest.Conflicts,
ManifestData: manifestMap,
ScriptPath: manifest.Entry,
StylePath: manifest.Style,
ScriptPath: manifest.GetEffectiveEntry(),
StylePath: manifest.GetEffectiveStyle(),
AssetsPaths: manifest.Assets,
Settings: manifest.Settings,
Options: manifest.Options,
InstallSource: "git",
SourceURL: gitUrl, // 记录 Git URL 用于更新
Branch: branch, // 记录分支
SourceURL: gitUrl,
Branch: branch,
AutoUpdate: manifest.AutoUpdate,
Metadata: manifest.Metadata,
}
// 确保扩展类型有效
if createReq.ExtensionType == "" {
createReq.ExtensionType = "ui"
}
// 创建扩展记录
extension, err := es.CreateExtension(userID, createReq)
if err != nil {
// 创建失败则清理本地文件
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("创建扩展记录失败: %w", err)
}
global.GVA_LOG.Info("从 Git 安装扩展成功",
zap.Uint("extensionID", extension.ID),
zap.String("name", extension.Name),
zap.String("version", extension.Version))
zap.String("version", extension.Version),
zap.String("localPath", storePath))
return extension, nil
}
// copyDir 递归复制目录(排除 .git 目录以节省空间)
func copyDir(src, dst string) error {
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 计算相对路径
relPath, err := filepath.Rel(src, path)
if err != nil {
return err
}
// 排除 .git 目录
if info.IsDir() && info.Name() == ".git" {
return filepath.SkipDir
}
dstPath := filepath.Join(dst, relPath)
if info.IsDir() {
return os.MkdirAll(dstPath, info.Mode())
}
// 复制文件
srcFile, err := os.ReadFile(path)
if err != nil {
return err
}
return os.WriteFile(dstPath, srcFile, info.Mode())
})
}

View File

@@ -0,0 +1,949 @@
package app
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/app/response"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type ProviderService struct{}
// ==================== 提供商 CRUD ====================
// CreateProvider 创建AI提供商
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
// 加密 API Key
encryptedKey := encryptAPIKey(req.APIKey)
// 序列化额外配置
apiConfigJSON, _ := json.Marshal(req.APIConfig)
if req.APIConfig == nil {
apiConfigJSON = []byte("{}")
}
// 根据类型确定能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
// 如果 BaseURL 为空,使用默认地址
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
provider := app.AIProvider{
UserID: &userID,
ProviderName: req.ProviderName,
ProviderType: req.ProviderType,
BaseURL: baseURL,
APIKey: encryptedKey,
APIConfig: datatypes.JSON(apiConfigJSON),
Capabilities: datatypes.JSON(capJSON),
IsEnabled: true,
IsDefault: false,
}
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 创建提供商
if err := tx.Create(&provider).Error; err != nil {
return err
}
// 如果携带了模型列表,同时创建模型
if len(req.Models) > 0 {
for _, m := range req.Models {
modelConfigJSON, _ := json.Marshal(m.Config)
if m.Config == nil {
modelConfigJSON = []byte("{}")
}
isEnabled := true
if m.IsEnabled != nil {
isEnabled = *m.IsEnabled
}
model := app.AIModel{
ProviderID: provider.ID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: datatypes.JSON(modelConfigJSON),
IsEnabled: isEnabled,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
} else {
// 没有指定模型时,自动添加预设模型
presets := getPresetModels(req.ProviderType)
for _, p := range presets {
model := app.AIModel{
ProviderID: provider.ID,
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
Config: datatypes.JSON([]byte("{}")),
IsEnabled: true,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
}
// 如果是用户的第一个提供商,自动设为默认
var count int64
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
if count == 0 {
provider.IsDefault = true
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(provider.ID, userID)
}
// GetProviderList 获取用户的提供商列表
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
if req.Keyword != "" {
keyword := "%" + req.Keyword + "%"
db = db.Where("provider_name ILIKE ?", keyword)
}
var total int64
db.Count(&total)
var providers []app.AIProvider
offset := (req.Page - 1) * req.PageSize
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
Offset(offset).Limit(req.PageSize).Find(&providers).Error
if err != nil {
return response.ProviderListResponse{}, err
}
// 获取所有提供商的模型
providerIDs := make([]uint, len(providers))
for i, p := range providers {
providerIDs[i] = p.ID
}
var models []app.AIModel
if len(providerIDs) > 0 {
global.GVA_DB.Where("provider_id IN ?", providerIDs).
Order("model_type ASC, model_name ASC").Find(&models)
}
// 按提供商ID分组模型
modelMap := make(map[uint][]app.AIModel)
for _, m := range models {
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
}
list := make([]response.ProviderResponse, len(providers))
for i, p := range providers {
list[i] = toProviderResponse(&p, modelMap[p.ID])
}
return response.ProviderListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// GetProviderDetail 获取提供商详情
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
var models []app.AIModel
global.GVA_DB.Where("provider_id = ?", providerID).
Order("model_type ASC, model_name ASC").Find(&models)
return toProviderResponse(&provider, models), nil
}
// UpdateProvider 更新提供商
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
// 更新字段
updates := map[string]interface{}{
"provider_name": req.ProviderName,
"provider_type": req.ProviderType,
"base_url": req.BaseURL,
}
// APIKey 不为空时才更新
if req.APIKey != "" {
updates["api_key"] = encryptAPIKey(req.APIKey)
}
if req.APIConfig != nil {
apiConfigJSON, _ := json.Marshal(req.APIConfig)
updates["api_config"] = datatypes.JSON(apiConfigJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if req.SortOrder != nil {
updates["sort_order"] = *req.SortOrder
}
// 更新能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
updates["capabilities"] = datatypes.JSON(capJSON)
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
return err
}
// 处理设置默认
if req.IsDefault != nil && *req.IsDefault {
// 先取消其他默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ? AND id != ?", userID, req.ID).
Update("is_default", false).Error; err != nil {
return err
}
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(req.ID, userID)
}
// DeleteProvider 删除提供商
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 删除关联的模型
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
return err
}
// 删除提供商
if err := tx.Delete(&provider).Error; err != nil {
return err
}
// 如果删除的是默认提供商,自动将第一个提供商设为默认
if provider.IsDefault {
var firstProvider app.AIProvider
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
tx.Model(&firstProvider).Update("is_default", true)
}
}
return nil
})
}
// SetDefaultProvider 设置默认提供商
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 先取消所有默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ?", userID).
Update("is_default", false).Error; err != nil {
return err
}
// 设置新默认
return tx.Model(&provider).Update("is_default", true).Error
})
}
// ==================== 模型 CRUD ====================
// AddModel 为提供商添加模型
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
// 验证提供商归属
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
if err != nil {
return response.ModelResponse{}, errors.New("提供商不存在")
}
configJSON, _ := json.Marshal(req.Config)
if req.Config == nil {
configJSON = []byte("{}")
}
isEnabled := true
if req.IsEnabled != nil {
isEnabled = *req.IsEnabled
}
model := app.AIModel{
ProviderID: req.ProviderID,
ModelName: req.ModelName,
DisplayName: req.DisplayName,
ModelType: req.ModelType,
Config: datatypes.JSON(configJSON),
IsEnabled: isEnabled,
}
if err := global.GVA_DB.Create(&model).Error; err != nil {
return response.ModelResponse{}, err
}
return toModelResponse(&model), nil
}
// UpdateModel 更新模型
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
First(&model).Error
if err != nil {
return response.ModelResponse{}, errors.New("模型不存在")
}
updates := map[string]interface{}{
"model_name": req.ModelName,
"display_name": req.DisplayName,
"model_type": req.ModelType,
}
if req.Config != nil {
configJSON, _ := json.Marshal(req.Config)
updates["config"] = datatypes.JSON(configJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
return response.ModelResponse{}, err
}
// 重新查询
global.GVA_DB.First(&model, model.ID)
return toModelResponse(&model), nil
}
// DeleteModel 删除模型
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
First(&model).Error
if err != nil {
return errors.New("模型不存在")
}
return global.GVA_DB.Delete(&model).Error
}
// ==================== 连通性测试 ====================
// TestProvider 测试提供商连通性
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// TestExistingProvider 测试已保存的提供商连通性
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.TestProviderResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.TestProvider(request.TestProviderRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 辅助查询 ====================
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
return []response.ProviderTypeOption{
{
Value: "openai",
Label: "OpenAI",
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
DefaultURL: "https://api.openai.com/v1",
},
{
Value: "claude",
Label: "Claude",
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
DefaultURL: "https://api.anthropic.com",
},
{
Value: "gemini",
Label: "Google Gemini",
Description: "Google 的 Gemini 系列模型,支持多模态",
DefaultURL: "https://generativelanguage.googleapis.com",
},
{
Value: "custom",
Label: "自定义OpenAI 兼容)",
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
DefaultURL: "",
},
}
}
// GetPresetModels 获取指定提供商类型的预设模型列表
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
presets := getPresetModels(providerType)
result := make([]response.PresetModelOption, len(presets))
for i, p := range presets {
result[i] = response.PresetModelOption{
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
}
}
return result
}
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("请先配置 AI 接口")
}
return nil, err
}
return &provider, nil
}
// GetDecryptedAPIKey 获取解密后的API密钥内部方法给AI调用用
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
return decryptAPIKey(provider.APIKey)
}
// ==================== 内部辅助函数 ====================
// toProviderResponse 转换为响应对象
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
apiConfig := json.RawMessage(p.APIConfig)
if len(apiConfig) == 0 {
apiConfig = json.RawMessage("{}")
}
capabilities := json.RawMessage(p.Capabilities)
if len(capabilities) == 0 {
capabilities = json.RawMessage("[]")
}
// 模型列表
modelList := make([]response.ModelResponse, len(models))
for i, m := range models {
modelList[i] = toModelResponse(&m)
}
// API Key 提示
apiKeyHint := ""
apiKeySet := false
if p.APIKey != "" {
apiKeySet = true
decrypted := decryptAPIKey(p.APIKey)
if len(decrypted) > 8 {
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
} else if len(decrypted) > 0 {
apiKeyHint = "****"
}
}
return response.ProviderResponse{
ID: p.ID,
ProviderName: p.ProviderName,
ProviderType: p.ProviderType,
BaseURL: p.BaseURL,
APIKeySet: apiKeySet,
APIKeyHint: apiKeyHint,
APIConfig: apiConfig,
Capabilities: capabilities,
IsEnabled: p.IsEnabled,
IsDefault: p.IsDefault,
SortOrder: p.SortOrder,
Models: modelList,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
// toModelResponse 转换模型为响应对象
func toModelResponse(m *app.AIModel) response.ModelResponse {
config := json.RawMessage(m.Config)
if len(config) == 0 {
config = json.RawMessage("{}")
}
return response.ModelResponse{
ID: m.ID,
ProviderID: m.ProviderID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: config,
IsEnabled: m.IsEnabled,
CreatedAt: m.CreatedAt,
}
}
// encryptAPIKey 加密API密钥
// TODO: 后续可以替换为更安全的加密方式(如 AES当前使用简单的 Base64 编码
func encryptAPIKey(key string) string {
if key == "" {
return ""
}
// 简单的混淆处理,生产环境应替换为 AES 加密
import_encoding := []byte(key)
for i := range import_encoding {
import_encoding[i] ^= 0x5A
}
return fmt.Sprintf("enc:%x", import_encoding)
}
// decryptAPIKey 解密API密钥
func decryptAPIKey(encrypted string) string {
if encrypted == "" {
return ""
}
if !strings.HasPrefix(encrypted, "enc:") {
return encrypted // 未加密的旧数据,直接返回
}
hexStr := encrypted[4:]
var data []byte
fmt.Sscanf(hexStr, "%x", &data)
for i := range data {
data[i] ^= 0x5A
}
return string(data)
}
// getDefaultBaseURL 获取默认API基础地址
func getDefaultBaseURL(providerType string) string {
switch providerType {
case "openai":
return "https://api.openai.com/v1"
case "claude":
return "https://api.anthropic.com"
case "gemini":
return "https://generativelanguage.googleapis.com"
default:
return ""
}
}
// getDefaultCapabilities 获取默认能力列表
func getDefaultCapabilities(providerType string) []string {
switch providerType {
case "openai":
return []string{"chat", "image_gen"}
case "claude":
return []string{"chat"}
case "gemini":
return []string{"chat", "image_gen"}
case "custom":
return []string{"chat"}
default:
return []string{"chat"}
}
}
// presetModel 预设模型内部结构
type presetModel struct {
ModelName string
DisplayName string
ModelType string
}
// getPresetModels 获取预设模型列表
func getPresetModels(providerType string) []presetModel {
switch providerType {
case "openai":
return []presetModel{
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
}
case "claude":
return []presetModel{
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
}
case "gemini":
return []presetModel{
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
}
case "custom":
return []presetModel{} // 自定义不提供预设
default:
return []presetModel{}
}
}
// ==================== 获取远程模型列表 ====================
// FetchRemoteModels 从远程API获取可用模型列表
// 所有提供商类型统一使用 baseURL + /models 端点OpenAI 兼容格式)
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
result := fetchModelsUniversal(baseURL, req.APIKey)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 发送测试消息 ====================
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
// 所有提供商类型统一使用 baseURL + /chat/completions 端点OpenAI 兼容格式)
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
message := req.Message
if message == "" {
message = "你好,请用一句话介绍你自己。"
}
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// SendTestMessageExisting 发送测试消息(已保存的提供商)
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.SendTestMessage(request.SendTestMessageRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
ModelName: modelName,
Message: message,
})
}
// ==================== 连通性测试实现 ====================
// testOpenAICompatible 测试 OpenAI 兼容接口
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.TestProviderResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
var modelNames []string
if err := json.Unmarshal(body, &modelsResp); err == nil {
for _, m := range modelsResp.Data {
modelNames = append(modelNames, m.ID)
}
}
return response.TestProviderResponse{
Success: true,
Message: "连接成功",
Models: modelNames,
}
}
// ==================== 获取远程模型列表实现 ====================
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
// 使用 baseURL + /models 端点Authorization: Bearer 鉴权
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
if resp.StatusCode == 401 {
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.FetchRemoteModelsResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
body, _ := io.ReadAll(resp.Body)
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
var modelsData struct {
Data []struct {
ID string `json:"id"`
OwnedBy string `json:"owned_by"`
} `json:"data"`
}
var models []response.RemoteModel
if err := json.Unmarshal(body, &modelsData); err == nil {
for _, m := range modelsData.Data {
models = append(models, response.RemoteModel{
ID: m.ID,
OwnedBy: m.OwnedBy,
})
}
}
return response.FetchRemoteModelsResponse{
Success: true,
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
Models: models,
}
}
// ==================== 发送测试消息实现 ====================
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
// 使用 baseURL + /chat/completions 端点Authorization: Bearer 鉴权
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
payload := map[string]interface{}{
"model": modelName,
"max_tokens": 100,
"messages": []map[string]string{
{"role": "user", "content": message},
},
}
payloadBytes, _ := json.Marshal(payload)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
var errResp struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
}
return response.SendTestMessageResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析 OpenAI 兼容格式的 chat completion 响应
var chatResp struct {
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &chatResp); err != nil {
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
}
reply := ""
if len(chatResp.Choices) > 0 {
reply = chatResp.Choices[0].Message.Content
}
return response.SendTestMessageResponse{
Success: true,
Message: "测试成功",
Reply: reply,
Model: chatResp.Model,
Tokens: chatResp.Usage.TotalTokens,
}
}