🎨 优化扩展模块,完成ai接入和对话功能
This commit is contained in:
470
server/api/v1/app/chat.go
Normal file
470
server/api/v1/app/chat.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/middleware"
|
||||
appModel "git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/common/response"
|
||||
appService "git.echol.cn/loser/st/server/service/app"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ChatApi struct{}
|
||||
|
||||
// ==================== 对话管理 ====================
|
||||
|
||||
// CreateChat 创建对话
|
||||
// @Tags 对话
|
||||
// @Summary 创建新对话
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.CreateChatRequest true "角色卡ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatResponse} "创建成功"
|
||||
// @Router /app/chat [post]
|
||||
func (ca *ChatApi) CreateChat(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.CreateChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := chatService.CreateChat(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("创建对话失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(chat, c)
|
||||
}
|
||||
|
||||
// GetChatList 获取对话列表
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话列表
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatListResponse} "获取成功"
|
||||
// @Router /app/chat/list [get]
|
||||
func (ca *ChatApi) GetChatList(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.ChatListRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
|
||||
list, err := chatService.GetChatList(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取对话列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败", c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(list, c)
|
||||
}
|
||||
|
||||
// GetChatDetail 获取对话详情(含消息)
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话详情
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatDetailResponse} "获取成功"
|
||||
// @Router /app/chat/:id [get]
|
||||
func (ca *ChatApi) GetChatDetail(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
detail, err := chatService.GetChatDetail(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(detail, c)
|
||||
}
|
||||
|
||||
// GetChatMessages 分页获取消息
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话消息
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Success 200 {object} response.Response{data=appResponse.MessageListResponse} "获取成功"
|
||||
// @Router /app/chat/:id/messages [get]
|
||||
func (ca *ChatApi) GetChatMessages(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
var req request.ChatMessagesRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 50
|
||||
}
|
||||
|
||||
messages, err := chatService.GetChatMessages(uint(id), req, userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(messages, c)
|
||||
}
|
||||
|
||||
// DeleteChat 删除对话
|
||||
// @Tags 对话
|
||||
// @Summary 删除对话
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/chat/:id [delete]
|
||||
func (ca *ChatApi) DeleteChat(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := chatService.DeleteChat(uint(id), userID); err != nil {
|
||||
global.GVA_LOG.Error("删除对话失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 消息操作 ====================
|
||||
|
||||
// SendMessage 发送消息并获取AI回复(SSE 流式响应)
|
||||
// @Tags 对话
|
||||
// @Summary 发送消息(SSE流式)
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce text/event-stream
|
||||
// @Param data body request.SendMessageRequest true "消息内容"
|
||||
// @Router /app/chat/send [post]
|
||||
func (ca *ChatApi) SendMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.SendMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 获取对话上下文
|
||||
chat, character, historyMsgs, err := chatService.GetChatForAI(req.ChatID, userID)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 保存用户消息
|
||||
_, err = chatService.SaveUserMessage(chat.ID, userID, req.Content)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"code": 7, "msg": "保存消息失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2.1 世界书匹配(根据角色卡 + 历史消息)
|
||||
var worldInfoTexts []string
|
||||
if character != nil && chat.CharacterID != nil {
|
||||
// 收集最近的消息文本(用于关键词匹配),附加当前用户输入
|
||||
messagesForMatch := make([]string, 0, len(historyMsgs)+1)
|
||||
for _, m := range historyMsgs {
|
||||
if m.Content != "" {
|
||||
messagesForMatch = append(messagesForMatch, m.Content)
|
||||
}
|
||||
}
|
||||
if req.Content != "" {
|
||||
messagesForMatch = append(messagesForMatch, req.Content)
|
||||
}
|
||||
|
||||
matchReq := request.MatchWorldInfoRequest{
|
||||
CharacterID: *chat.CharacterID,
|
||||
Messages: messagesForMatch,
|
||||
ScanDepth: 10,
|
||||
MaxTokens: 2000,
|
||||
}
|
||||
|
||||
if result, wErr := worldInfoService.MatchWorldInfo(userID, &matchReq); wErr != nil {
|
||||
global.GVA_LOG.Warn("匹配世界书失败",
|
||||
zap.Uint("chatID", chat.ID),
|
||||
zap.Uint("characterID", *chat.CharacterID),
|
||||
zap.Error(wErr))
|
||||
} else if result != nil && len(result.Entries) > 0 {
|
||||
// 按 position 分组,暂时统一作为 System 级别补充上下文
|
||||
beforeChar := make([]string, 0)
|
||||
afterChar := make([]string, 0)
|
||||
for _, entry := range result.Entries {
|
||||
text := strings.TrimSpace(entry.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if entry.Position == "after_char" {
|
||||
afterChar = append(afterChar, text)
|
||||
} else {
|
||||
// 默认 before_char
|
||||
beforeChar = append(beforeChar, text)
|
||||
}
|
||||
}
|
||||
|
||||
if len(beforeChar) > 0 {
|
||||
worldInfoTexts = append(worldInfoTexts, strings.Join(beforeChar, "\n\n"))
|
||||
}
|
||||
if len(afterChar) > 0 {
|
||||
worldInfoTexts = append(worldInfoTexts, strings.Join(afterChar, "\n\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 获取 AI 提供商和模型
|
||||
var provider *appModel.AIProvider
|
||||
var modelName string
|
||||
|
||||
if req.ProviderID != nil {
|
||||
// 指定了提供商
|
||||
p, pErr := providerService.GetUserDefaultProvider(userID) // 简化:暂时仅用默认
|
||||
if pErr != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": "指定的 AI 接口不可用"})
|
||||
return
|
||||
}
|
||||
provider = p
|
||||
} else {
|
||||
// 使用默认提供商
|
||||
p, pErr := providerService.GetUserDefaultProvider(userID)
|
||||
if pErr != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": pErr.Error()})
|
||||
return
|
||||
}
|
||||
provider = p
|
||||
}
|
||||
|
||||
if req.ModelName != "" {
|
||||
modelName = req.ModelName
|
||||
} else {
|
||||
// 使用提供商的第一个启用的聊天模型
|
||||
modelName = getDefaultChatModelForProvider(provider.ID)
|
||||
if modelName == "" {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": "该 AI 接口没有可用的聊天模型"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 构建 Prompt(角色卡 + 世界书 + 历史消息)
|
||||
prompt := appService.BuildPrompt(character, historyMsgs)
|
||||
|
||||
// 将世界书内容插入为额外 System 提示(靠近角色定义)
|
||||
if len(worldInfoTexts) > 0 {
|
||||
worldInfoContent := "World Info:\n" + strings.Join(worldInfoTexts, "\n\n")
|
||||
inserted := false
|
||||
for i, msg := range prompt {
|
||||
if msg.Role == "system" {
|
||||
// 在第一个 system 消息之后插入一条世界书消息
|
||||
newPrompt := make([]appService.AIMessagePayload, 0, len(prompt)+1)
|
||||
newPrompt = append(newPrompt, prompt[:i+1]...)
|
||||
newPrompt = append(newPrompt, appService.AIMessagePayload{
|
||||
Role: "system",
|
||||
Content: worldInfoContent,
|
||||
})
|
||||
newPrompt = append(newPrompt, prompt[i+1:]...)
|
||||
prompt = newPrompt
|
||||
inserted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !inserted {
|
||||
prompt = append([]appService.AIMessagePayload{{
|
||||
Role: "system",
|
||||
Content: worldInfoContent,
|
||||
}}, prompt...)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加用户的新消息
|
||||
prompt = append(prompt, appService.AIMessagePayload{
|
||||
Role: "user",
|
||||
Content: req.Content,
|
||||
})
|
||||
|
||||
// 5. 设置 SSE 响应头
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
// 6. 流式调用 AI
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan appService.AIStreamChunk, 100)
|
||||
go appService.StreamAIResponse(ctx, provider, modelName, prompt, ch)
|
||||
|
||||
var fullContent string
|
||||
var promptTokens, completionTokens int
|
||||
|
||||
flusher, ok := c.Writer.(interface{ Flush() })
|
||||
if !ok {
|
||||
c.JSON(500, gin.H{"code": 7, "msg": "服务器不支持流式响应"})
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range ch {
|
||||
if chunk.Error != "" {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": chunk.Error,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
if chunk.Content != "" {
|
||||
fullContent += chunk.Content
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "content",
|
||||
"content": chunk.Content,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if chunk.Done {
|
||||
promptTokens = chunk.PromptTokens
|
||||
completionTokens = chunk.CompletionTokens
|
||||
|
||||
// 保存 AI 回复
|
||||
if fullContent != "" {
|
||||
chatService.SaveAssistantMessage(
|
||||
chat.ID, chat.CharacterID, fullContent,
|
||||
modelName, promptTokens, completionTokens,
|
||||
)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "done",
|
||||
"model": modelName,
|
||||
"promptTokens": promptTokens,
|
||||
"completionTokens": completionTokens,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EditMessage 编辑消息
|
||||
// @Tags 对话
|
||||
// @Summary 编辑消息内容
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.EditMessageRequest true "编辑信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.MessageResponse} "编辑成功"
|
||||
// @Router /app/chat/message/edit [post]
|
||||
func (ca *ChatApi) EditMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.EditMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := chatService.EditMessage(req, userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(msg, c)
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息
|
||||
// @Tags 对话
|
||||
// @Summary 删除消息
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.DeleteMessageRequest true "消息ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/chat/message/delete [post]
|
||||
func (ca *ChatApi) DeleteMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.DeleteMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := chatService.DeleteMessage(req.MessageID, userID); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 内部辅助 ====================
|
||||
|
||||
// getDefaultChatModelForProvider 获取提供商的默认聊天模型
|
||||
func getDefaultChatModelForProvider(providerID uint) string {
|
||||
var model appModel.AIModel
|
||||
err := global.GVA_DB.Where("provider_id = ? AND model_type = ? AND is_enabled = ?",
|
||||
providerID, "chat", true).
|
||||
Order("created_at ASC").
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return model.ModelName
|
||||
}
|
||||
@@ -8,9 +8,13 @@ type ApiGroup struct {
|
||||
WorldInfoApi
|
||||
ExtensionApi
|
||||
RegexScriptApi
|
||||
ProviderApi
|
||||
ChatApi
|
||||
}
|
||||
|
||||
var (
|
||||
authService = service.ServiceGroupApp.AppServiceGroup.AuthService
|
||||
characterService = service.ServiceGroupApp.AppServiceGroup.CharacterService
|
||||
providerService = service.ServiceGroupApp.AppServiceGroup.ProviderService
|
||||
chatService = service.ServiceGroupApp.AppServiceGroup.ChatService
|
||||
)
|
||||
|
||||
@@ -3,7 +3,10 @@ package app
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
@@ -563,3 +566,72 @@ func (a *ExtensionApi) InstallExtensionFromGit(c *gin.Context) {
|
||||
|
||||
sysResponse.OkWithData(response.ToExtensionResponse(extension), c)
|
||||
}
|
||||
|
||||
// ProxyExtensionAsset 获取扩展资源文件(从本地文件系统读取)
|
||||
// @Summary 获取扩展资源文件
|
||||
// @Description 从本地存储读取扩展的 JS/CSS 等资源文件(与原版 SillyTavern 一致,扩展文件存储在本地)
|
||||
// @Tags 扩展管理
|
||||
// @Produce octet-stream
|
||||
// @Param id path int true "扩展ID"
|
||||
// @Param path path string true "资源文件路径"
|
||||
// @Success 200 {file} binary
|
||||
// @Router /app/extension/:id/asset/*path [get]
|
||||
func (a *ExtensionApi) ProxyExtensionAsset(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
sysResponse.FailWithMessage("无效的扩展ID", c)
|
||||
return
|
||||
}
|
||||
extensionID := uint(id)
|
||||
|
||||
// 获取资源路径(去掉前导 /)
|
||||
assetPath := c.Param("path")
|
||||
if len(assetPath) > 0 && assetPath[0] == '/' {
|
||||
assetPath = assetPath[1:]
|
||||
}
|
||||
if assetPath == "" {
|
||||
sysResponse.FailWithMessage("资源路径不能为空", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 通过扩展 ID 查库获取信息(公开路由,不做 userID 过滤)
|
||||
extInfo, err := extensionService.GetExtensionByID(extensionID)
|
||||
if err != nil {
|
||||
sysResponse.FailWithMessage("扩展不存在", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 从本地文件系统读取资源
|
||||
localPath, err := extensionService.GetExtensionAssetLocalPath(extInfo.Name, assetPath)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取扩展资源失败",
|
||||
zap.Error(err),
|
||||
zap.String("name", extInfo.Name),
|
||||
zap.String("asset", assetPath))
|
||||
sysResponse.FailWithMessage("资源不存在: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("读取扩展资源文件失败", zap.Error(err), zap.String("path", localPath))
|
||||
sysResponse.FailWithMessage("资源读取失败", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据文件扩展名设置正确的 Content-Type
|
||||
fileExt := filepath.Ext(assetPath)
|
||||
contentType := mime.TypeByExtension(fileExt)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
// 设置缓存和 CORS 头
|
||||
c.Header("Content-Type", contentType)
|
||||
c.Header("Cache-Control", "public, max-age=3600")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
c.Data(http.StatusOK, contentType, data)
|
||||
}
|
||||
|
||||
476
server/api/v1/app/provider.go
Normal file
476
server/api/v1/app/provider.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/middleware"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/common/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ProviderApi struct{}
|
||||
|
||||
// ==================== 提供商接口 ====================
|
||||
|
||||
// CreateProvider 创建AI提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 创建AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.CreateProviderRequest true "提供商信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "创建成功"
|
||||
// @Router /app/provider [post]
|
||||
func (pa *ProviderApi) CreateProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.CreateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := providerService.CreateProvider(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("创建AI提供商失败", zap.Error(err))
|
||||
response.FailWithMessage("创建失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(provider, c)
|
||||
}
|
||||
|
||||
// GetProviderList 获取AI提供商列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取AI提供商列表
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Param keyword query string false "搜索关键词"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderListResponse} "获取成功"
|
||||
// @Router /app/provider/list [get]
|
||||
func (pa *ProviderApi) GetProviderList(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.ProviderListRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 默认分页
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
|
||||
list, err := providerService.GetProviderList(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取AI提供商列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败", c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(list, c)
|
||||
}
|
||||
|
||||
// GetProviderDetail 获取AI提供商详情
|
||||
// @Tags AI配置
|
||||
// @Summary 获取AI提供商详情
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "获取成功"
|
||||
// @Router /app/provider/:id [get]
|
||||
func (pa *ProviderApi) GetProviderDetail(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := providerService.GetProviderDetail(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(provider, c)
|
||||
}
|
||||
|
||||
// UpdateProvider 更新AI提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 更新AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.UpdateProviderRequest true "更新信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "更新成功"
|
||||
// @Router /app/provider [put]
|
||||
func (pa *ProviderApi) UpdateProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.UpdateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := providerService.UpdateProvider(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("更新AI提供商失败", zap.Error(err))
|
||||
response.FailWithMessage("更新失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(provider, c)
|
||||
}
|
||||
|
||||
// DeleteProvider 删除AI提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 删除AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/provider/:id [delete]
|
||||
func (pa *ProviderApi) DeleteProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := providerService.DeleteProvider(uint(id), userID); err != nil {
|
||||
global.GVA_LOG.Error("删除AI提供商失败", zap.Error(err))
|
||||
response.FailWithMessage("删除失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// SetDefaultProvider 设置默认提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 设置默认AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.SetDefaultProviderRequest true "提供商ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "设置成功"
|
||||
// @Router /app/provider/setDefault [post]
|
||||
func (pa *ProviderApi) SetDefaultProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.SetDefaultProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := providerService.SetDefaultProvider(req.ProviderID, userID); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("设置成功", c)
|
||||
}
|
||||
|
||||
// TestProvider 测试提供商连通性(不需要先保存)
|
||||
// @Tags AI配置
|
||||
// @Summary 测试AI提供商连通性
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.TestProviderRequest true "测试参数"
|
||||
// @Success 200 {object} response.Response{data=appResponse.TestProviderResponse} "测试结果"
|
||||
// @Router /app/provider/test [post]
|
||||
func (pa *ProviderApi) TestProvider(c *gin.Context) {
|
||||
var req request.TestProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.TestProvider(req)
|
||||
if err != nil {
|
||||
response.FailWithMessage("测试失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// TestExistingProvider 测试已保存的提供商连通性
|
||||
// @Tags AI配置
|
||||
// @Summary 测试已保存的提供商连通性
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.TestProviderResponse} "测试结果"
|
||||
// @Router /app/provider/test/:id [get]
|
||||
func (pa *ProviderApi) TestExistingProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.TestExistingProvider(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// ==================== 模型接口 ====================
|
||||
|
||||
// AddModel 为提供商添加模型
|
||||
// @Tags AI配置
|
||||
// @Summary 添加AI模型
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.CreateModelRequest true "模型信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ModelResponse} "创建成功"
|
||||
// @Router /app/provider/model [post]
|
||||
func (pa *ProviderApi) AddModel(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.CreateModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
model, err := providerService.AddModel(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("添加模型失败", zap.Error(err))
|
||||
response.FailWithMessage("添加失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(model, c)
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
// @Tags AI配置
|
||||
// @Summary 更新AI模型
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.UpdateModelRequest true "更新信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ModelResponse} "更新成功"
|
||||
// @Router /app/provider/model [put]
|
||||
func (pa *ProviderApi) UpdateModel(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.UpdateModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
model, err := providerService.UpdateModel(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("更新模型失败", zap.Error(err))
|
||||
response.FailWithMessage("更新失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(model, c)
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
// @Tags AI配置
|
||||
// @Summary 删除AI模型
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "模型ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/provider/model/:id [delete]
|
||||
func (pa *ProviderApi) DeleteModel(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := providerService.DeleteModel(uint(id), userID); err != nil {
|
||||
global.GVA_LOG.Error("删除模型失败", zap.Error(err))
|
||||
response.FailWithMessage("删除失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 远程模型获取 ====================
|
||||
|
||||
// FetchRemoteModels 获取远程可用模型列表(不需要先保存)
|
||||
// @Tags AI配置
|
||||
// @Summary 从远程获取可用模型列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.FetchRemoteModelsRequest true "提供商配置"
|
||||
// @Success 200 {object} response.Response{data=appResponse.FetchRemoteModelsResponse} "获取结果"
|
||||
// @Router /app/provider/fetchModels [post]
|
||||
func (pa *ProviderApi) FetchRemoteModels(c *gin.Context) {
|
||||
var req request.FetchRemoteModelsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.FetchRemoteModels(req)
|
||||
if err != nil {
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取已保存提供商的远程可用模型
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.FetchRemoteModelsResponse} "获取结果"
|
||||
// @Router /app/provider/fetchModels/:id [get]
|
||||
func (pa *ProviderApi) FetchRemoteModelsExisting(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.FetchRemoteModelsExisting(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// ==================== 测试消息 ====================
|
||||
|
||||
// SendTestMessage 发送测试消息(不需要先保存)
|
||||
// @Tags AI配置
|
||||
// @Summary 发送测试消息验证AI接口
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.SendTestMessageRequest true "测试参数"
|
||||
// @Success 200 {object} response.Response{data=appResponse.SendTestMessageResponse} "测试结果"
|
||||
// @Router /app/provider/testMessage [post]
|
||||
func (pa *ProviderApi) SendTestMessage(c *gin.Context) {
|
||||
var req request.SendTestMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.SendTestMessage(req)
|
||||
if err != nil {
|
||||
response.FailWithMessage("测试失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// SendTestMessageExisting 使用已保存的提供商发送测试消息
|
||||
// @Tags AI配置
|
||||
// @Summary 使用已保存的提供商发送测试消息
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Param data body object{modelName=string,message=string} true "测试参数"
|
||||
// @Success 200 {object} response.Response{data=appResponse.SendTestMessageResponse} "测试结果"
|
||||
// @Router /app/provider/testMessage/:id [post]
|
||||
func (pa *ProviderApi) SendTestMessageExisting(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
ModelName string `json:"modelName" binding:"required"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.SendTestMessageExisting(uint(id), userID, body.ModelName, body.Message)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// ==================== 辅助接口 ====================
|
||||
|
||||
// GetProviderTypes 获取支持的提供商类型列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取支持的AI提供商类型
|
||||
// @Produce application/json
|
||||
// @Success 200 {object} response.Response{data=[]appResponse.ProviderTypeOption} "获取成功"
|
||||
// @Router /app/provider/types [get]
|
||||
func (pa *ProviderApi) GetProviderTypes(c *gin.Context) {
|
||||
types := providerService.GetProviderTypes()
|
||||
response.OkWithData(types, c)
|
||||
}
|
||||
|
||||
// GetPresetModels 获取预设模型列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取指定提供商类型的预设模型
|
||||
// @Produce application/json
|
||||
// @Param type query string true "提供商类型(openai/claude/gemini/custom)"
|
||||
// @Success 200 {object} response.Response{data=[]appResponse.PresetModelOption} "获取成功"
|
||||
// @Router /app/provider/presetModels [get]
|
||||
func (pa *ProviderApi) GetPresetModels(c *gin.Context) {
|
||||
providerType := c.Query("type")
|
||||
if providerType == "" {
|
||||
response.FailWithMessage("请指定提供商类型", c)
|
||||
return
|
||||
}
|
||||
|
||||
models := providerService.GetPresetModels(providerType)
|
||||
response.OkWithData(models, c)
|
||||
}
|
||||
@@ -45,6 +45,11 @@ func Routers() *gin.Engine {
|
||||
Router.Use(gin.Logger())
|
||||
}
|
||||
|
||||
// 跨域配置(前台应用需要)
|
||||
// 必须在静态文件路由之前注册,否则静态文件跨域会失败
|
||||
Router.Use(middleware.Cors())
|
||||
global.GVA_LOG.Info("use middleware cors")
|
||||
|
||||
if !global.GVA_CONFIG.MCP.Separate {
|
||||
|
||||
sseServer := McpRun()
|
||||
@@ -63,20 +68,23 @@ func Routers() *gin.Engine {
|
||||
exampleRouter := router.RouterGroupApp.Example
|
||||
appRouter := router.RouterGroupApp.App // 前台应用路由
|
||||
|
||||
// 前台用户端静态文件服务(web-app)
|
||||
// 开发环境:直接使用 web-app/public 目录
|
||||
// 生产环境:使用打包后的 web-app/dist 目录
|
||||
webAppPath := "../web-app/public"
|
||||
if _, err := os.Stat(webAppPath); err == nil {
|
||||
Router.Static("/css", webAppPath+"/css")
|
||||
Router.Static("/scripts", webAppPath+"/scripts")
|
||||
Router.Static("/img", webAppPath+"/img")
|
||||
Router.Static("/fonts", webAppPath+"/fonts")
|
||||
Router.Static("/webfonts", webAppPath+"/webfonts")
|
||||
Router.StaticFile("/auth.html", webAppPath+"/auth.html")
|
||||
Router.StaticFile("/dashboard-example.html", webAppPath+"/dashboard-example.html")
|
||||
Router.StaticFile("/favicon.ico", webAppPath+"/favicon.ico")
|
||||
global.GVA_LOG.Info("前台静态文件服务已启动: " + webAppPath)
|
||||
// SillyTavern 核心脚本静态文件服务
|
||||
// 扩展通过 ES module 相对路径 import 引用这些核心模块(如 ../../../../../script.js → /script.js)
|
||||
// 所有核心文件存储在 data/st-core-scripts/ 下,完全独立于 web-app/ 目录
|
||||
// 扩展文件存储在 data/st-core-scripts/scripts/extensions/third-party/{name}/ 下
|
||||
stCorePath := "data/st-core-scripts"
|
||||
if _, err := os.Stat(stCorePath); err == nil {
|
||||
Router.Static("/scripts", stCorePath+"/scripts")
|
||||
Router.Static("/css", stCorePath+"/css")
|
||||
Router.Static("/img", stCorePath+"/img")
|
||||
Router.Static("/webfonts", stCorePath+"/webfonts")
|
||||
Router.Static("/lib", stCorePath+"/lib") // SillyTavern 扩展依赖的第三方库
|
||||
Router.Static("/locales", stCorePath+"/locales") // 国际化文件
|
||||
Router.StaticFile("/script.js", stCorePath+"/script.js") // SillyTavern 主入口
|
||||
Router.StaticFile("/lib.js", stCorePath+"/lib.js") // Webpack 编译后的 lib.js
|
||||
global.GVA_LOG.Info("SillyTavern 核心脚本服务已启动: " + stCorePath)
|
||||
} else {
|
||||
global.GVA_LOG.Warn("SillyTavern 核心脚本目录不存在: " + stCorePath + ",扩展功能将不可用")
|
||||
}
|
||||
|
||||
// 管理后台前端静态文件(web)
|
||||
@@ -90,9 +98,6 @@ func Routers() *gin.Engine {
|
||||
|
||||
Router.StaticFS(global.GVA_CONFIG.Local.StorePath, justFilesFilesystem{http.Dir(global.GVA_CONFIG.Local.StorePath)}) // Router.Use(middleware.LoadTls()) // 如果需要使用https 请打开此中间件 然后前往 core/server.go 将启动模式 更变为 Router.RunTLS("端口","你的cre/pem文件","你的key文件")
|
||||
|
||||
// 跨域配置(前台应用需要)
|
||||
Router.Use(middleware.Cors()) // 直接放行全部跨域请求
|
||||
global.GVA_LOG.Info("use middleware cors")
|
||||
docs.SwaggerInfo.BasePath = global.GVA_CONFIG.System.RouterPrefix
|
||||
Router.GET(global.GVA_CONFIG.System.RouterPrefix+"/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
global.GVA_LOG.Info("register swagger handler")
|
||||
@@ -149,6 +154,8 @@ func Routers() *gin.Engine {
|
||||
appRouter.InitWorldInfoRouter(appGroup) // 世界书路由:/app/worldbook/*
|
||||
appRouter.InitExtensionRouter(appGroup) // 扩展路由:/app/extension/*
|
||||
appRouter.InitRegexScriptRouter(appGroup) // 正则脚本路由:/app/regex/*
|
||||
appRouter.InitProviderRouter(appGroup) // AI提供商路由:/app/provider/*
|
||||
appRouter.InitChatRouter(appGroup) // 对话路由:/app/chat/*
|
||||
}
|
||||
|
||||
//插件路由安装
|
||||
|
||||
@@ -67,7 +67,7 @@ func (AIExtension) TableName() string {
|
||||
return "ai_extensions"
|
||||
}
|
||||
|
||||
// AIExtensionManifest 扩展清单结构 (对应 manifest.json)
|
||||
// AIExtensionManifest 扩展清单结构 (对应 manifest.json,兼容 SillyTavern 格式)
|
||||
type AIExtensionManifest struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
@@ -75,6 +75,7 @@ type AIExtensionManifest struct {
|
||||
Description string `json:"description"`
|
||||
Author string `json:"author"`
|
||||
Homepage string `json:"homepage,omitempty"`
|
||||
HomePage string `json:"homePage,omitempty"` // SillyTavern 兼容(驼峰命名)
|
||||
Repository string `json:"repository,omitempty"`
|
||||
License string `json:"license,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
@@ -83,13 +84,54 @@ type AIExtensionManifest struct {
|
||||
Dependencies map[string]string `json:"dependencies,omitempty"` // {"extension-name": ">=1.0.0"}
|
||||
Conflicts []string `json:"conflicts,omitempty"`
|
||||
Entry string `json:"entry,omitempty"` // 主入口文件
|
||||
Js string `json:"js,omitempty"` // SillyTavern 兼容: JS 入口文件
|
||||
Style string `json:"style,omitempty"` // 样式文件
|
||||
Css string `json:"css,omitempty"` // SillyTavern 兼容: CSS 样式文件
|
||||
Assets []string `json:"assets,omitempty"` // 资源文件列表
|
||||
Settings map[string]interface{} `json:"settings,omitempty"` // 默认设置
|
||||
Options map[string]interface{} `json:"options,omitempty"` // 扩展选项
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"` // 扩展元数据
|
||||
AutoUpdate bool `json:"auto_update,omitempty"` // 是否自动更新(SillyTavern 兼容)
|
||||
InlineScript string `json:"inline_script,omitempty"` // 内联脚本(SillyTavern 兼容)
|
||||
Requires []string `json:"requires,omitempty"` // SillyTavern 兼容: 依赖列表
|
||||
Optional []string `json:"optional,omitempty"` // SillyTavern 兼容: 可选依赖
|
||||
LoadingOrder int `json:"loading_order,omitempty"` // SillyTavern 兼容: 加载顺序
|
||||
I18n map[string]string `json:"i18n,omitempty"` // SillyTavern 兼容: 国际化文件
|
||||
}
|
||||
|
||||
// GetEffectiveName 获取有效名称,兼容 SillyTavern manifest 没有 name 字段的情况
|
||||
func (m *AIExtensionManifest) GetEffectiveName() string {
|
||||
if m.Name != "" {
|
||||
return m.Name
|
||||
}
|
||||
if m.DisplayName != "" {
|
||||
return m.DisplayName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetEffectiveHomepage 获取有效主页地址
|
||||
func (m *AIExtensionManifest) GetEffectiveHomepage() string {
|
||||
if m.Homepage != "" {
|
||||
return m.Homepage
|
||||
}
|
||||
return m.HomePage
|
||||
}
|
||||
|
||||
// GetEffectiveEntry 获取有效的 JS 入口文件路径
|
||||
func (m *AIExtensionManifest) GetEffectiveEntry() string {
|
||||
if m.Entry != "" {
|
||||
return m.Entry
|
||||
}
|
||||
return m.Js
|
||||
}
|
||||
|
||||
// GetEffectiveStyle 获取有效的 CSS 样式文件路径
|
||||
func (m *AIExtensionManifest) GetEffectiveStyle() string {
|
||||
if m.Style != "" {
|
||||
return m.Style
|
||||
}
|
||||
return m.Css
|
||||
}
|
||||
|
||||
// AIExtensionSettings 用户的扩展配置(已废弃,配置现在直接存储在 AIExtension.Settings 中)
|
||||
|
||||
@@ -6,14 +6,21 @@ import (
|
||||
)
|
||||
|
||||
// AIProvider AI 服务提供商配置
|
||||
// 每个用户可以配置多个 AI 提供商(如 OpenAI、Claude、Gemini 等)
|
||||
// UserID 为 NULL 表示系统预设的提供商配置
|
||||
type AIProvider struct {
|
||||
global.GVA_MODEL
|
||||
UserID *uint `json:"userId" gorm:"index;comment:用户ID(NULL表示系统配置)"`
|
||||
User *AppUser `json:"user" gorm:"foreignKey:UserID"`
|
||||
User *AppUser `json:"user,omitempty" gorm:"foreignKey:UserID"`
|
||||
ProviderName string `json:"providerName" gorm:"type:varchar(100);not null;index;comment:提供商名称"`
|
||||
APIConfig datatypes.JSON `json:"apiConfig" gorm:"type:jsonb;not null;comment:API配置(加密存储)"`
|
||||
ProviderType string `json:"providerType" gorm:"type:varchar(50);not null;default:openai;comment:提供商类型(openai/claude/gemini/custom)"`
|
||||
BaseURL string `json:"baseUrl" gorm:"type:varchar(500);comment:API基础地址"`
|
||||
APIKey string `json:"-" gorm:"type:varchar(500);comment:API密钥(加密存储)"`
|
||||
APIConfig datatypes.JSON `json:"apiConfig" gorm:"type:jsonb;comment:额外API配置"`
|
||||
Capabilities datatypes.JSON `json:"capabilities" gorm:"type:jsonb;comment:支持的能力(chat/image_gen等)"`
|
||||
IsEnabled bool `json:"isEnabled" gorm:"default:true;comment:是否启用"`
|
||||
IsDefault bool `json:"isDefault" gorm:"default:false;comment:是否为默认提供商"`
|
||||
SortOrder int `json:"sortOrder" gorm:"default:0;comment:排序权重"`
|
||||
}
|
||||
|
||||
func (AIProvider) TableName() string {
|
||||
@@ -21,12 +28,14 @@ func (AIProvider) TableName() string {
|
||||
}
|
||||
|
||||
// AIModel AI 模型配置
|
||||
// 每个提供商下可以有多个模型(如 gpt-4o、gpt-4o-mini 等)
|
||||
type AIModel struct {
|
||||
global.GVA_MODEL
|
||||
ProviderID uint `json:"providerId" gorm:"not null;index;comment:提供商ID"`
|
||||
Provider *AIProvider `json:"provider" gorm:"foreignKey:ProviderID"`
|
||||
ModelName string `json:"modelName" gorm:"type:varchar(200);not null;comment:模型名称"`
|
||||
Provider *AIProvider `json:"provider,omitempty" gorm:"foreignKey:ProviderID"`
|
||||
ModelName string `json:"modelName" gorm:"type:varchar(200);not null;comment:模型标识(API调用用)"`
|
||||
DisplayName string `json:"displayName" gorm:"type:varchar(200);comment:模型显示名称"`
|
||||
ModelType string `json:"modelType" gorm:"type:varchar(50);not null;default:chat;comment:模型类型(chat/image_gen)"`
|
||||
Config datatypes.JSON `json:"config" gorm:"type:jsonb;comment:模型参数配置"`
|
||||
IsEnabled bool `json:"isEnabled" gorm:"default:true;comment:是否启用"`
|
||||
}
|
||||
|
||||
45
server/model/app/request/chat.go
Normal file
45
server/model/app/request/chat.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package request
|
||||
|
||||
// CreateChatRequest 创建对话请求
|
||||
type CreateChatRequest struct {
|
||||
CharacterID uint `json:"characterId" binding:"required"` // 角色卡ID
|
||||
Title string `json:"title"` // 对话标题(可选,默认使用角色名)
|
||||
}
|
||||
|
||||
// ChatListRequest 对话列表请求
|
||||
type ChatListRequest struct {
|
||||
Page int `form:"page" binding:"min=1"`
|
||||
PageSize int `form:"pageSize" binding:"min=1,max=50"`
|
||||
}
|
||||
|
||||
// SendMessageRequest 发送消息请求
|
||||
type SendMessageRequest struct {
|
||||
ChatID uint `json:"chatId" binding:"required"` // 对话ID
|
||||
Content string `json:"content" binding:"required"` // 消息内容
|
||||
ProviderID *uint `json:"providerId"` // 指定提供商(可选,默认使用用户默认)
|
||||
ModelName string `json:"modelName"` // 指定模型(可选)
|
||||
}
|
||||
|
||||
// RegenerateRequest 重新生成请求
|
||||
type RegenerateRequest struct {
|
||||
ChatID uint `json:"chatId" binding:"required"` // 对话ID
|
||||
MessageID uint `json:"messageId" binding:"required"` // 要重新生成的消息ID
|
||||
ModelName string `json:"modelName"` // 可选,指定其他模型
|
||||
}
|
||||
|
||||
// EditMessageRequest 编辑消息请求
|
||||
type EditMessageRequest struct {
|
||||
MessageID uint `json:"messageId" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
// DeleteMessageRequest 删除消息请求
|
||||
type DeleteMessageRequest struct {
|
||||
MessageID uint `json:"messageId" binding:"required"`
|
||||
}
|
||||
|
||||
// ChatMessagesRequest 获取对话消息请求
|
||||
type ChatMessagesRequest struct {
|
||||
Page int `form:"page" binding:"min=1"`
|
||||
PageSize int `form:"pageSize" binding:"min=1,max=100"`
|
||||
}
|
||||
80
server/model/app/request/provider.go
Normal file
80
server/model/app/request/provider.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package request
|
||||
|
||||
// CreateProviderRequest 创建AI提供商请求
|
||||
type CreateProviderRequest struct {
|
||||
ProviderName string `json:"providerName" binding:"required,min=1,max=100"` // 提供商名称(如"我的OpenAI")
|
||||
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"` // 提供商类型
|
||||
BaseURL string `json:"baseUrl"` // API基础地址(自定义或中转站)
|
||||
APIKey string `json:"apiKey" binding:"required"` // API密钥
|
||||
APIConfig map[string]interface{} `json:"apiConfig"` // 额外配置
|
||||
Models []CreateModelRequest `json:"models"` // 同时创建的模型列表
|
||||
}
|
||||
|
||||
// UpdateProviderRequest 更新AI提供商请求
|
||||
type UpdateProviderRequest struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
ProviderName string `json:"providerName" binding:"required,min=1,max=100"`
|
||||
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey"` // 为空表示不修改
|
||||
APIConfig map[string]interface{} `json:"apiConfig"`
|
||||
IsEnabled *bool `json:"isEnabled"`
|
||||
IsDefault *bool `json:"isDefault"`
|
||||
SortOrder *int `json:"sortOrder"`
|
||||
}
|
||||
|
||||
// CreateModelRequest 创建AI模型请求
|
||||
type CreateModelRequest struct {
|
||||
ProviderID uint `json:"providerId"` // 关联提供商ID
|
||||
ModelName string `json:"modelName" binding:"required,min=1,max=200"` // 模型标识(如 gpt-4o)
|
||||
DisplayName string `json:"displayName"` // 显示名称
|
||||
ModelType string `json:"modelType" binding:"required,oneof=chat image_gen"` // 模型类型
|
||||
Config map[string]interface{} `json:"config"` // 模型参数配置
|
||||
IsEnabled *bool `json:"isEnabled"`
|
||||
}
|
||||
|
||||
// UpdateModelRequest 更新AI模型请求
|
||||
type UpdateModelRequest struct {
|
||||
ID uint `json:"id" binding:"required"`
|
||||
ModelName string `json:"modelName" binding:"required,min=1,max=200"`
|
||||
DisplayName string `json:"displayName"`
|
||||
ModelType string `json:"modelType" binding:"required,oneof=chat image_gen"`
|
||||
Config map[string]interface{} `json:"config"`
|
||||
IsEnabled *bool `json:"isEnabled"`
|
||||
}
|
||||
|
||||
// ProviderListRequest 提供商列表请求
|
||||
type ProviderListRequest struct {
|
||||
Page int `form:"page" binding:"min=1"`
|
||||
PageSize int `form:"pageSize" binding:"min=1,max=100"`
|
||||
Keyword string `form:"keyword"`
|
||||
}
|
||||
|
||||
// TestProviderRequest 测试提供商连通性请求
|
||||
type TestProviderRequest struct {
|
||||
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey" binding:"required"`
|
||||
ModelName string `json:"modelName"` // 可选,用于测试特定模型
|
||||
}
|
||||
|
||||
// SetDefaultProviderRequest 设置默认提供商请求
|
||||
type SetDefaultProviderRequest struct {
|
||||
ProviderID uint `json:"providerId" binding:"required"`
|
||||
}
|
||||
|
||||
// FetchRemoteModelsRequest 获取远程可用模型列表请求
|
||||
type FetchRemoteModelsRequest struct {
|
||||
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey" binding:"required"`
|
||||
}
|
||||
|
||||
// SendTestMessageRequest 发送测试消息请求
|
||||
type SendTestMessageRequest struct {
|
||||
ProviderType string `json:"providerType" binding:"required,oneof=openai claude gemini custom"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKey string `json:"apiKey" binding:"required"`
|
||||
ModelName string `json:"modelName" binding:"required"` // 要测试的模型
|
||||
Message string `json:"message"` // 测试消息内容,为空则使用默认
|
||||
}
|
||||
64
server/model/app/response/chat.go
Normal file
64
server/model/app/response/chat.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChatResponse 对话响应
|
||||
type ChatResponse struct {
|
||||
ID uint `json:"id"`
|
||||
Title string `json:"title"`
|
||||
CharacterID *uint `json:"characterId"`
|
||||
CharacterName string `json:"characterName"`
|
||||
CharacterAvatar string `json:"characterAvatar"`
|
||||
ChatType string `json:"chatType"`
|
||||
LastMessageAt *time.Time `json:"lastMessageAt"`
|
||||
MessageCount int `json:"messageCount"`
|
||||
IsPinned bool `json:"isPinned"`
|
||||
LastMessage *MessageBrief `json:"lastMessage,omitempty"` // 最后一条消息摘要
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// MessageBrief 消息摘要(用于对话列表显示)
|
||||
type MessageBrief struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// ChatListResponse 对话列表响应
|
||||
type ChatListResponse struct {
|
||||
List []ChatResponse `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"pageSize"`
|
||||
}
|
||||
|
||||
// MessageResponse 消息响应
|
||||
type MessageResponse struct {
|
||||
ID uint `json:"id"`
|
||||
ChatID uint `json:"chatId"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"` // user / assistant / system
|
||||
CharacterID *uint `json:"characterId"`
|
||||
CharacterName string `json:"characterName,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
PromptTokens int `json:"promptTokens,omitempty"`
|
||||
CompletionTokens int `json:"completionTokens,omitempty"`
|
||||
TotalTokens int `json:"totalTokens,omitempty"`
|
||||
SequenceNumber int `json:"sequenceNumber"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// MessageListResponse 消息列表响应
|
||||
type MessageListResponse struct {
|
||||
List []MessageResponse `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"pageSize"`
|
||||
}
|
||||
|
||||
// ChatDetailResponse 对话详情响应(包含角色信息 + 消息列表)
|
||||
type ChatDetailResponse struct {
|
||||
Chat ChatResponse `json:"chat"`
|
||||
Messages []MessageResponse `json:"messages"`
|
||||
}
|
||||
@@ -8,39 +8,44 @@ import (
|
||||
|
||||
// ExtensionResponse 扩展响应
|
||||
type ExtensionResponse struct {
|
||||
ID uint `json:"id"`
|
||||
UserID uint `json:"userId"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Version string `json:"version"`
|
||||
Author string `json:"author"`
|
||||
Description string `json:"description"`
|
||||
Homepage string `json:"homepage"`
|
||||
Repository string `json:"repository"`
|
||||
License string `json:"license"`
|
||||
Tags []string `json:"tags"`
|
||||
ExtensionType string `json:"extensionType"`
|
||||
Category string `json:"category"`
|
||||
Dependencies map[string]string `json:"dependencies"`
|
||||
Conflicts []string `json:"conflicts"`
|
||||
ManifestData map[string]interface{} `json:"manifestData"`
|
||||
ScriptPath string `json:"scriptPath"`
|
||||
StylePath string `json:"stylePath"`
|
||||
AssetsPaths []string `json:"assetsPaths"`
|
||||
Settings map[string]interface{} `json:"settings"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
IsEnabled bool `json:"isEnabled"`
|
||||
IsInstalled bool `json:"isInstalled"`
|
||||
IsSystemExt bool `json:"isSystemExt"`
|
||||
InstallSource string `json:"installSource"`
|
||||
InstallDate time.Time `json:"installDate"`
|
||||
LastEnabled time.Time `json:"lastEnabled"`
|
||||
UsageCount int `json:"usageCount"`
|
||||
ErrorCount int `json:"errorCount"`
|
||||
LoadTime int `json:"loadTime"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ID uint `json:"id"`
|
||||
UserID uint `json:"userId"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Version string `json:"version"`
|
||||
Author string `json:"author"`
|
||||
Description string `json:"description"`
|
||||
Homepage string `json:"homepage"`
|
||||
Repository string `json:"repository"`
|
||||
License string `json:"license"`
|
||||
Tags []string `json:"tags"`
|
||||
ExtensionType string `json:"extensionType"`
|
||||
Category string `json:"category"`
|
||||
Dependencies map[string]string `json:"dependencies"`
|
||||
Conflicts []string `json:"conflicts"`
|
||||
ManifestData map[string]interface{} `json:"manifestData"`
|
||||
ScriptPath string `json:"scriptPath"`
|
||||
StylePath string `json:"stylePath"`
|
||||
AssetsPaths []string `json:"assetsPaths"`
|
||||
Settings map[string]interface{} `json:"settings"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
IsEnabled bool `json:"isEnabled"`
|
||||
IsInstalled bool `json:"isInstalled"`
|
||||
IsSystemExt bool `json:"isSystemExt"`
|
||||
InstallSource string `json:"installSource"`
|
||||
SourceURL string `json:"sourceUrl"`
|
||||
Branch string `json:"branch"`
|
||||
AutoUpdate bool `json:"autoUpdate"`
|
||||
InstallDate time.Time `json:"installDate"`
|
||||
LastEnabled time.Time `json:"lastEnabled"`
|
||||
LastUpdateCheck *time.Time `json:"lastUpdateCheck"`
|
||||
AvailableVersion string `json:"availableVersion"`
|
||||
UsageCount int `json:"usageCount"`
|
||||
ErrorCount int `json:"errorCount"`
|
||||
LoadTime int `json:"loadTime"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// ExtensionListResponse 扩展列表响应
|
||||
@@ -151,38 +156,43 @@ func ToExtensionResponse(ext *app.AIExtension) ExtensionResponse {
|
||||
}
|
||||
|
||||
return ExtensionResponse{
|
||||
ID: ext.ID,
|
||||
UserID: ext.UserID,
|
||||
Name: ext.Name,
|
||||
DisplayName: ext.DisplayName,
|
||||
Version: ext.Version,
|
||||
Author: ext.Author,
|
||||
Description: ext.Description,
|
||||
Homepage: ext.Homepage,
|
||||
Repository: ext.Repository,
|
||||
License: ext.License,
|
||||
Tags: tags,
|
||||
ExtensionType: ext.ExtensionType,
|
||||
Category: ext.Category,
|
||||
Dependencies: dependencies,
|
||||
Conflicts: conflicts,
|
||||
ManifestData: manifestData,
|
||||
ScriptPath: ext.ScriptPath,
|
||||
StylePath: ext.StylePath,
|
||||
AssetsPaths: assetsPaths,
|
||||
Settings: settings,
|
||||
Options: options,
|
||||
IsEnabled: ext.IsEnabled,
|
||||
IsInstalled: ext.IsInstalled,
|
||||
IsSystemExt: ext.IsSystemExt,
|
||||
InstallSource: ext.InstallSource,
|
||||
InstallDate: ext.InstallDate,
|
||||
LastEnabled: ext.LastEnabled,
|
||||
UsageCount: ext.UsageCount,
|
||||
ErrorCount: ext.ErrorCount,
|
||||
LoadTime: ext.LoadTime,
|
||||
Metadata: metadata,
|
||||
CreatedAt: ext.CreatedAt,
|
||||
UpdatedAt: ext.UpdatedAt,
|
||||
ID: ext.ID,
|
||||
UserID: ext.UserID,
|
||||
Name: ext.Name,
|
||||
DisplayName: ext.DisplayName,
|
||||
Version: ext.Version,
|
||||
Author: ext.Author,
|
||||
Description: ext.Description,
|
||||
Homepage: ext.Homepage,
|
||||
Repository: ext.Repository,
|
||||
License: ext.License,
|
||||
Tags: tags,
|
||||
ExtensionType: ext.ExtensionType,
|
||||
Category: ext.Category,
|
||||
Dependencies: dependencies,
|
||||
Conflicts: conflicts,
|
||||
ManifestData: manifestData,
|
||||
ScriptPath: ext.ScriptPath,
|
||||
StylePath: ext.StylePath,
|
||||
AssetsPaths: assetsPaths,
|
||||
Settings: settings,
|
||||
Options: options,
|
||||
IsEnabled: ext.IsEnabled,
|
||||
IsInstalled: ext.IsInstalled,
|
||||
IsSystemExt: ext.IsSystemExt,
|
||||
InstallSource: ext.InstallSource,
|
||||
SourceURL: ext.SourceURL,
|
||||
Branch: ext.Branch,
|
||||
AutoUpdate: ext.AutoUpdate,
|
||||
InstallDate: ext.InstallDate,
|
||||
LastEnabled: ext.LastEnabled,
|
||||
LastUpdateCheck: ext.LastUpdateCheck,
|
||||
AvailableVersion: ext.AvailableVersion,
|
||||
UsageCount: ext.UsageCount,
|
||||
ErrorCount: ext.ErrorCount,
|
||||
LoadTime: ext.LoadTime,
|
||||
Metadata: metadata,
|
||||
CreatedAt: ext.CreatedAt,
|
||||
UpdatedAt: ext.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
92
server/model/app/response/provider.go
Normal file
92
server/model/app/response/provider.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProviderResponse 提供商响应
|
||||
type ProviderResponse struct {
|
||||
ID uint `json:"id"`
|
||||
ProviderName string `json:"providerName"`
|
||||
ProviderType string `json:"providerType"`
|
||||
BaseURL string `json:"baseUrl"`
|
||||
APIKeySet bool `json:"apiKeySet"` // 是否已设置API密钥(不返回明文)
|
||||
APIKeyHint string `json:"apiKeyHint"` // API密钥提示(如 sk-****1234)
|
||||
APIConfig json.RawMessage `json:"apiConfig"`
|
||||
Capabilities json.RawMessage `json:"capabilities"`
|
||||
IsEnabled bool `json:"isEnabled"`
|
||||
IsDefault bool `json:"isDefault"`
|
||||
SortOrder int `json:"sortOrder"`
|
||||
Models []ModelResponse `json:"models"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// ModelResponse 模型响应
|
||||
type ModelResponse struct {
|
||||
ID uint `json:"id"`
|
||||
ProviderID uint `json:"providerId"`
|
||||
ModelName string `json:"modelName"`
|
||||
DisplayName string `json:"displayName"`
|
||||
ModelType string `json:"modelType"`
|
||||
Config json.RawMessage `json:"config"`
|
||||
IsEnabled bool `json:"isEnabled"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// ProviderListResponse 提供商列表响应
|
||||
type ProviderListResponse struct {
|
||||
List []ProviderResponse `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"pageSize"`
|
||||
}
|
||||
|
||||
// TestProviderResponse 测试连通性响应
|
||||
type TestProviderResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Models []string `json:"models,omitempty"` // 获取到的可用模型列表
|
||||
Latency int64 `json:"latency"` // 响应延迟(毫秒)
|
||||
}
|
||||
|
||||
// ProviderTypeOption 提供商类型选项(前端下拉用)
|
||||
type ProviderTypeOption struct {
|
||||
Value string `json:"value"` // 类型标识
|
||||
Label string `json:"label"` // 显示名称
|
||||
Description string `json:"description"` // 描述
|
||||
DefaultURL string `json:"defaultUrl"` // 默认API地址
|
||||
}
|
||||
|
||||
// PresetModelOption 预设模型选项
|
||||
type PresetModelOption struct {
|
||||
ModelName string `json:"modelName"`
|
||||
DisplayName string `json:"displayName"`
|
||||
ModelType string `json:"modelType"` // chat / image_gen
|
||||
}
|
||||
|
||||
// RemoteModel 远程获取到的模型信息
|
||||
type RemoteModel struct {
|
||||
ID string `json:"id"` // 模型标识(API调用用)
|
||||
DisplayName string `json:"displayName"` // 显示名称
|
||||
OwnedBy string `json:"ownedBy"` // 所有者/来源
|
||||
}
|
||||
|
||||
// FetchRemoteModelsResponse 获取远程模型列表响应
|
||||
type FetchRemoteModelsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Models []RemoteModel `json:"models"`
|
||||
Latency int64 `json:"latency"` // 响应延迟(毫秒)
|
||||
}
|
||||
|
||||
// SendTestMessageResponse 发送测试消息响应
|
||||
type SendTestMessageResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"` // 状态信息
|
||||
Reply string `json:"reply"` // AI 回复内容
|
||||
Model string `json:"model"` // 实际使用的模型
|
||||
Latency int64 `json:"latency"` // 响应延迟(毫秒)
|
||||
Tokens int `json:"tokens"` // 消耗的token数
|
||||
}
|
||||
29
server/router/app/chat.go
Normal file
29
server/router/app/chat.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
v1 "git.echol.cn/loser/st/server/api/v1"
|
||||
"git.echol.cn/loser/st/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ChatRouter struct{}
|
||||
|
||||
func (cr *ChatRouter) InitChatRouter(Router *gin.RouterGroup) {
|
||||
chatApi := v1.ApiGroupApp.AppApiGroup.ChatApi
|
||||
|
||||
// 所有对话接口都需要登录
|
||||
chatRouter := Router.Group("chat").Use(middleware.AppJWTAuth())
|
||||
{
|
||||
// 对话管理
|
||||
chatRouter.POST("", chatApi.CreateChat) // 创建对话
|
||||
chatRouter.GET("/list", chatApi.GetChatList) // 对话列表
|
||||
chatRouter.GET("/:id", chatApi.GetChatDetail) // 对话详情
|
||||
chatRouter.GET("/:id/messages", chatApi.GetChatMessages) // 获取消息
|
||||
chatRouter.DELETE("/:id", chatApi.DeleteChat) // 删除对话
|
||||
|
||||
// 消息操作
|
||||
chatRouter.POST("/send", chatApi.SendMessage) // 发送消息(SSE 流式)
|
||||
chatRouter.POST("/message/edit", chatApi.EditMessage) // 编辑消息
|
||||
chatRouter.POST("/message/delete", chatApi.DeleteMessage) // 删除消息
|
||||
}
|
||||
}
|
||||
@@ -6,4 +6,6 @@ type RouterGroup struct {
|
||||
WorldInfoRouter
|
||||
ExtensionRouter
|
||||
RegexScriptRouter
|
||||
ProviderRouter
|
||||
ChatRouter
|
||||
}
|
||||
|
||||
@@ -43,4 +43,13 @@ func (r *ExtensionRouter) InitExtensionRouter(Router *gin.RouterGroup) {
|
||||
// 统计
|
||||
extensionRouter.POST("/stats", extensionApi.UpdateExtensionStats) // 更新扩展统计
|
||||
}
|
||||
|
||||
// 扩展资源文件 - 公开路由(不需要鉴权)
|
||||
// 原因:<script type="module"> 标签无法携带 JWT header,
|
||||
// 且 ES module 的 import 语句也无法携带认证信息。
|
||||
// 与原版 SillyTavern 一致:扩展文件作为公开静态资源提供。
|
||||
extensionPublicRouter := Router.Group("extension")
|
||||
{
|
||||
extensionPublicRouter.GET("/:id/asset/*path", extensionApi.ProxyExtensionAsset)
|
||||
}
|
||||
}
|
||||
|
||||
49
server/router/app/provider.go
Normal file
49
server/router/app/provider.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
v1 "git.echol.cn/loser/st/server/api/v1"
|
||||
"git.echol.cn/loser/st/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ProviderRouter struct{}
|
||||
|
||||
func (pr *ProviderRouter) InitProviderRouter(Router *gin.RouterGroup) {
|
||||
providerApi := v1.ApiGroupApp.AppApiGroup.ProviderApi
|
||||
|
||||
// 公开接口(不需要登录)
|
||||
providerPublicRouter := Router.Group("provider")
|
||||
{
|
||||
providerPublicRouter.GET("/types", providerApi.GetProviderTypes) // 获取支持的提供商类型
|
||||
providerPublicRouter.GET("/presetModels", providerApi.GetPresetModels) // 获取预设模型列表
|
||||
}
|
||||
|
||||
// 需要登录的接口
|
||||
providerAuthRouter := Router.Group("provider").Use(middleware.AppJWTAuth())
|
||||
{
|
||||
// 提供商 CRUD
|
||||
providerAuthRouter.POST("", providerApi.CreateProvider) // 创建提供商
|
||||
providerAuthRouter.GET("/list", providerApi.GetProviderList) // 获取列表
|
||||
providerAuthRouter.GET("/:id", providerApi.GetProviderDetail) // 获取详情
|
||||
providerAuthRouter.PUT("", providerApi.UpdateProvider) // 更新提供商
|
||||
providerAuthRouter.DELETE("/:id", providerApi.DeleteProvider) // 删除提供商
|
||||
providerAuthRouter.POST("/setDefault", providerApi.SetDefaultProvider) // 设置默认
|
||||
|
||||
// 连通性测试
|
||||
providerAuthRouter.POST("/test", providerApi.TestProvider) // 测试连接(不需要先保存)
|
||||
providerAuthRouter.GET("/test/:id", providerApi.TestExistingProvider) // 测试已保存的连接
|
||||
|
||||
// 远程模型获取
|
||||
providerAuthRouter.POST("/fetchModels", providerApi.FetchRemoteModels) // 获取远程模型列表
|
||||
providerAuthRouter.GET("/fetchModels/:id", providerApi.FetchRemoteModelsExisting) // 获取已保存提供商的远程模型
|
||||
|
||||
// 测试消息
|
||||
providerAuthRouter.POST("/testMessage", providerApi.SendTestMessage) // 发送测试消息
|
||||
providerAuthRouter.POST("/testMessage/:id", providerApi.SendTestMessageExisting) // 已保存提供商发测试消息
|
||||
|
||||
// 模型管理
|
||||
providerAuthRouter.POST("/model", providerApi.AddModel) // 添加模型
|
||||
providerAuthRouter.PUT("/model", providerApi.UpdateModel) // 更新模型
|
||||
providerAuthRouter.DELETE("/model/:id", providerApi.DeleteModel) // 删除模型
|
||||
}
|
||||
}
|
||||
625
server/service/app/ai_client.go
Normal file
625
server/service/app/ai_client.go
Normal file
@@ -0,0 +1,625 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appModel "git.echol.cn/loser/st/server/model/app"
|
||||
)
|
||||
|
||||
// AIMessage AI调用的消息格式(统一)
|
||||
type AIMessagePayload struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// AIStreamChunk 流式响应的单个块
|
||||
type AIStreamChunk struct {
|
||||
Content string `json:"content"` // 增量文本
|
||||
Done bool `json:"done"` // 是否结束
|
||||
Model string `json:"model"` // 使用的模型
|
||||
PromptTokens int `json:"promptTokens"` // 提示词Token(仅结束时有值)
|
||||
CompletionTokens int `json:"completionTokens"` // 补全Token(仅结束时有值)
|
||||
Error string `json:"error"` // 错误信息
|
||||
}
|
||||
|
||||
// BuildPrompt 根据角色卡和消息历史构建 AI 请求的 prompt
|
||||
func BuildPrompt(character *appModel.AICharacter, messages []appModel.AIMessage) []AIMessagePayload {
|
||||
var payload []AIMessagePayload
|
||||
|
||||
// 1. 系统提示词(System Prompt)
|
||||
systemPrompt := buildSystemPrompt(character)
|
||||
if systemPrompt != "" {
|
||||
payload = append(payload, AIMessagePayload{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// 2. 历史消息
|
||||
for _, msg := range messages {
|
||||
role := msg.Role
|
||||
if role == "assistant" || role == "user" || role == "system" {
|
||||
payload = append(payload, AIMessagePayload{
|
||||
Role: role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildSystemPrompt 构建系统提示词
|
||||
func buildSystemPrompt(character *appModel.AICharacter) string {
|
||||
if character == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
|
||||
// 角色描述
|
||||
if character.Description != "" {
|
||||
parts = append(parts, character.Description)
|
||||
}
|
||||
|
||||
// 角色性格
|
||||
if character.Personality != "" {
|
||||
parts = append(parts, "Personality: "+character.Personality)
|
||||
}
|
||||
|
||||
// 场景
|
||||
if character.Scenario != "" {
|
||||
parts = append(parts, "Scenario: "+character.Scenario)
|
||||
}
|
||||
|
||||
// 系统提示词
|
||||
if character.SystemPrompt != "" {
|
||||
parts = append(parts, character.SystemPrompt)
|
||||
}
|
||||
|
||||
// 示例消息
|
||||
if len(character.ExampleMessages) > 0 {
|
||||
parts = append(parts, "Example dialogue:\n"+strings.Join(character.ExampleMessages, "\n"))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// StreamAIResponse 流式调用 AI 并通过 channel 返回结果
|
||||
func StreamAIResponse(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
defer close(ch)
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
baseURL := provider.BaseURL
|
||||
|
||||
switch provider.ProviderType {
|
||||
case "openai", "custom":
|
||||
streamOpenAI(ctx, baseURL, apiKey, modelName, messages, ch)
|
||||
case "claude":
|
||||
streamClaude(ctx, baseURL, apiKey, modelName, messages, ch)
|
||||
case "gemini":
|
||||
streamGemini(ctx, baseURL, apiKey, modelName, messages, ch)
|
||||
default:
|
||||
ch <- AIStreamChunk{Error: "不支持的提供商类型: " + provider.ProviderType, Done: true}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== OpenAI Compatible ====================
|
||||
|
||||
func streamOpenAI(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
}
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
var fullContent string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName, Content: ""}
|
||||
return
|
||||
}
|
||||
|
||||
var chunk struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
Usage *struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
content := chunk.Choices[0].Delta.Content
|
||||
fullContent += content
|
||||
ch <- AIStreamChunk{Content: content}
|
||||
}
|
||||
|
||||
if chunk.Usage != nil {
|
||||
ch <- AIStreamChunk{
|
||||
Done: true,
|
||||
Model: modelName,
|
||||
PromptTokens: chunk.Usage.PromptTokens,
|
||||
CompletionTokens: chunk.Usage.CompletionTokens,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果扫描结束但没收到 [DONE]
|
||||
if fullContent != "" {
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Claude ====================
|
||||
|
||||
func streamClaude(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
|
||||
|
||||
// Claude 需要把 system 消息分离出来
|
||||
var systemContent string
|
||||
var claudeMessages []map[string]string
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemContent += msg.Content + "\n"
|
||||
} else {
|
||||
claudeMessages = append(claudeMessages, map[string]string{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 确保至少有一条消息
|
||||
if len(claudeMessages) == 0 {
|
||||
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": claudeMessages,
|
||||
"max_tokens": 4096,
|
||||
"stream": true,
|
||||
}
|
||||
if systemContent != "" {
|
||||
body["system"] = strings.TrimSpace(systemContent)
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Delta *struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"delta"`
|
||||
Usage *struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "content_block_delta":
|
||||
if event.Delta != nil && event.Delta.Text != "" {
|
||||
ch <- AIStreamChunk{Content: event.Delta.Text}
|
||||
}
|
||||
case "message_delta":
|
||||
if event.Usage != nil {
|
||||
ch <- AIStreamChunk{
|
||||
Done: true,
|
||||
Model: modelName,
|
||||
PromptTokens: event.Usage.InputTokens,
|
||||
CompletionTokens: event.Usage.OutputTokens,
|
||||
}
|
||||
return
|
||||
}
|
||||
case "message_stop":
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||||
return
|
||||
case "error":
|
||||
ch <- AIStreamChunk{Error: "Claude API 错误", Done: true}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Gemini ====================
|
||||
|
||||
func streamGemini(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload, ch chan<- AIStreamChunk) {
|
||||
url := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s",
|
||||
strings.TrimRight(baseURL, "/"), modelName, apiKey)
|
||||
|
||||
// 构建 Gemini 格式的消息
|
||||
var systemInstruction string
|
||||
var contents []map[string]interface{}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemInstruction += msg.Content + "\n"
|
||||
continue
|
||||
}
|
||||
|
||||
role := msg.Role
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
|
||||
contents = append(contents, map[string]interface{}{
|
||||
"role": role,
|
||||
"parts": []map[string]string{
|
||||
{"text": msg.Content},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if len(contents) == 0 {
|
||||
ch <- AIStreamChunk{Error: "没有有效的消息内容", Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"contents": contents,
|
||||
}
|
||||
if systemInstruction != "" {
|
||||
body["systemInstruction"] = map[string]interface{}{
|
||||
"parts": []map[string]string{
|
||||
{"text": strings.TrimSpace(systemInstruction)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "请求构建失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ch <- AIStreamChunk{Error: "连接失败: " + err.Error(), Done: true}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
ch <- AIStreamChunk{Error: fmt.Sprintf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody)), Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// Gemini 返回较大的 chunks,增大 buffer
|
||||
scanner.Buffer(make([]byte, 0), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var chunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata *struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, candidate := range chunk.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
ch <- AIStreamChunk{Content: part.Text}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.UsageMetadata != nil {
|
||||
ch <- AIStreamChunk{
|
||||
Done: true,
|
||||
Model: modelName,
|
||||
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- AIStreamChunk{Done: true, Model: modelName}
|
||||
}
|
||||
|
||||
// ==================== 非流式调用(用于生图等) ====================
|
||||
|
||||
// CallAINonStream 非流式调用 AI
|
||||
func CallAINonStream(ctx context.Context, provider *appModel.AIProvider, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
baseURL := provider.BaseURL
|
||||
|
||||
switch provider.ProviderType {
|
||||
case "openai", "custom":
|
||||
return callOpenAINonStream(ctx, baseURL, apiKey, modelName, messages)
|
||||
case "claude":
|
||||
return callClaudeNonStream(ctx, baseURL, apiKey, modelName, messages)
|
||||
case "gemini":
|
||||
return callGeminiNonStream(ctx, baseURL, apiKey, modelName, messages)
|
||||
default:
|
||||
return "", errors.New("不支持的提供商类型: " + provider.ProviderType)
|
||||
}
|
||||
}
|
||||
|
||||
func callOpenAINonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": messages,
|
||||
}
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(result.Choices) > 0 {
|
||||
return result.Choices[0].Message.Content, nil
|
||||
}
|
||||
return "", errors.New("AI 未返回有效回复")
|
||||
}
|
||||
|
||||
func callClaudeNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
url := strings.TrimRight(baseURL, "/") + "/v1/messages"
|
||||
|
||||
var systemContent string
|
||||
var claudeMessages []map[string]string
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemContent += msg.Content + "\n"
|
||||
} else {
|
||||
claudeMessages = append(claudeMessages, map[string]string{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"messages": claudeMessages,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
if systemContent != "" {
|
||||
body["system"] = strings.TrimSpace(systemContent)
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(result.Content) > 0 {
|
||||
return result.Content[0].Text, nil
|
||||
}
|
||||
return "", errors.New("AI 未返回有效回复")
|
||||
}
|
||||
|
||||
func callGeminiNonStream(ctx context.Context, baseURL, apiKey, modelName string, messages []AIMessagePayload) (string, error) {
|
||||
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s",
|
||||
strings.TrimRight(baseURL, "/"), modelName, apiKey)
|
||||
|
||||
var systemInstruction string
|
||||
var contents []map[string]interface{}
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemInstruction += msg.Content + "\n"
|
||||
continue
|
||||
}
|
||||
role := msg.Role
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
contents = append(contents, map[string]interface{}{
|
||||
"role": role,
|
||||
"parts": []map[string]string{{"text": msg.Content}},
|
||||
})
|
||||
}
|
||||
|
||||
body := map[string]interface{}{"contents": contents}
|
||||
if systemInstruction != "" {
|
||||
body["systemInstruction"] = map[string]interface{}{
|
||||
"parts": []map[string]string{{"text": strings.TrimSpace(systemInstruction)}},
|
||||
}
|
||||
}
|
||||
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("API 错误 (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(result.Candidates) > 0 && len(result.Candidates[0].Content.Parts) > 0 {
|
||||
return result.Candidates[0].Content.Parts[0].Text, nil
|
||||
}
|
||||
return "", errors.New("AI 未返回有效回复")
|
||||
}
|
||||
408
server/service/app/chat.go
Normal file
408
server/service/app/chat.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/app/response"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ChatService struct{}
|
||||
|
||||
// ==================== 对话 CRUD ====================
|
||||
|
||||
// CreateChat 创建对话
|
||||
func (cs *ChatService) CreateChat(req request.CreateChatRequest, userID uint) (response.ChatResponse, error) {
|
||||
// 获取角色卡信息
|
||||
var character app.AICharacter
|
||||
err := global.GVA_DB.First(&character, req.CharacterID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ChatResponse{}, errors.New("角色卡不存在")
|
||||
}
|
||||
return response.ChatResponse{}, err
|
||||
}
|
||||
|
||||
// 对话标题
|
||||
title := req.Title
|
||||
if title == "" {
|
||||
title = character.Name
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
chat := app.AIChat{
|
||||
Title: title,
|
||||
UserID: userID,
|
||||
CharacterID: &req.CharacterID,
|
||||
ChatType: "single",
|
||||
LastMessageAt: &now,
|
||||
}
|
||||
|
||||
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建对话
|
||||
if err := tx.Create(&chat).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果角色卡有 FirstMessage,自动创建第一条系统消息
|
||||
if character.FirstMessage != "" {
|
||||
firstMsg := app.AIMessage{
|
||||
ChatID: chat.ID,
|
||||
Content: character.FirstMessage,
|
||||
Role: "assistant",
|
||||
CharacterID: &req.CharacterID,
|
||||
SequenceNumber: 1,
|
||||
}
|
||||
if err := tx.Create(&firstMsg).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
chat.MessageCount = 1
|
||||
tx.Model(&chat).Update("message_count", 1)
|
||||
}
|
||||
|
||||
// 更新角色卡使用次数
|
||||
tx.Model(&character).Update("total_chats", gorm.Expr("total_chats + 1"))
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ChatResponse{}, err
|
||||
}
|
||||
|
||||
return toChatResponse(&chat, &character, nil), nil
|
||||
}
|
||||
|
||||
// GetChatList 获取用户的对话列表
|
||||
func (cs *ChatService) GetChatList(req request.ChatListRequest, userID uint) (response.ChatListResponse, error) {
|
||||
db := global.GVA_DB.Model(&app.AIChat{}).Where("user_id = ?", userID)
|
||||
|
||||
var total int64
|
||||
db.Count(&total)
|
||||
|
||||
var chats []app.AIChat
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
err := db.Preload("Character").
|
||||
Order("is_pinned DESC, last_message_at DESC").
|
||||
Offset(offset).Limit(req.PageSize).
|
||||
Find(&chats).Error
|
||||
if err != nil {
|
||||
return response.ChatListResponse{}, err
|
||||
}
|
||||
|
||||
// 获取每个对话的最后一条消息
|
||||
chatIDs := make([]uint, len(chats))
|
||||
for i, c := range chats {
|
||||
chatIDs[i] = c.ID
|
||||
}
|
||||
|
||||
lastMessages := make(map[uint]*response.MessageBrief)
|
||||
if len(chatIDs) > 0 {
|
||||
var messages []app.AIMessage
|
||||
// 获取每个对话的最后一条消息(通过子查询)
|
||||
global.GVA_DB.Where("chat_id IN ? AND is_deleted = ?", chatIDs, false).
|
||||
Order("sequence_number DESC").
|
||||
Find(&messages)
|
||||
|
||||
// 只保留每个对话的最后一条
|
||||
for _, msg := range messages {
|
||||
if _, exists := lastMessages[msg.ChatID]; !exists {
|
||||
content := msg.Content
|
||||
if len(content) > 100 {
|
||||
content = content[:100] + "..."
|
||||
}
|
||||
lastMessages[msg.ChatID] = &response.MessageBrief{
|
||||
Content: content,
|
||||
Role: msg.Role,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
list := make([]response.ChatResponse, len(chats))
|
||||
for i, c := range chats {
|
||||
list[i] = toChatResponse(&c, c.Character, lastMessages[c.ID])
|
||||
}
|
||||
|
||||
return response.ChatListResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChatDetail 获取对话详情(包含消息列表)
|
||||
func (cs *ChatService) GetChatDetail(chatID uint, userID uint) (response.ChatDetailResponse, error) {
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Preload("Character").
|
||||
Where("id = ? AND user_id = ?", chatID, userID).
|
||||
First(&chat).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ChatDetailResponse{}, errors.New("对话不存在")
|
||||
}
|
||||
return response.ChatDetailResponse{}, err
|
||||
}
|
||||
|
||||
// 获取消息列表(最近的50条)
|
||||
var messages []app.AIMessage
|
||||
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
|
||||
Order("sequence_number ASC").
|
||||
Limit(50).
|
||||
Find(&messages)
|
||||
|
||||
msgList := make([]response.MessageResponse, len(messages))
|
||||
for i, msg := range messages {
|
||||
msgList[i] = toMessageResponse(&msg, chat.Character)
|
||||
}
|
||||
|
||||
return response.ChatDetailResponse{
|
||||
Chat: toChatResponse(&chat, chat.Character, nil),
|
||||
Messages: msgList,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetChatMessages 分页获取对话消息
|
||||
func (cs *ChatService) GetChatMessages(chatID uint, req request.ChatMessagesRequest, userID uint) (response.MessageListResponse, error) {
|
||||
// 验证对话归属
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||||
if err != nil {
|
||||
return response.MessageListResponse{}, errors.New("对话不存在")
|
||||
}
|
||||
|
||||
db := global.GVA_DB.Model(&app.AIMessage{}).
|
||||
Where("chat_id = ? AND is_deleted = ?", chatID, false)
|
||||
|
||||
var total int64
|
||||
db.Count(&total)
|
||||
|
||||
var messages []app.AIMessage
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
err = db.Order("sequence_number ASC").
|
||||
Offset(offset).Limit(req.PageSize).
|
||||
Find(&messages).Error
|
||||
if err != nil {
|
||||
return response.MessageListResponse{}, err
|
||||
}
|
||||
|
||||
// 预加载角色信息
|
||||
var character *app.AICharacter
|
||||
if chat.CharacterID != nil {
|
||||
var c app.AICharacter
|
||||
global.GVA_DB.First(&c, *chat.CharacterID)
|
||||
character = &c
|
||||
}
|
||||
|
||||
list := make([]response.MessageResponse, len(messages))
|
||||
for i, msg := range messages {
|
||||
list[i] = toMessageResponse(&msg, character)
|
||||
}
|
||||
|
||||
return response.MessageListResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteChat 删除对话
|
||||
func (cs *ChatService) DeleteChat(chatID uint, userID uint) error {
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("对话不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除消息
|
||||
if err := tx.Where("chat_id = ?", chatID).Delete(&app.AIMessage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除消息变体
|
||||
if err := tx.Where("message_id IN (?)",
|
||||
tx.Model(&app.AIMessage{}).Select("id").Where("chat_id = ?", chatID),
|
||||
).Delete(&app.AIMessageSwipe{}).Error; err != nil {
|
||||
// 忽略错误,可能没有变体
|
||||
}
|
||||
// 删除对话
|
||||
return tx.Delete(&chat).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 消息操作 ====================
|
||||
|
||||
// SaveUserMessage 保存用户消息(内部方法)
|
||||
func (cs *ChatService) SaveUserMessage(chatID uint, userID uint, content string) (*app.AIMessage, error) {
|
||||
// 获取当前最大序号
|
||||
var maxSeq int
|
||||
global.GVA_DB.Model(&app.AIMessage{}).
|
||||
Where("chat_id = ?", chatID).
|
||||
Select("COALESCE(MAX(sequence_number), 0)").
|
||||
Scan(&maxSeq)
|
||||
|
||||
msg := &app.AIMessage{
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Role: "user",
|
||||
SenderID: &userID,
|
||||
SequenceNumber: maxSeq + 1,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(msg).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 更新对话的消息数和最后消息时间
|
||||
now := time.Now()
|
||||
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
|
||||
"message_count": gorm.Expr("message_count + 1"),
|
||||
"last_message_at": now,
|
||||
})
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// SaveAssistantMessage 保存AI回复消息(内部方法)
|
||||
func (cs *ChatService) SaveAssistantMessage(chatID uint, characterID *uint, content string, model string, promptTokens, completionTokens int) (*app.AIMessage, error) {
|
||||
var maxSeq int
|
||||
global.GVA_DB.Model(&app.AIMessage{}).
|
||||
Where("chat_id = ?", chatID).
|
||||
Select("COALESCE(MAX(sequence_number), 0)").
|
||||
Scan(&maxSeq)
|
||||
|
||||
msg := &app.AIMessage{
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Role: "assistant",
|
||||
CharacterID: characterID,
|
||||
SequenceNumber: maxSeq + 1,
|
||||
Model: model,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(msg).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
global.GVA_DB.Model(&app.AIChat{}).Where("id = ?", chatID).Updates(map[string]interface{}{
|
||||
"message_count": gorm.Expr("message_count + 1"),
|
||||
"last_message_at": now,
|
||||
})
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// EditMessage 编辑消息
|
||||
func (cs *ChatService) EditMessage(req request.EditMessageRequest, userID uint) (response.MessageResponse, error) {
|
||||
var msg app.AIMessage
|
||||
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
|
||||
Where("ai_messages.id = ? AND ai_chats.user_id = ?", req.MessageID, userID).
|
||||
First(&msg).Error
|
||||
if err != nil {
|
||||
return response.MessageResponse{}, errors.New("消息不存在")
|
||||
}
|
||||
|
||||
msg.Content = req.Content
|
||||
if err := global.GVA_DB.Save(&msg).Error; err != nil {
|
||||
return response.MessageResponse{}, err
|
||||
}
|
||||
|
||||
return toMessageResponse(&msg, nil), nil
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息(软删除)
|
||||
func (cs *ChatService) DeleteMessage(messageID uint, userID uint) error {
|
||||
var msg app.AIMessage
|
||||
err := global.GVA_DB.Joins("JOIN ai_chats ON ai_chats.id = ai_messages.chat_id").
|
||||
Where("ai_messages.id = ? AND ai_chats.user_id = ?", messageID, userID).
|
||||
First(&msg).Error
|
||||
if err != nil {
|
||||
return errors.New("消息不存在")
|
||||
}
|
||||
|
||||
return global.GVA_DB.Model(&msg).Update("is_deleted", true).Error
|
||||
}
|
||||
|
||||
// GetChatForAI 获取对话用于AI调用的完整上下文(内部方法)
|
||||
func (cs *ChatService) GetChatForAI(chatID uint, userID uint) (*app.AIChat, *app.AICharacter, []app.AIMessage, error) {
|
||||
var chat app.AIChat
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", chatID, userID).First(&chat).Error
|
||||
if err != nil {
|
||||
return nil, nil, nil, errors.New("对话不存在")
|
||||
}
|
||||
|
||||
var character *app.AICharacter
|
||||
if chat.CharacterID != nil {
|
||||
var c app.AICharacter
|
||||
if err := global.GVA_DB.First(&c, *chat.CharacterID).Error; err == nil {
|
||||
character = &c
|
||||
}
|
||||
}
|
||||
|
||||
// 获取历史消息(最近30条)
|
||||
var messages []app.AIMessage
|
||||
global.GVA_DB.Where("chat_id = ? AND is_deleted = ?", chatID, false).
|
||||
Order("sequence_number ASC").
|
||||
Limit(30).
|
||||
Find(&messages)
|
||||
|
||||
return &chat, character, messages, nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
func toChatResponse(chat *app.AIChat, character *app.AICharacter, lastMsg *response.MessageBrief) response.ChatResponse {
|
||||
resp := response.ChatResponse{
|
||||
ID: chat.ID,
|
||||
Title: chat.Title,
|
||||
CharacterID: chat.CharacterID,
|
||||
ChatType: chat.ChatType,
|
||||
LastMessageAt: chat.LastMessageAt,
|
||||
MessageCount: chat.MessageCount,
|
||||
IsPinned: chat.IsPinned,
|
||||
LastMessage: lastMsg,
|
||||
CreatedAt: chat.CreatedAt,
|
||||
}
|
||||
|
||||
if character != nil {
|
||||
resp.CharacterName = character.Name
|
||||
resp.CharacterAvatar = character.Avatar
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func toMessageResponse(msg *app.AIMessage, character *app.AICharacter) response.MessageResponse {
|
||||
resp := response.MessageResponse{
|
||||
ID: msg.ID,
|
||||
ChatID: msg.ChatID,
|
||||
Content: msg.Content,
|
||||
Role: msg.Role,
|
||||
CharacterID: msg.CharacterID,
|
||||
Model: msg.Model,
|
||||
PromptTokens: msg.PromptTokens,
|
||||
CompletionTokens: msg.CompletionTokens,
|
||||
TotalTokens: msg.TotalTokens,
|
||||
SequenceNumber: msg.SequenceNumber,
|
||||
CreatedAt: msg.CreatedAt,
|
||||
}
|
||||
|
||||
if msg.Role == "assistant" && character != nil {
|
||||
resp.CharacterName = character.Name
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -6,4 +6,6 @@ type AppServiceGroup struct {
|
||||
WorldInfoService
|
||||
ExtensionService
|
||||
RegexScriptService
|
||||
ProviderService
|
||||
ChatService
|
||||
}
|
||||
|
||||
@@ -21,15 +21,70 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// extensionDataDir 扩展本地存储根目录
|
||||
// 与原版 SillyTavern 完全一致的路径结构:scripts/extensions/third-party/{name}/
|
||||
// 扩展 JS 中的相对路径 import(如 ../../../../../script.js)依赖此目录层级来正确解析
|
||||
// 所有 SillyTavern 核心脚本和扩展文件统一存储在 data/st-core-scripts/ 下,独立于 web-app/
|
||||
// 扩展代码是公共的(不按用户隔离),用户间差异仅在于数据库中的配置和启用状态
|
||||
const extensionDataDir = "data/st-core-scripts/scripts/extensions/third-party"
|
||||
|
||||
// getExtensionStorePath 获取扩展的本地存储路径: {extensionDataDir}/{extensionName}/
|
||||
func getExtensionStorePath(extensionName string) string {
|
||||
return filepath.Join(extensionDataDir, extensionName)
|
||||
}
|
||||
|
||||
// GetExtensionAssetLocalPath 获取扩展资源文件的本地绝对路径
|
||||
func (es *ExtensionService) GetExtensionAssetLocalPath(extensionName string, assetPath string) (string, error) {
|
||||
storePath := getExtensionStorePath(extensionName)
|
||||
fullPath := filepath.Join(storePath, assetPath)
|
||||
|
||||
// 安全检查:防止路径遍历攻击
|
||||
absStore, _ := filepath.Abs(storePath)
|
||||
absFile, _ := filepath.Abs(fullPath)
|
||||
if !strings.HasPrefix(absFile, absStore) {
|
||||
return "", errors.New("非法的资源路径")
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("资源文件不存在: %s", assetPath)
|
||||
}
|
||||
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
// ensureExtensionDir 确保扩展存储目录存在
|
||||
func ensureExtensionDir(extensionName string) (string, error) {
|
||||
storePath := getExtensionStorePath(extensionName)
|
||||
if err := os.MkdirAll(storePath, 0755); err != nil {
|
||||
return "", fmt.Errorf("创建扩展存储目录失败: %w", err)
|
||||
}
|
||||
return storePath, nil
|
||||
}
|
||||
|
||||
// removeExtensionDir 删除扩展的本地存储目录
|
||||
func removeExtensionDir(extensionName string) error {
|
||||
storePath := getExtensionStorePath(extensionName)
|
||||
if _, err := os.Stat(storePath); os.IsNotExist(err) {
|
||||
return nil // 目录不存在,无需删除
|
||||
}
|
||||
return os.RemoveAll(storePath)
|
||||
}
|
||||
|
||||
type ExtensionService struct{}
|
||||
|
||||
// CreateExtension 创建/安装扩展
|
||||
func (es *ExtensionService) CreateExtension(userID uint, req *request.CreateExtensionRequest) (*app.AIExtension, error) {
|
||||
// 校验名称
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("扩展名称不能为空")
|
||||
}
|
||||
|
||||
// 检查扩展是否已存在
|
||||
var existing app.AIExtension
|
||||
err := global.GVA_DB.Where("user_id = ? AND name = ?", userID, req.Name).First(&existing).Error
|
||||
if err == nil {
|
||||
return nil, errors.New("扩展已存在")
|
||||
return nil, fmt.Errorf("扩展 %s 已存在", req.Name)
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
@@ -136,15 +191,18 @@ func (es *ExtensionService) DeleteExtension(userID, extensionID uint, deleteFile
|
||||
return errors.New("系统内置扩展不允许删除")
|
||||
}
|
||||
|
||||
// TODO: 如果 deleteFiles=true,删除扩展文件
|
||||
// 这需要文件系统支持
|
||||
// 删除本地扩展文件(与原版 SillyTavern 一致:卸载扩展时清理本地文件)
|
||||
if err := removeExtensionDir(extension.Name); err != nil {
|
||||
global.GVA_LOG.Warn("删除扩展本地文件失败", zap.Error(err), zap.String("name", extension.Name))
|
||||
// 不阻断删除流程
|
||||
}
|
||||
|
||||
// 删除扩展(配置已经在扩展记录的 Settings 字段中,无需单独删除)
|
||||
// 删除数据库记录
|
||||
if err := global.GVA_DB.Delete(&extension).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID))
|
||||
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID), zap.String("name", extension.Name))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,6 +215,15 @@ func (es *ExtensionService) GetExtension(userID, extensionID uint) (*app.AIExten
|
||||
return &extension, nil
|
||||
}
|
||||
|
||||
// GetExtensionByID 通过扩展 ID 获取扩展信息(不限制用户,用于公开资源路由)
|
||||
func (es *ExtensionService) GetExtensionByID(extensionID uint) (*app.AIExtension, error) {
|
||||
var extension app.AIExtension
|
||||
if err := global.GVA_DB.Where("id = ?", extensionID).First(&extension).Error; err != nil {
|
||||
return nil, errors.New("扩展不存在")
|
||||
}
|
||||
return &extension, nil
|
||||
}
|
||||
|
||||
// GetExtensionList 获取扩展列表
|
||||
func (es *ExtensionService) GetExtensionList(userID uint, req *request.ExtensionListRequest) (*response.ExtensionListResponse, error) {
|
||||
var extensions []app.AIExtension
|
||||
@@ -550,9 +617,37 @@ func isGitURL(url string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装
|
||||
// GetExtensionAssetURL 根据扩展的安装来源构建资源文件的远程 URL
|
||||
func (es *ExtensionService) GetExtensionAssetURL(extension *app.AIExtension, assetPath string) (string, error) {
|
||||
if extension.SourceURL == "" {
|
||||
return "", errors.New("扩展没有源地址")
|
||||
}
|
||||
|
||||
sourceURL := strings.TrimSuffix(strings.TrimSuffix(extension.SourceURL, "/"), ".git")
|
||||
branch := extension.Branch
|
||||
if branch == "" {
|
||||
branch = "main"
|
||||
}
|
||||
|
||||
// GitLab: repo/-/raw/branch/path
|
||||
if strings.Contains(sourceURL, "gitlab.com") {
|
||||
return fmt.Sprintf("%s/-/raw/%s/%s", sourceURL, branch, assetPath), nil
|
||||
}
|
||||
// GitHub: raw.githubusercontent.com/user/repo/branch/path
|
||||
if strings.Contains(sourceURL, "github.com") {
|
||||
rawURL := strings.Replace(sourceURL, "github.com", "raw.githubusercontent.com", 1)
|
||||
return fmt.Sprintf("%s/%s/%s", rawURL, branch, assetPath), nil
|
||||
}
|
||||
// Gitee: repo/raw/branch/path
|
||||
if strings.Contains(sourceURL, "gitee.com") {
|
||||
return fmt.Sprintf("%s/raw/%s/%s", sourceURL, branch, assetPath), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s", sourceURL, assetPath), nil
|
||||
}
|
||||
|
||||
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装(同时下载资源文件到本地)
|
||||
func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manifestURL string) (*app.AIExtension, error) {
|
||||
// 创建 HTTP 客户端
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
@@ -568,7 +663,6 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
return nil, fmt.Errorf("下载 manifest.json 失败: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应内容
|
||||
manifestData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取 manifest.json 失败: %w", err)
|
||||
@@ -580,21 +674,63 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if manifest.Name == "" {
|
||||
return nil, errors.New("manifest.json 缺少 name 字段")
|
||||
// 获取有效名称
|
||||
effectiveName := manifest.GetEffectiveName()
|
||||
if effectiveName == "" {
|
||||
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
|
||||
}
|
||||
|
||||
// 检查扩展是否已存在
|
||||
var existing app.AIExtension
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建本地存储目录并保存 manifest.json
|
||||
storePath, err := ensureExtensionDir(effectiveName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(storePath, "manifest.json"), manifestData, 0644); err != nil {
|
||||
_ = removeExtensionDir(effectiveName)
|
||||
return nil, fmt.Errorf("保存 manifest.json 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取 manifest URL 的基础目录(用于下载关联资源)
|
||||
baseURL := manifestURL[:strings.LastIndex(manifestURL, "/")+1]
|
||||
|
||||
// 下载 JS/CSS 等资源文件到本地
|
||||
filesToDownload := []string{}
|
||||
if entry := manifest.GetEffectiveEntry(); entry != "" {
|
||||
filesToDownload = append(filesToDownload, entry)
|
||||
}
|
||||
if style := manifest.GetEffectiveStyle(); style != "" {
|
||||
filesToDownload = append(filesToDownload, style)
|
||||
}
|
||||
filesToDownload = append(filesToDownload, manifest.Assets...)
|
||||
|
||||
for _, file := range filesToDownload {
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
fileURL := baseURL + file
|
||||
if err := downloadFileToLocal(client, fileURL, filepath.Join(storePath, file)); err != nil {
|
||||
global.GVA_LOG.Warn("下载扩展资源文件失败(非致命)",
|
||||
zap.String("file", file),
|
||||
zap.String("url", fileURL),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("扩展文件已保存到本地",
|
||||
zap.String("name", effectiveName),
|
||||
zap.String("path", storePath))
|
||||
|
||||
// 将 manifest 转换为 map[string]interface{}
|
||||
var manifestMap map[string]interface{}
|
||||
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
|
||||
@@ -603,13 +739,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
|
||||
// 构建创建请求
|
||||
createReq := &request.CreateExtensionRequest{
|
||||
Name: manifest.Name,
|
||||
Name: effectiveName,
|
||||
DisplayName: manifest.DisplayName,
|
||||
Version: manifest.Version,
|
||||
Author: manifest.Author,
|
||||
Description: manifest.Description,
|
||||
Homepage: manifest.Homepage,
|
||||
Repository: manifest.Repository, // 使用 manifest 中的 repository
|
||||
Homepage: manifest.GetEffectiveHomepage(),
|
||||
Repository: manifest.Repository,
|
||||
License: manifest.License,
|
||||
Tags: manifest.Tags,
|
||||
ExtensionType: manifest.Type,
|
||||
@@ -617,13 +753,13 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
Dependencies: manifest.Dependencies,
|
||||
Conflicts: manifest.Conflicts,
|
||||
ManifestData: manifestMap,
|
||||
ScriptPath: manifest.Entry,
|
||||
StylePath: manifest.Style,
|
||||
ScriptPath: manifest.GetEffectiveEntry(),
|
||||
StylePath: manifest.GetEffectiveStyle(),
|
||||
AssetsPaths: manifest.Assets,
|
||||
Settings: manifest.Settings,
|
||||
Options: manifest.Options,
|
||||
InstallSource: "url",
|
||||
SourceURL: manifestURL, // 记录原始 URL 用于更新
|
||||
SourceURL: manifestURL,
|
||||
AutoUpdate: manifest.AutoUpdate,
|
||||
Metadata: nil,
|
||||
}
|
||||
@@ -636,6 +772,7 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
// 创建扩展
|
||||
extension, err := es.CreateExtension(userID, createReq)
|
||||
if err != nil {
|
||||
_ = removeExtensionDir(effectiveName)
|
||||
return nil, fmt.Errorf("创建扩展失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -647,6 +784,31 @@ func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manif
|
||||
return extension, nil
|
||||
}
|
||||
|
||||
// downloadFileToLocal 下载远程文件到本地路径
|
||||
func downloadFileToLocal(client *http.Client, url string, localPath string) error {
|
||||
// 确保目标文件的父目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(localPath, data, 0644)
|
||||
}
|
||||
|
||||
// UpgradeExtension 升级扩展版本(根据安装来源自动选择更新方式)
|
||||
func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force bool) (*app.AIExtension, error) {
|
||||
// 获取扩展信息
|
||||
@@ -672,7 +834,7 @@ func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force boo
|
||||
}
|
||||
}
|
||||
|
||||
// updateExtensionFromGit 从 Git 仓库更新扩展
|
||||
// updateExtensionFromGit 从 Git 仓库更新扩展(先删除旧记录和文件,再重新安装)
|
||||
func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.AIExtension, force bool) (*app.AIExtension, error) {
|
||||
if extension.SourceURL == "" {
|
||||
return nil, errors.New("缺少 Git 仓库 URL")
|
||||
@@ -683,11 +845,17 @@ func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.A
|
||||
zap.String("sourceUrl", extension.SourceURL),
|
||||
zap.String("branch", extension.Branch))
|
||||
|
||||
// 重新克隆(简单方式,避免处理本地修改)
|
||||
// 先删除旧的数据库记录和本地文件
|
||||
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
|
||||
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
|
||||
}
|
||||
_ = removeExtensionDir(extension.Name)
|
||||
|
||||
// 重新克隆安装
|
||||
return es.InstallExtensionFromGit(userID, extension.SourceURL, extension.Branch)
|
||||
}
|
||||
|
||||
// updateExtensionFromURL 从 URL 更新扩展(重新下载 manifest.json)
|
||||
// updateExtensionFromURL 从 URL 更新扩展(先删除旧记录和文件,再重新下载安装)
|
||||
func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.AIExtension) (*app.AIExtension, error) {
|
||||
if extension.SourceURL == "" {
|
||||
return nil, errors.New("缺少 Manifest URL")
|
||||
@@ -697,18 +865,24 @@ func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.A
|
||||
zap.String("name", extension.Name),
|
||||
zap.String("sourceUrl", extension.SourceURL))
|
||||
|
||||
// 重新下载并安装
|
||||
// 先删除旧的数据库记录和本地文件
|
||||
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
|
||||
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
|
||||
}
|
||||
_ = removeExtensionDir(extension.Name)
|
||||
|
||||
// 重新下载安装
|
||||
return es.downloadAndInstallFromManifestURL(userID, extension.SourceURL)
|
||||
}
|
||||
|
||||
// InstallExtensionFromGit 从 Git URL 安装扩展
|
||||
// InstallExtensionFromGit 从 Git URL 安装扩展(与原版 SillyTavern 一致:将源码下载到本地)
|
||||
func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch string) (*app.AIExtension, error) {
|
||||
// 验证 Git URL
|
||||
if !strings.Contains(gitUrl, "://") && !strings.HasSuffix(gitUrl, ".git") {
|
||||
return nil, errors.New("无效的 Git URL")
|
||||
}
|
||||
|
||||
// 创建临时目录
|
||||
// 先 clone 到临时目录读取 manifest(获取扩展名后再移动到正式目录)
|
||||
tempDir, err := os.MkdirTemp("", "extension-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建临时目录失败: %w", err)
|
||||
@@ -717,10 +891,9 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
|
||||
|
||||
global.GVA_LOG.Info("开始从 Git 克隆扩展",
|
||||
zap.String("gitUrl", gitUrl),
|
||||
zap.String("branch", branch),
|
||||
zap.String("tempDir", tempDir))
|
||||
zap.String("branch", branch))
|
||||
|
||||
// 执行 git clone
|
||||
// 执行 git clone(浅克隆)
|
||||
cmd := exec.Command("git", "clone", "--depth=1", "--branch="+branch, gitUrl, tempDir)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
@@ -744,31 +917,53 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
|
||||
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取有效名称(兼容 SillyTavern manifest 没有 name 字段的情况)
|
||||
effectiveName := manifest.GetEffectiveName()
|
||||
if effectiveName == "" {
|
||||
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
|
||||
}
|
||||
|
||||
// 检查扩展是否已存在
|
||||
var existing app.AIExtension
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, manifest.Name).First(&existing).Error
|
||||
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", manifest.Name)
|
||||
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 将扩展文件保存到公共目录: web-app/public/scripts/extensions/third-party/{extensionName}/
|
||||
storePath, err := ensureExtensionDir(effectiveName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 清空目标目录(如果有残留文件)后复制 clone 内容
|
||||
_ = os.RemoveAll(storePath)
|
||||
if err := copyDir(tempDir, storePath); err != nil {
|
||||
return nil, fmt.Errorf("保存扩展文件失败: %w", err)
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("扩展文件已保存到本地",
|
||||
zap.String("name", effectiveName),
|
||||
zap.String("path", storePath))
|
||||
|
||||
// 将 manifest 转换为 map[string]interface{}
|
||||
var manifestMap map[string]interface{}
|
||||
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
|
||||
return nil, fmt.Errorf("转换 manifest 失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建创建请求
|
||||
// 构建创建请求(使用兼容方法获取字段值)
|
||||
createReq := &request.CreateExtensionRequest{
|
||||
Name: manifest.Name,
|
||||
Name: effectiveName,
|
||||
DisplayName: manifest.DisplayName,
|
||||
Version: manifest.Version,
|
||||
Author: manifest.Author,
|
||||
Description: manifest.Description,
|
||||
Homepage: manifest.Homepage,
|
||||
Repository: manifest.Repository, // 使用 manifest 中的 repository
|
||||
Homepage: manifest.GetEffectiveHomepage(),
|
||||
Repository: manifest.Repository,
|
||||
License: manifest.License,
|
||||
Tags: manifest.Tags,
|
||||
ExtensionType: manifest.Type,
|
||||
@@ -776,28 +971,69 @@ func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch
|
||||
Dependencies: manifest.Dependencies,
|
||||
Conflicts: manifest.Conflicts,
|
||||
ManifestData: manifestMap,
|
||||
ScriptPath: manifest.Entry,
|
||||
StylePath: manifest.Style,
|
||||
ScriptPath: manifest.GetEffectiveEntry(),
|
||||
StylePath: manifest.GetEffectiveStyle(),
|
||||
AssetsPaths: manifest.Assets,
|
||||
Settings: manifest.Settings,
|
||||
Options: manifest.Options,
|
||||
InstallSource: "git",
|
||||
SourceURL: gitUrl, // 记录 Git URL 用于更新
|
||||
Branch: branch, // 记录分支
|
||||
SourceURL: gitUrl,
|
||||
Branch: branch,
|
||||
AutoUpdate: manifest.AutoUpdate,
|
||||
Metadata: manifest.Metadata,
|
||||
}
|
||||
|
||||
// 确保扩展类型有效
|
||||
if createReq.ExtensionType == "" {
|
||||
createReq.ExtensionType = "ui"
|
||||
}
|
||||
|
||||
// 创建扩展记录
|
||||
extension, err := es.CreateExtension(userID, createReq)
|
||||
if err != nil {
|
||||
// 创建失败则清理本地文件
|
||||
_ = removeExtensionDir(effectiveName)
|
||||
return nil, fmt.Errorf("创建扩展记录失败: %w", err)
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("从 Git 安装扩展成功",
|
||||
zap.Uint("extensionID", extension.ID),
|
||||
zap.String("name", extension.Name),
|
||||
zap.String("version", extension.Version))
|
||||
zap.String("version", extension.Version),
|
||||
zap.String("localPath", storePath))
|
||||
|
||||
return extension, nil
|
||||
}
|
||||
|
||||
// copyDir 递归复制目录(排除 .git 目录以节省空间)
|
||||
func copyDir(src, dst string) error {
|
||||
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 计算相对路径
|
||||
relPath, err := filepath.Rel(src, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 排除 .git 目录
|
||||
if info.IsDir() && info.Name() == ".git" {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
dstPath := filepath.Join(dst, relPath)
|
||||
|
||||
if info.IsDir() {
|
||||
return os.MkdirAll(dstPath, info.Mode())
|
||||
}
|
||||
|
||||
// 复制文件
|
||||
srcFile, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dstPath, srcFile, info.Mode())
|
||||
})
|
||||
}
|
||||
|
||||
949
server/service/app/provider.go
Normal file
949
server/service/app/provider.go
Normal file
@@ -0,0 +1,949 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/app/response"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProviderService struct{}
|
||||
|
||||
// ==================== 提供商 CRUD ====================
|
||||
|
||||
// CreateProvider 创建AI提供商
|
||||
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||||
// 加密 API Key
|
||||
encryptedKey := encryptAPIKey(req.APIKey)
|
||||
|
||||
// 序列化额外配置
|
||||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||||
if req.APIConfig == nil {
|
||||
apiConfigJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// 根据类型确定能力
|
||||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||||
capJSON, _ := json.Marshal(capabilities)
|
||||
|
||||
// 如果 BaseURL 为空,使用默认地址
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
provider := app.AIProvider{
|
||||
UserID: &userID,
|
||||
ProviderName: req.ProviderName,
|
||||
ProviderType: req.ProviderType,
|
||||
BaseURL: baseURL,
|
||||
APIKey: encryptedKey,
|
||||
APIConfig: datatypes.JSON(apiConfigJSON),
|
||||
Capabilities: datatypes.JSON(capJSON),
|
||||
IsEnabled: true,
|
||||
IsDefault: false,
|
||||
}
|
||||
|
||||
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建提供商
|
||||
if err := tx.Create(&provider).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果携带了模型列表,同时创建模型
|
||||
if len(req.Models) > 0 {
|
||||
for _, m := range req.Models {
|
||||
modelConfigJSON, _ := json.Marshal(m.Config)
|
||||
if m.Config == nil {
|
||||
modelConfigJSON = []byte("{}")
|
||||
}
|
||||
isEnabled := true
|
||||
if m.IsEnabled != nil {
|
||||
isEnabled = *m.IsEnabled
|
||||
}
|
||||
model := app.AIModel{
|
||||
ProviderID: provider.ID,
|
||||
ModelName: m.ModelName,
|
||||
DisplayName: m.DisplayName,
|
||||
ModelType: m.ModelType,
|
||||
Config: datatypes.JSON(modelConfigJSON),
|
||||
IsEnabled: isEnabled,
|
||||
}
|
||||
if err := tx.Create(&model).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 没有指定模型时,自动添加预设模型
|
||||
presets := getPresetModels(req.ProviderType)
|
||||
for _, p := range presets {
|
||||
model := app.AIModel{
|
||||
ProviderID: provider.ID,
|
||||
ModelName: p.ModelName,
|
||||
DisplayName: p.DisplayName,
|
||||
ModelType: p.ModelType,
|
||||
Config: datatypes.JSON([]byte("{}")),
|
||||
IsEnabled: true,
|
||||
}
|
||||
if err := tx.Create(&model).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是用户的第一个提供商,自动设为默认
|
||||
var count int64
|
||||
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
|
||||
if count == 0 {
|
||||
provider.IsDefault = true
|
||||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
return ps.GetProviderDetail(provider.ID, userID)
|
||||
}
|
||||
|
||||
// GetProviderList 获取用户的提供商列表
|
||||
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
|
||||
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
|
||||
|
||||
if req.Keyword != "" {
|
||||
keyword := "%" + req.Keyword + "%"
|
||||
db = db.Where("provider_name ILIKE ?", keyword)
|
||||
}
|
||||
|
||||
var total int64
|
||||
db.Count(&total)
|
||||
|
||||
var providers []app.AIProvider
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
|
||||
Offset(offset).Limit(req.PageSize).Find(&providers).Error
|
||||
if err != nil {
|
||||
return response.ProviderListResponse{}, err
|
||||
}
|
||||
|
||||
// 获取所有提供商的模型
|
||||
providerIDs := make([]uint, len(providers))
|
||||
for i, p := range providers {
|
||||
providerIDs[i] = p.ID
|
||||
}
|
||||
|
||||
var models []app.AIModel
|
||||
if len(providerIDs) > 0 {
|
||||
global.GVA_DB.Where("provider_id IN ?", providerIDs).
|
||||
Order("model_type ASC, model_name ASC").Find(&models)
|
||||
}
|
||||
|
||||
// 按提供商ID分组模型
|
||||
modelMap := make(map[uint][]app.AIModel)
|
||||
for _, m := range models {
|
||||
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
|
||||
}
|
||||
|
||||
list := make([]response.ProviderResponse, len(providers))
|
||||
for i, p := range providers {
|
||||
list[i] = toProviderResponse(&p, modelMap[p.ID])
|
||||
}
|
||||
|
||||
return response.ProviderListResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderDetail 获取提供商详情
|
||||
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
var models []app.AIModel
|
||||
global.GVA_DB.Where("provider_id = ?", providerID).
|
||||
Order("model_type ASC, model_name ASC").Find(&models)
|
||||
|
||||
return toProviderResponse(&provider, models), nil
|
||||
}
|
||||
|
||||
// UpdateProvider 更新提供商
|
||||
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
updates := map[string]interface{}{
|
||||
"provider_name": req.ProviderName,
|
||||
"provider_type": req.ProviderType,
|
||||
"base_url": req.BaseURL,
|
||||
}
|
||||
|
||||
// APIKey 不为空时才更新
|
||||
if req.APIKey != "" {
|
||||
updates["api_key"] = encryptAPIKey(req.APIKey)
|
||||
}
|
||||
|
||||
if req.APIConfig != nil {
|
||||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||||
updates["api_config"] = datatypes.JSON(apiConfigJSON)
|
||||
}
|
||||
|
||||
if req.IsEnabled != nil {
|
||||
updates["is_enabled"] = *req.IsEnabled
|
||||
}
|
||||
|
||||
if req.SortOrder != nil {
|
||||
updates["sort_order"] = *req.SortOrder
|
||||
}
|
||||
|
||||
// 更新能力
|
||||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||||
capJSON, _ := json.Marshal(capabilities)
|
||||
updates["capabilities"] = datatypes.JSON(capJSON)
|
||||
|
||||
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理设置默认
|
||||
if req.IsDefault != nil && *req.IsDefault {
|
||||
// 先取消其他默认
|
||||
if err := tx.Model(&app.AIProvider{}).
|
||||
Where("user_id = ? AND id != ?", userID, req.ID).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
return ps.GetProviderDetail(req.ID, userID)
|
||||
}
|
||||
|
||||
// DeleteProvider 删除提供商
|
||||
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除关联的模型
|
||||
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除提供商
|
||||
if err := tx.Delete(&provider).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果删除的是默认提供商,自动将第一个提供商设为默认
|
||||
if provider.IsDefault {
|
||||
var firstProvider app.AIProvider
|
||||
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
|
||||
tx.Model(&firstProvider).Update("is_default", true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultProvider 设置默认提供商
|
||||
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 先取消所有默认
|
||||
if err := tx.Model(&app.AIProvider{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置新默认
|
||||
return tx.Model(&provider).Update("is_default", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 模型 CRUD ====================
|
||||
|
||||
// AddModel 为提供商添加模型
|
||||
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
|
||||
// 验证提供商归属
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.ModelResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
if req.Config == nil {
|
||||
configJSON = []byte("{}")
|
||||
}
|
||||
|
||||
isEnabled := true
|
||||
if req.IsEnabled != nil {
|
||||
isEnabled = *req.IsEnabled
|
||||
}
|
||||
|
||||
model := app.AIModel{
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
DisplayName: req.DisplayName,
|
||||
ModelType: req.ModelType,
|
||||
Config: datatypes.JSON(configJSON),
|
||||
IsEnabled: isEnabled,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(&model).Error; err != nil {
|
||||
return response.ModelResponse{}, err
|
||||
}
|
||||
|
||||
return toModelResponse(&model), nil
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
|
||||
var model app.AIModel
|
||||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||||
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return response.ModelResponse{}, errors.New("模型不存在")
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"model_name": req.ModelName,
|
||||
"display_name": req.DisplayName,
|
||||
"model_type": req.ModelType,
|
||||
}
|
||||
|
||||
if req.Config != nil {
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
updates["config"] = datatypes.JSON(configJSON)
|
||||
}
|
||||
|
||||
if req.IsEnabled != nil {
|
||||
updates["is_enabled"] = *req.IsEnabled
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
|
||||
return response.ModelResponse{}, err
|
||||
}
|
||||
|
||||
// 重新查询
|
||||
global.GVA_DB.First(&model, model.ID)
|
||||
return toModelResponse(&model), nil
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
|
||||
var model app.AIModel
|
||||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||||
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return errors.New("模型不存在")
|
||||
}
|
||||
|
||||
return global.GVA_DB.Delete(&model).Error
|
||||
}
|
||||
|
||||
// ==================== 连通性测试 ====================
|
||||
|
||||
// TestProvider 测试提供商连通性
|
||||
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
|
||||
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TestExistingProvider 测试已保存的提供商连通性
|
||||
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.TestProvider(request.TestProviderRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 辅助查询 ====================
|
||||
|
||||
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
|
||||
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
|
||||
return []response.ProviderTypeOption{
|
||||
{
|
||||
Value: "openai",
|
||||
Label: "OpenAI",
|
||||
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
|
||||
DefaultURL: "https://api.openai.com/v1",
|
||||
},
|
||||
{
|
||||
Value: "claude",
|
||||
Label: "Claude",
|
||||
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
|
||||
DefaultURL: "https://api.anthropic.com",
|
||||
},
|
||||
{
|
||||
Value: "gemini",
|
||||
Label: "Google Gemini",
|
||||
Description: "Google 的 Gemini 系列模型,支持多模态",
|
||||
DefaultURL: "https://generativelanguage.googleapis.com",
|
||||
},
|
||||
{
|
||||
Value: "custom",
|
||||
Label: "自定义(OpenAI 兼容)",
|
||||
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
|
||||
DefaultURL: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetPresetModels 获取指定提供商类型的预设模型列表
|
||||
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
|
||||
presets := getPresetModels(providerType)
|
||||
result := make([]response.PresetModelOption, len(presets))
|
||||
for i, p := range presets {
|
||||
result[i] = response.PresetModelOption{
|
||||
ModelName: p.ModelName,
|
||||
DisplayName: p.DisplayName,
|
||||
ModelType: p.ModelType,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
|
||||
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
|
||||
First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("请先配置 AI 接口")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// GetDecryptedAPIKey 获取解密后的API密钥(内部方法,给AI调用用)
|
||||
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
|
||||
return decryptAPIKey(provider.APIKey)
|
||||
}
|
||||
|
||||
// ==================== 内部辅助函数 ====================
|
||||
|
||||
// toProviderResponse 转换为响应对象
|
||||
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
|
||||
apiConfig := json.RawMessage(p.APIConfig)
|
||||
if len(apiConfig) == 0 {
|
||||
apiConfig = json.RawMessage("{}")
|
||||
}
|
||||
|
||||
capabilities := json.RawMessage(p.Capabilities)
|
||||
if len(capabilities) == 0 {
|
||||
capabilities = json.RawMessage("[]")
|
||||
}
|
||||
|
||||
// 模型列表
|
||||
modelList := make([]response.ModelResponse, len(models))
|
||||
for i, m := range models {
|
||||
modelList[i] = toModelResponse(&m)
|
||||
}
|
||||
|
||||
// API Key 提示
|
||||
apiKeyHint := ""
|
||||
apiKeySet := false
|
||||
if p.APIKey != "" {
|
||||
apiKeySet = true
|
||||
decrypted := decryptAPIKey(p.APIKey)
|
||||
if len(decrypted) > 8 {
|
||||
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
|
||||
} else if len(decrypted) > 0 {
|
||||
apiKeyHint = "****"
|
||||
}
|
||||
}
|
||||
|
||||
return response.ProviderResponse{
|
||||
ID: p.ID,
|
||||
ProviderName: p.ProviderName,
|
||||
ProviderType: p.ProviderType,
|
||||
BaseURL: p.BaseURL,
|
||||
APIKeySet: apiKeySet,
|
||||
APIKeyHint: apiKeyHint,
|
||||
APIConfig: apiConfig,
|
||||
Capabilities: capabilities,
|
||||
IsEnabled: p.IsEnabled,
|
||||
IsDefault: p.IsDefault,
|
||||
SortOrder: p.SortOrder,
|
||||
Models: modelList,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// toModelResponse 转换模型为响应对象
|
||||
func toModelResponse(m *app.AIModel) response.ModelResponse {
|
||||
config := json.RawMessage(m.Config)
|
||||
if len(config) == 0 {
|
||||
config = json.RawMessage("{}")
|
||||
}
|
||||
return response.ModelResponse{
|
||||
ID: m.ID,
|
||||
ProviderID: m.ProviderID,
|
||||
ModelName: m.ModelName,
|
||||
DisplayName: m.DisplayName,
|
||||
ModelType: m.ModelType,
|
||||
Config: config,
|
||||
IsEnabled: m.IsEnabled,
|
||||
CreatedAt: m.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// encryptAPIKey 加密API密钥
|
||||
// TODO: 后续可以替换为更安全的加密方式(如 AES),当前使用简单的 Base64 编码
|
||||
func encryptAPIKey(key string) string {
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
// 简单的混淆处理,生产环境应替换为 AES 加密
|
||||
import_encoding := []byte(key)
|
||||
for i := range import_encoding {
|
||||
import_encoding[i] ^= 0x5A
|
||||
}
|
||||
return fmt.Sprintf("enc:%x", import_encoding)
|
||||
}
|
||||
|
||||
// decryptAPIKey 解密API密钥
|
||||
func decryptAPIKey(encrypted string) string {
|
||||
if encrypted == "" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(encrypted, "enc:") {
|
||||
return encrypted // 未加密的旧数据,直接返回
|
||||
}
|
||||
|
||||
hexStr := encrypted[4:]
|
||||
var data []byte
|
||||
fmt.Sscanf(hexStr, "%x", &data)
|
||||
for i := range data {
|
||||
data[i] ^= 0x5A
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// getDefaultBaseURL 获取默认API基础地址
|
||||
func getDefaultBaseURL(providerType string) string {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return "https://api.openai.com/v1"
|
||||
case "claude":
|
||||
return "https://api.anthropic.com"
|
||||
case "gemini":
|
||||
return "https://generativelanguage.googleapis.com"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultCapabilities 获取默认能力列表
|
||||
func getDefaultCapabilities(providerType string) []string {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return []string{"chat", "image_gen"}
|
||||
case "claude":
|
||||
return []string{"chat"}
|
||||
case "gemini":
|
||||
return []string{"chat", "image_gen"}
|
||||
case "custom":
|
||||
return []string{"chat"}
|
||||
default:
|
||||
return []string{"chat"}
|
||||
}
|
||||
}
|
||||
|
||||
// presetModel 预设模型内部结构
|
||||
type presetModel struct {
|
||||
ModelName string
|
||||
DisplayName string
|
||||
ModelType string
|
||||
}
|
||||
|
||||
// getPresetModels 获取预设模型列表
|
||||
func getPresetModels(providerType string) []presetModel {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return []presetModel{
|
||||
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
|
||||
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
|
||||
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
|
||||
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
|
||||
}
|
||||
case "claude":
|
||||
return []presetModel{
|
||||
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
|
||||
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
|
||||
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
|
||||
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
|
||||
}
|
||||
case "gemini":
|
||||
return []presetModel{
|
||||
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
|
||||
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
|
||||
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
|
||||
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
|
||||
}
|
||||
case "custom":
|
||||
return []presetModel{} // 自定义不提供预设
|
||||
default:
|
||||
return []presetModel{}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 获取远程模型列表 ====================
|
||||
|
||||
// FetchRemoteModels 从远程API获取可用模型列表
|
||||
// 所有提供商类型统一使用 baseURL + /models 端点(OpenAI 兼容格式)
|
||||
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
result := fetchModelsUniversal(baseURL, req.APIKey)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
|
||||
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 发送测试消息 ====================
|
||||
|
||||
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
|
||||
// 所有提供商类型统一使用 baseURL + /chat/completions 端点(OpenAI 兼容格式)
|
||||
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
message := req.Message
|
||||
if message == "" {
|
||||
message = "你好,请用一句话介绍你自己。"
|
||||
}
|
||||
|
||||
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SendTestMessageExisting 发送测试消息(已保存的提供商)
|
||||
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.SendTestMessage(request.SendTestMessageRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
ModelName: modelName,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 连通性测试实现 ====================
|
||||
|
||||
// testOpenAICompatible 测试 OpenAI 兼容接口
|
||||
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return response.TestProviderResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// 解析模型列表
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var modelNames []string
|
||||
if err := json.Unmarshal(body, &modelsResp); err == nil {
|
||||
for _, m := range modelsResp.Data {
|
||||
modelNames = append(modelNames, m.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return response.TestProviderResponse{
|
||||
Success: true,
|
||||
Message: "连接成功",
|
||||
Models: modelNames,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 获取远程模型列表实现 ====================
|
||||
|
||||
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
|
||||
// 使用 baseURL + /models 端点,Authorization: Bearer 鉴权
|
||||
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return response.FetchRemoteModelsResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
|
||||
var modelsData struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var models []response.RemoteModel
|
||||
if err := json.Unmarshal(body, &modelsData); err == nil {
|
||||
for _, m := range modelsData.Data {
|
||||
models = append(models, response.RemoteModel{
|
||||
ID: m.ID,
|
||||
OwnedBy: m.OwnedBy,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return response.FetchRemoteModelsResponse{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
|
||||
Models: models,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 发送测试消息实现 ====================
|
||||
|
||||
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
|
||||
// 使用 baseURL + /chat/completions 端点,Authorization: Bearer 鉴权
|
||||
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"max_tokens": 100,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": message},
|
||||
},
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var errResp struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
|
||||
}
|
||||
return response.SendTestMessageResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 OpenAI 兼容格式的 chat completion 响应
|
||||
var chatResp struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
|
||||
}
|
||||
|
||||
reply := ""
|
||||
if len(chatResp.Choices) > 0 {
|
||||
reply = chatResp.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return response.SendTestMessageResponse{
|
||||
Success: true,
|
||||
Message: "测试成功",
|
||||
Reply: reply,
|
||||
Model: chatResp.Model,
|
||||
Tokens: chatResp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user