🎨 优化扩展模块,完成ai接入和对话功能
This commit is contained in:
470
server/api/v1/app/chat.go
Normal file
470
server/api/v1/app/chat.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/middleware"
|
||||
appModel "git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/common/response"
|
||||
appService "git.echol.cn/loser/st/server/service/app"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ChatApi struct{}
|
||||
|
||||
// ==================== 对话管理 ====================
|
||||
|
||||
// CreateChat 创建对话
|
||||
// @Tags 对话
|
||||
// @Summary 创建新对话
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.CreateChatRequest true "角色卡ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatResponse} "创建成功"
|
||||
// @Router /app/chat [post]
|
||||
func (ca *ChatApi) CreateChat(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.CreateChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := chatService.CreateChat(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("创建对话失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(chat, c)
|
||||
}
|
||||
|
||||
// GetChatList 获取对话列表
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话列表
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatListResponse} "获取成功"
|
||||
// @Router /app/chat/list [get]
|
||||
func (ca *ChatApi) GetChatList(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.ChatListRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
|
||||
list, err := chatService.GetChatList(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取对话列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败", c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(list, c)
|
||||
}
|
||||
|
||||
// GetChatDetail 获取对话详情(含消息)
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话详情
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ChatDetailResponse} "获取成功"
|
||||
// @Router /app/chat/:id [get]
|
||||
func (ca *ChatApi) GetChatDetail(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
detail, err := chatService.GetChatDetail(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(detail, c)
|
||||
}
|
||||
|
||||
// GetChatMessages 分页获取消息
|
||||
// @Tags 对话
|
||||
// @Summary 获取对话消息
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Success 200 {object} response.Response{data=appResponse.MessageListResponse} "获取成功"
|
||||
// @Router /app/chat/:id/messages [get]
|
||||
func (ca *ChatApi) GetChatMessages(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
var req request.ChatMessagesRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 50
|
||||
}
|
||||
|
||||
messages, err := chatService.GetChatMessages(uint(id), req, userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(messages, c)
|
||||
}
|
||||
|
||||
// DeleteChat 删除对话
|
||||
// @Tags 对话
|
||||
// @Summary 删除对话
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "对话ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/chat/:id [delete]
|
||||
func (ca *ChatApi) DeleteChat(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := chatService.DeleteChat(uint(id), userID); err != nil {
|
||||
global.GVA_LOG.Error("删除对话失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 消息操作 ====================
|
||||
|
||||
// SendMessage 发送消息并获取AI回复(SSE 流式响应)
|
||||
// @Tags 对话
|
||||
// @Summary 发送消息(SSE流式)
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce text/event-stream
|
||||
// @Param data body request.SendMessageRequest true "消息内容"
|
||||
// @Router /app/chat/send [post]
|
||||
func (ca *ChatApi) SendMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.SendMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 获取对话上下文
|
||||
chat, character, historyMsgs, err := chatService.GetChatForAI(req.ChatID, userID)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 保存用户消息
|
||||
_, err = chatService.SaveUserMessage(chat.ID, userID, req.Content)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"code": 7, "msg": "保存消息失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2.1 世界书匹配(根据角色卡 + 历史消息)
|
||||
var worldInfoTexts []string
|
||||
if character != nil && chat.CharacterID != nil {
|
||||
// 收集最近的消息文本(用于关键词匹配),附加当前用户输入
|
||||
messagesForMatch := make([]string, 0, len(historyMsgs)+1)
|
||||
for _, m := range historyMsgs {
|
||||
if m.Content != "" {
|
||||
messagesForMatch = append(messagesForMatch, m.Content)
|
||||
}
|
||||
}
|
||||
if req.Content != "" {
|
||||
messagesForMatch = append(messagesForMatch, req.Content)
|
||||
}
|
||||
|
||||
matchReq := request.MatchWorldInfoRequest{
|
||||
CharacterID: *chat.CharacterID,
|
||||
Messages: messagesForMatch,
|
||||
ScanDepth: 10,
|
||||
MaxTokens: 2000,
|
||||
}
|
||||
|
||||
if result, wErr := worldInfoService.MatchWorldInfo(userID, &matchReq); wErr != nil {
|
||||
global.GVA_LOG.Warn("匹配世界书失败",
|
||||
zap.Uint("chatID", chat.ID),
|
||||
zap.Uint("characterID", *chat.CharacterID),
|
||||
zap.Error(wErr))
|
||||
} else if result != nil && len(result.Entries) > 0 {
|
||||
// 按 position 分组,暂时统一作为 System 级别补充上下文
|
||||
beforeChar := make([]string, 0)
|
||||
afterChar := make([]string, 0)
|
||||
for _, entry := range result.Entries {
|
||||
text := strings.TrimSpace(entry.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
if entry.Position == "after_char" {
|
||||
afterChar = append(afterChar, text)
|
||||
} else {
|
||||
// 默认 before_char
|
||||
beforeChar = append(beforeChar, text)
|
||||
}
|
||||
}
|
||||
|
||||
if len(beforeChar) > 0 {
|
||||
worldInfoTexts = append(worldInfoTexts, strings.Join(beforeChar, "\n\n"))
|
||||
}
|
||||
if len(afterChar) > 0 {
|
||||
worldInfoTexts = append(worldInfoTexts, strings.Join(afterChar, "\n\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 获取 AI 提供商和模型
|
||||
var provider *appModel.AIProvider
|
||||
var modelName string
|
||||
|
||||
if req.ProviderID != nil {
|
||||
// 指定了提供商
|
||||
p, pErr := providerService.GetUserDefaultProvider(userID) // 简化:暂时仅用默认
|
||||
if pErr != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": "指定的 AI 接口不可用"})
|
||||
return
|
||||
}
|
||||
provider = p
|
||||
} else {
|
||||
// 使用默认提供商
|
||||
p, pErr := providerService.GetUserDefaultProvider(userID)
|
||||
if pErr != nil {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": pErr.Error()})
|
||||
return
|
||||
}
|
||||
provider = p
|
||||
}
|
||||
|
||||
if req.ModelName != "" {
|
||||
modelName = req.ModelName
|
||||
} else {
|
||||
// 使用提供商的第一个启用的聊天模型
|
||||
modelName = getDefaultChatModelForProvider(provider.ID)
|
||||
if modelName == "" {
|
||||
c.JSON(400, gin.H{"code": 7, "msg": "该 AI 接口没有可用的聊天模型"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 构建 Prompt(角色卡 + 世界书 + 历史消息)
|
||||
prompt := appService.BuildPrompt(character, historyMsgs)
|
||||
|
||||
// 将世界书内容插入为额外 System 提示(靠近角色定义)
|
||||
if len(worldInfoTexts) > 0 {
|
||||
worldInfoContent := "World Info:\n" + strings.Join(worldInfoTexts, "\n\n")
|
||||
inserted := false
|
||||
for i, msg := range prompt {
|
||||
if msg.Role == "system" {
|
||||
// 在第一个 system 消息之后插入一条世界书消息
|
||||
newPrompt := make([]appService.AIMessagePayload, 0, len(prompt)+1)
|
||||
newPrompt = append(newPrompt, prompt[:i+1]...)
|
||||
newPrompt = append(newPrompt, appService.AIMessagePayload{
|
||||
Role: "system",
|
||||
Content: worldInfoContent,
|
||||
})
|
||||
newPrompt = append(newPrompt, prompt[i+1:]...)
|
||||
prompt = newPrompt
|
||||
inserted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !inserted {
|
||||
prompt = append([]appService.AIMessagePayload{{
|
||||
Role: "system",
|
||||
Content: worldInfoContent,
|
||||
}}, prompt...)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加用户的新消息
|
||||
prompt = append(prompt, appService.AIMessagePayload{
|
||||
Role: "user",
|
||||
Content: req.Content,
|
||||
})
|
||||
|
||||
// 5. 设置 SSE 响应头
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
// 6. 流式调用 AI
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan appService.AIStreamChunk, 100)
|
||||
go appService.StreamAIResponse(ctx, provider, modelName, prompt, ch)
|
||||
|
||||
var fullContent string
|
||||
var promptTokens, completionTokens int
|
||||
|
||||
flusher, ok := c.Writer.(interface{ Flush() })
|
||||
if !ok {
|
||||
c.JSON(500, gin.H{"code": 7, "msg": "服务器不支持流式响应"})
|
||||
return
|
||||
}
|
||||
|
||||
for chunk := range ch {
|
||||
if chunk.Error != "" {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": chunk.Error,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
if chunk.Content != "" {
|
||||
fullContent += chunk.Content
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "content",
|
||||
"content": chunk.Content,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if chunk.Done {
|
||||
promptTokens = chunk.PromptTokens
|
||||
completionTokens = chunk.CompletionTokens
|
||||
|
||||
// 保存 AI 回复
|
||||
if fullContent != "" {
|
||||
chatService.SaveAssistantMessage(
|
||||
chat.ID, chat.CharacterID, fullContent,
|
||||
modelName, promptTokens, completionTokens,
|
||||
)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "done",
|
||||
"model": modelName,
|
||||
"promptTokens": promptTokens,
|
||||
"completionTokens": completionTokens,
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EditMessage 编辑消息
|
||||
// @Tags 对话
|
||||
// @Summary 编辑消息内容
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.EditMessageRequest true "编辑信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.MessageResponse} "编辑成功"
|
||||
// @Router /app/chat/message/edit [post]
|
||||
func (ca *ChatApi) EditMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.EditMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := chatService.EditMessage(req, userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(msg, c)
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息
|
||||
// @Tags 对话
|
||||
// @Summary 删除消息
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.DeleteMessageRequest true "消息ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/chat/message/delete [post]
|
||||
func (ca *ChatApi) DeleteMessage(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.DeleteMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := chatService.DeleteMessage(req.MessageID, userID); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 内部辅助 ====================
|
||||
|
||||
// getDefaultChatModelForProvider 获取提供商的默认聊天模型
|
||||
func getDefaultChatModelForProvider(providerID uint) string {
|
||||
var model appModel.AIModel
|
||||
err := global.GVA_DB.Where("provider_id = ? AND model_type = ? AND is_enabled = ?",
|
||||
providerID, "chat", true).
|
||||
Order("created_at ASC").
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return model.ModelName
|
||||
}
|
||||
@@ -8,9 +8,13 @@ type ApiGroup struct {
|
||||
WorldInfoApi
|
||||
ExtensionApi
|
||||
RegexScriptApi
|
||||
ProviderApi
|
||||
ChatApi
|
||||
}
|
||||
|
||||
var (
|
||||
authService = service.ServiceGroupApp.AppServiceGroup.AuthService
|
||||
characterService = service.ServiceGroupApp.AppServiceGroup.CharacterService
|
||||
providerService = service.ServiceGroupApp.AppServiceGroup.ProviderService
|
||||
chatService = service.ServiceGroupApp.AppServiceGroup.ChatService
|
||||
)
|
||||
|
||||
@@ -3,7 +3,10 @@ package app
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
@@ -563,3 +566,72 @@ func (a *ExtensionApi) InstallExtensionFromGit(c *gin.Context) {
|
||||
|
||||
sysResponse.OkWithData(response.ToExtensionResponse(extension), c)
|
||||
}
|
||||
|
||||
// ProxyExtensionAsset 获取扩展资源文件(从本地文件系统读取)
|
||||
// @Summary 获取扩展资源文件
|
||||
// @Description 从本地存储读取扩展的 JS/CSS 等资源文件(与原版 SillyTavern 一致,扩展文件存储在本地)
|
||||
// @Tags 扩展管理
|
||||
// @Produce octet-stream
|
||||
// @Param id path int true "扩展ID"
|
||||
// @Param path path string true "资源文件路径"
|
||||
// @Success 200 {file} binary
|
||||
// @Router /app/extension/:id/asset/*path [get]
|
||||
func (a *ExtensionApi) ProxyExtensionAsset(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
sysResponse.FailWithMessage("无效的扩展ID", c)
|
||||
return
|
||||
}
|
||||
extensionID := uint(id)
|
||||
|
||||
// 获取资源路径(去掉前导 /)
|
||||
assetPath := c.Param("path")
|
||||
if len(assetPath) > 0 && assetPath[0] == '/' {
|
||||
assetPath = assetPath[1:]
|
||||
}
|
||||
if assetPath == "" {
|
||||
sysResponse.FailWithMessage("资源路径不能为空", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 通过扩展 ID 查库获取信息(公开路由,不做 userID 过滤)
|
||||
extInfo, err := extensionService.GetExtensionByID(extensionID)
|
||||
if err != nil {
|
||||
sysResponse.FailWithMessage("扩展不存在", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 从本地文件系统读取资源
|
||||
localPath, err := extensionService.GetExtensionAssetLocalPath(extInfo.Name, assetPath)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取扩展资源失败",
|
||||
zap.Error(err),
|
||||
zap.String("name", extInfo.Name),
|
||||
zap.String("asset", assetPath))
|
||||
sysResponse.FailWithMessage("资源不存在: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("读取扩展资源文件失败", zap.Error(err), zap.String("path", localPath))
|
||||
sysResponse.FailWithMessage("资源读取失败", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据文件扩展名设置正确的 Content-Type
|
||||
fileExt := filepath.Ext(assetPath)
|
||||
contentType := mime.TypeByExtension(fileExt)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
// 设置缓存和 CORS 头
|
||||
c.Header("Content-Type", contentType)
|
||||
c.Header("Cache-Control", "public, max-age=3600")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
c.Data(http.StatusOK, contentType, data)
|
||||
}
|
||||
|
||||
476
server/api/v1/app/provider.go
Normal file
476
server/api/v1/app/provider.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/middleware"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/common/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ProviderApi struct{}
|
||||
|
||||
// ==================== 提供商接口 ====================
|
||||
|
||||
// CreateProvider 创建AI提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 创建AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.CreateProviderRequest true "提供商信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "创建成功"
|
||||
// @Router /app/provider [post]
|
||||
func (pa *ProviderApi) CreateProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.CreateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := providerService.CreateProvider(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("创建AI提供商失败", zap.Error(err))
|
||||
response.FailWithMessage("创建失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(provider, c)
|
||||
}
|
||||
|
||||
// GetProviderList 获取AI提供商列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取AI提供商列表
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param page query int true "页码"
|
||||
// @Param pageSize query int true "每页数量"
|
||||
// @Param keyword query string false "搜索关键词"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderListResponse} "获取成功"
|
||||
// @Router /app/provider/list [get]
|
||||
func (pa *ProviderApi) GetProviderList(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.ProviderListRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 默认分页
|
||||
if req.Page == 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize == 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
|
||||
list, err := providerService.GetProviderList(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取AI提供商列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败", c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(list, c)
|
||||
}
|
||||
|
||||
// GetProviderDetail 获取AI提供商详情
|
||||
// @Tags AI配置
|
||||
// @Summary 获取AI提供商详情
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "获取成功"
|
||||
// @Router /app/provider/:id [get]
|
||||
func (pa *ProviderApi) GetProviderDetail(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := providerService.GetProviderDetail(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(provider, c)
|
||||
}
|
||||
|
||||
// UpdateProvider 更新AI提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 更新AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.UpdateProviderRequest true "更新信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ProviderResponse} "更新成功"
|
||||
// @Router /app/provider [put]
|
||||
func (pa *ProviderApi) UpdateProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.UpdateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := providerService.UpdateProvider(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("更新AI提供商失败", zap.Error(err))
|
||||
response.FailWithMessage("更新失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(provider, c)
|
||||
}
|
||||
|
||||
// DeleteProvider 删除AI提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 删除AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/provider/:id [delete]
|
||||
func (pa *ProviderApi) DeleteProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := providerService.DeleteProvider(uint(id), userID); err != nil {
|
||||
global.GVA_LOG.Error("删除AI提供商失败", zap.Error(err))
|
||||
response.FailWithMessage("删除失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// SetDefaultProvider 设置默认提供商
|
||||
// @Tags AI配置
|
||||
// @Summary 设置默认AI提供商
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.SetDefaultProviderRequest true "提供商ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "设置成功"
|
||||
// @Router /app/provider/setDefault [post]
|
||||
func (pa *ProviderApi) SetDefaultProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.SetDefaultProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := providerService.SetDefaultProvider(req.ProviderID, userID); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("设置成功", c)
|
||||
}
|
||||
|
||||
// TestProvider 测试提供商连通性(不需要先保存)
|
||||
// @Tags AI配置
|
||||
// @Summary 测试AI提供商连通性
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.TestProviderRequest true "测试参数"
|
||||
// @Success 200 {object} response.Response{data=appResponse.TestProviderResponse} "测试结果"
|
||||
// @Router /app/provider/test [post]
|
||||
func (pa *ProviderApi) TestProvider(c *gin.Context) {
|
||||
var req request.TestProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.TestProvider(req)
|
||||
if err != nil {
|
||||
response.FailWithMessage("测试失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// TestExistingProvider 测试已保存的提供商连通性
|
||||
// @Tags AI配置
|
||||
// @Summary 测试已保存的提供商连通性
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.TestProviderResponse} "测试结果"
|
||||
// @Router /app/provider/test/:id [get]
|
||||
func (pa *ProviderApi) TestExistingProvider(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.TestExistingProvider(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// ==================== 模型接口 ====================
|
||||
|
||||
// AddModel 为提供商添加模型
|
||||
// @Tags AI配置
|
||||
// @Summary 添加AI模型
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.CreateModelRequest true "模型信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ModelResponse} "创建成功"
|
||||
// @Router /app/provider/model [post]
|
||||
func (pa *ProviderApi) AddModel(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.CreateModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
model, err := providerService.AddModel(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("添加模型失败", zap.Error(err))
|
||||
response.FailWithMessage("添加失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(model, c)
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
// @Tags AI配置
|
||||
// @Summary 更新AI模型
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.UpdateModelRequest true "更新信息"
|
||||
// @Success 200 {object} response.Response{data=appResponse.ModelResponse} "更新成功"
|
||||
// @Router /app/provider/model [put]
|
||||
func (pa *ProviderApi) UpdateModel(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
var req request.UpdateModelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
model, err := providerService.UpdateModel(req, userID)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("更新模型失败", zap.Error(err))
|
||||
response.FailWithMessage("更新失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(model, c)
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
// @Tags AI配置
|
||||
// @Summary 删除AI模型
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "模型ID"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /app/provider/model/:id [delete]
|
||||
func (pa *ProviderApi) DeleteModel(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := providerService.DeleteModel(uint(id), userID); err != nil {
|
||||
global.GVA_LOG.Error("删除模型失败", zap.Error(err))
|
||||
response.FailWithMessage("删除失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
// ==================== 远程模型获取 ====================
|
||||
|
||||
// FetchRemoteModels 获取远程可用模型列表(不需要先保存)
|
||||
// @Tags AI配置
|
||||
// @Summary 从远程获取可用模型列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.FetchRemoteModelsRequest true "提供商配置"
|
||||
// @Success 200 {object} response.Response{data=appResponse.FetchRemoteModelsResponse} "获取结果"
|
||||
// @Router /app/provider/fetchModels [post]
|
||||
func (pa *ProviderApi) FetchRemoteModels(c *gin.Context) {
|
||||
var req request.FetchRemoteModelsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.FetchRemoteModels(req)
|
||||
if err != nil {
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取已保存提供商的远程可用模型
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Success 200 {object} response.Response{data=appResponse.FetchRemoteModelsResponse} "获取结果"
|
||||
// @Router /app/provider/fetchModels/:id [get]
|
||||
func (pa *ProviderApi) FetchRemoteModelsExisting(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.FetchRemoteModelsExisting(uint(id), userID)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// ==================== 测试消息 ====================
|
||||
|
||||
// SendTestMessage 发送测试消息(不需要先保存)
|
||||
// @Tags AI配置
|
||||
// @Summary 发送测试消息验证AI接口
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.SendTestMessageRequest true "测试参数"
|
||||
// @Success 200 {object} response.Response{data=appResponse.SendTestMessageResponse} "测试结果"
|
||||
// @Router /app/provider/testMessage [post]
|
||||
func (pa *ProviderApi) SendTestMessage(c *gin.Context) {
|
||||
var req request.SendTestMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.SendTestMessage(req)
|
||||
if err != nil {
|
||||
response.FailWithMessage("测试失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// SendTestMessageExisting 使用已保存的提供商发送测试消息
|
||||
// @Tags AI配置
|
||||
// @Summary 使用已保存的提供商发送测试消息
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param id path uint true "提供商ID"
|
||||
// @Param data body object{modelName=string,message=string} true "测试参数"
|
||||
// @Success 200 {object} response.Response{data=appResponse.SendTestMessageResponse} "测试结果"
|
||||
// @Router /app/provider/testMessage/:id [post]
|
||||
func (pa *ProviderApi) SendTestMessageExisting(c *gin.Context) {
|
||||
userID := middleware.GetAppUserID(c)
|
||||
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
response.FailWithMessage("无效的ID", c)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
ModelName string `json:"modelName" binding:"required"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := providerService.SendTestMessageExisting(uint(id), userID, body.ModelName, body.Message)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// ==================== 辅助接口 ====================
|
||||
|
||||
// GetProviderTypes 获取支持的提供商类型列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取支持的AI提供商类型
|
||||
// @Produce application/json
|
||||
// @Success 200 {object} response.Response{data=[]appResponse.ProviderTypeOption} "获取成功"
|
||||
// @Router /app/provider/types [get]
|
||||
func (pa *ProviderApi) GetProviderTypes(c *gin.Context) {
|
||||
types := providerService.GetProviderTypes()
|
||||
response.OkWithData(types, c)
|
||||
}
|
||||
|
||||
// GetPresetModels 获取预设模型列表
|
||||
// @Tags AI配置
|
||||
// @Summary 获取指定提供商类型的预设模型
|
||||
// @Produce application/json
|
||||
// @Param type query string true "提供商类型(openai/claude/gemini/custom)"
|
||||
// @Success 200 {object} response.Response{data=[]appResponse.PresetModelOption} "获取成功"
|
||||
// @Router /app/provider/presetModels [get]
|
||||
func (pa *ProviderApi) GetPresetModels(c *gin.Context) {
|
||||
providerType := c.Query("type")
|
||||
if providerType == "" {
|
||||
response.FailWithMessage("请指定提供商类型", c)
|
||||
return
|
||||
}
|
||||
|
||||
models := providerService.GetPresetModels(providerType)
|
||||
response.OkWithData(models, c)
|
||||
}
|
||||
Reference in New Issue
Block a user