Files
st/server/service/app/extension.go

1040 lines
33 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/app/response"
"go.uber.org/zap"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// extensionDataDir 扩展本地存储根目录
// 与原版 SillyTavern 完全一致的路径结构scripts/extensions/third-party/{name}/
// 扩展 JS 中的相对路径 import如 ../../../../../script.js依赖此目录层级来正确解析
// 所有 SillyTavern 核心脚本和扩展文件统一存储在 data/st-core-scripts/ 下,独立于 web-app/
// 扩展代码是公共的(不按用户隔离),用户间差异仅在于数据库中的配置和启用状态
const extensionDataDir = "data/st-core-scripts/scripts/extensions/third-party"
// getExtensionStorePath 获取扩展的本地存储路径: {extensionDataDir}/{extensionName}/
func getExtensionStorePath(extensionName string) string {
return filepath.Join(extensionDataDir, extensionName)
}
// GetExtensionAssetLocalPath 获取扩展资源文件的本地绝对路径
func (es *ExtensionService) GetExtensionAssetLocalPath(extensionName string, assetPath string) (string, error) {
storePath := getExtensionStorePath(extensionName)
fullPath := filepath.Join(storePath, assetPath)
// 安全检查:防止路径遍历攻击
absStore, _ := filepath.Abs(storePath)
absFile, _ := filepath.Abs(fullPath)
if !strings.HasPrefix(absFile, absStore) {
return "", errors.New("非法的资源路径")
}
// 检查文件是否存在
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
return "", fmt.Errorf("资源文件不存在: %s", assetPath)
}
return fullPath, nil
}
// ensureExtensionDir 确保扩展存储目录存在
func ensureExtensionDir(extensionName string) (string, error) {
storePath := getExtensionStorePath(extensionName)
if err := os.MkdirAll(storePath, 0755); err != nil {
return "", fmt.Errorf("创建扩展存储目录失败: %w", err)
}
return storePath, nil
}
// removeExtensionDir 删除扩展的本地存储目录
func removeExtensionDir(extensionName string) error {
storePath := getExtensionStorePath(extensionName)
if _, err := os.Stat(storePath); os.IsNotExist(err) {
return nil // 目录不存在,无需删除
}
return os.RemoveAll(storePath)
}
type ExtensionService struct{}
// CreateExtension 创建/安装扩展
func (es *ExtensionService) CreateExtension(userID uint, req *request.CreateExtensionRequest) (*app.AIExtension, error) {
// 校验名称
if req.Name == "" {
return nil, errors.New("扩展名称不能为空")
}
// 检查扩展是否已存在
var existing app.AIExtension
err := global.GVA_DB.Where("user_id = ? AND name = ?", userID, req.Name).First(&existing).Error
if err == nil {
return nil, fmt.Errorf("扩展 %s 已存在", req.Name)
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 序列化 JSON 字段
tagsJSON, _ := json.Marshal(req.Tags)
dependenciesJSON, _ := json.Marshal(req.Dependencies)
conflictsJSON, _ := json.Marshal(req.Conflicts)
manifestJSON, _ := json.Marshal(req.ManifestData)
assetsJSON, _ := json.Marshal(req.AssetsPaths)
settingsJSON, _ := json.Marshal(req.Settings)
optionsJSON, _ := json.Marshal(req.Options)
metadataJSON, _ := json.Marshal(req.Metadata)
extension := &app.AIExtension{
UserID: userID,
Name: req.Name,
DisplayName: req.DisplayName,
Version: req.Version,
Author: req.Author,
Description: req.Description,
Homepage: req.Homepage,
Repository: req.Repository,
License: req.License,
Tags: datatypes.JSON(tagsJSON),
ExtensionType: req.ExtensionType,
Category: req.Category,
Dependencies: datatypes.JSON(dependenciesJSON),
Conflicts: datatypes.JSON(conflictsJSON),
ManifestData: datatypes.JSON(manifestJSON),
ScriptPath: req.ScriptPath,
StylePath: req.StylePath,
AssetsPaths: datatypes.JSON(assetsJSON),
Settings: datatypes.JSON(settingsJSON),
Options: datatypes.JSON(optionsJSON),
IsEnabled: false,
IsInstalled: true,
IsSystemExt: false,
InstallSource: req.InstallSource,
SourceURL: req.SourceURL,
Branch: req.Branch,
AutoUpdate: req.AutoUpdate,
InstallDate: time.Now(),
Metadata: datatypes.JSON(metadataJSON),
}
if err := global.GVA_DB.Create(extension).Error; err != nil {
return nil, err
}
global.GVA_LOG.Info("扩展安装成功", zap.Uint("extensionID", extension.ID), zap.String("name", extension.Name))
return extension, nil
}
// UpdateExtension 更新扩展
func (es *ExtensionService) UpdateExtension(userID, extensionID uint, req *request.UpdateExtensionRequest) error {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return errors.New("扩展不存在")
}
// 系统内置扩展不允许修改
if extension.IsSystemExt {
return errors.New("系统内置扩展不允许修改")
}
updates := map[string]interface{}{}
if req.DisplayName != "" {
updates["display_name"] = req.DisplayName
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.Settings != nil {
settingsJSON, _ := json.Marshal(req.Settings)
updates["settings"] = datatypes.JSON(settingsJSON)
}
if req.Options != nil {
optionsJSON, _ := json.Marshal(req.Options)
updates["options"] = datatypes.JSON(optionsJSON)
}
if req.Metadata != nil {
metadataJSON, _ := json.Marshal(req.Metadata)
updates["metadata"] = datatypes.JSON(metadataJSON)
}
if err := global.GVA_DB.Model(&extension).Updates(updates).Error; err != nil {
return err
}
return nil
}
// DeleteExtension 删除/卸载扩展
func (es *ExtensionService) DeleteExtension(userID, extensionID uint, deleteFiles bool) error {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return errors.New("扩展不存在")
}
// 系统内置扩展不允许删除
if extension.IsSystemExt {
return errors.New("系统内置扩展不允许删除")
}
// 删除本地扩展文件(与原版 SillyTavern 一致:卸载扩展时清理本地文件)
if err := removeExtensionDir(extension.Name); err != nil {
global.GVA_LOG.Warn("删除扩展本地文件失败", zap.Error(err), zap.String("name", extension.Name))
// 不阻断删除流程
}
// 删除数据库记录
if err := global.GVA_DB.Delete(&extension).Error; err != nil {
return err
}
global.GVA_LOG.Info("扩展卸载成功", zap.Uint("extensionID", extensionID), zap.String("name", extension.Name))
return nil
}
// GetExtension 获取扩展详情
func (es *ExtensionService) GetExtension(userID, extensionID uint) (*app.AIExtension, error) {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return nil, errors.New("扩展不存在")
}
return &extension, nil
}
// GetExtensionByID 通过扩展 ID 获取扩展信息(不限制用户,用于公开资源路由)
func (es *ExtensionService) GetExtensionByID(extensionID uint) (*app.AIExtension, error) {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ?", extensionID).First(&extension).Error; err != nil {
return nil, errors.New("扩展不存在")
}
return &extension, nil
}
// GetExtensionList 获取扩展列表
func (es *ExtensionService) GetExtensionList(userID uint, req *request.ExtensionListRequest) (*response.ExtensionListResponse, error) {
var extensions []app.AIExtension
var total int64
db := global.GVA_DB.Model(&app.AIExtension{}).Where("user_id = ?", userID)
// 过滤条件
if req.Name != "" {
db = db.Where("name ILIKE ? OR display_name ILIKE ?", "%"+req.Name+"%", "%"+req.Name+"%")
}
if req.ExtensionType != "" {
db = db.Where("extension_type = ?", req.ExtensionType)
}
if req.Category != "" {
db = db.Where("category = ?", req.Category)
}
if req.IsEnabled != nil {
db = db.Where("is_enabled = ?", *req.IsEnabled)
}
if req.IsInstalled != nil {
db = db.Where("is_installed = ?", *req.IsInstalled)
}
if req.Tag != "" {
db = db.Where("tags @> ?", fmt.Sprintf(`["%s"]`, req.Tag))
}
// 统计总数
if err := db.Count(&total).Error; err != nil {
return nil, err
}
// 分页查询
if err := db.Scopes(req.Paginate()).Order("created_at DESC").Find(&extensions).Error; err != nil {
return nil, err
}
// 转换响应
result := make([]response.ExtensionResponse, 0, len(extensions))
for i := range extensions {
result = append(result, response.ToExtensionResponse(&extensions[i]))
}
return &response.ExtensionListResponse{
List: result,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// ToggleExtension 启用/禁用扩展
func (es *ExtensionService) ToggleExtension(userID, extensionID uint, isEnabled bool) error {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return errors.New("扩展不存在")
}
// 检查依赖
if isEnabled {
if err := es.checkDependencies(userID, &extension); err != nil {
return err
}
}
// 检查冲突
if isEnabled {
if err := es.checkConflicts(userID, &extension); err != nil {
return err
}
}
updates := map[string]interface{}{
"is_enabled": isEnabled,
}
if isEnabled {
updates["last_enabled"] = time.Now()
}
if err := global.GVA_DB.Model(&extension).Updates(updates).Error; err != nil {
return err
}
global.GVA_LOG.Info("扩展状态更新", zap.Uint("extensionID", extensionID), zap.Bool("enabled", isEnabled))
return nil
}
// UpdateExtensionSettings 更新扩展配置
func (es *ExtensionService) UpdateExtensionSettings(userID, extensionID uint, settings map[string]interface{}) error {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return errors.New("扩展不存在")
}
settingsJSON, err := json.Marshal(settings)
if err != nil {
return errors.New("序列化配置失败")
}
// 直接更新扩展表的 settings 字段
return global.GVA_DB.Model(&extension).Update("settings", datatypes.JSON(settingsJSON)).Error
}
// GetExtensionSettings 获取扩展配置
func (es *ExtensionService) GetExtensionSettings(userID, extensionID uint) (map[string]interface{}, error) {
// 获取扩展信息
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return nil, errors.New("扩展不存在")
}
// 从扩展的 Settings 字段读取用户配置
var settings map[string]interface{}
if len(extension.Settings) > 0 {
if err := json.Unmarshal([]byte(extension.Settings), &settings); err != nil {
return nil, errors.New("解析配置失败: " + err.Error())
}
}
// 如果 ManifestData 中有默认配置,合并进来
if len(extension.ManifestData) > 0 {
var manifest map[string]interface{}
if err := json.Unmarshal([]byte(extension.ManifestData), &manifest); err == nil {
if manifestSettings, ok := manifest["settings"].(map[string]interface{}); ok && manifestSettings != nil {
// 只添加用户未设置的默认值
if settings == nil {
settings = make(map[string]interface{})
}
for k, v := range manifestSettings {
if _, exists := settings[k]; !exists {
settings[k] = v
}
}
}
}
}
if settings == nil {
settings = make(map[string]interface{})
}
return settings, nil
}
// UpdateExtensionStats 更新扩展统计
func (es *ExtensionService) UpdateExtensionStats(userID, extensionID uint, action string, value int) error {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return errors.New("扩展不存在")
}
updates := map[string]interface{}{}
switch action {
case "usage":
updates["usage_count"] = gorm.Expr("usage_count + ?", value)
case "error":
updates["error_count"] = gorm.Expr("error_count + ?", value)
case "load":
// 计算平均加载时间
newAvg := (extension.LoadTime*extension.UsageCount + value) / (extension.UsageCount + 1)
updates["load_time"] = newAvg
default:
return errors.New("未知的统计类型")
}
return global.GVA_DB.Model(&extension).Updates(updates).Error
}
// GetExtensionManifest 获取扩展 manifest
func (es *ExtensionService) GetExtensionManifest(userID, extensionID uint) (*response.ExtensionManifestResponse, error) {
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return nil, errors.New("扩展不存在")
}
var manifestData map[string]interface{}
if extension.ManifestData != nil {
_ = json.Unmarshal([]byte(extension.ManifestData), &manifestData)
}
// 从 manifestData 构建响应
manifest := &response.ExtensionManifestResponse{
Name: extension.Name,
DisplayName: extension.DisplayName,
Version: extension.Version,
Description: extension.Description,
Author: extension.Author,
Homepage: extension.Homepage,
Repository: extension.Repository,
License: extension.License,
Type: extension.ExtensionType,
Category: extension.Category,
Entry: extension.ScriptPath,
Style: extension.StylePath,
}
// 解析数组和对象
if extension.Tags != nil {
_ = json.Unmarshal([]byte(extension.Tags), &manifest.Tags)
}
if extension.Dependencies != nil {
_ = json.Unmarshal([]byte(extension.Dependencies), &manifest.Dependencies)
}
if extension.Conflicts != nil {
_ = json.Unmarshal([]byte(extension.Conflicts), &manifest.Conflicts)
}
if extension.AssetsPaths != nil {
_ = json.Unmarshal([]byte(extension.AssetsPaths), &manifest.Assets)
}
if extension.Settings != nil {
_ = json.Unmarshal([]byte(extension.Settings), &manifest.Settings)
}
if extension.Options != nil {
_ = json.Unmarshal([]byte(extension.Options), &manifest.Options)
}
if extension.Metadata != nil {
_ = json.Unmarshal([]byte(extension.Metadata), &manifest.Metadata)
}
return manifest, nil
}
// ImportExtension 导入扩展(从文件)
func (es *ExtensionService) ImportExtension(userID uint, manifestData []byte) (*app.AIExtension, error) {
// 解析 manifest.json
var manifest app.AIExtensionManifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return nil, errors.New("无效的 manifest.json 格式")
}
// 验证必填字段
if manifest.Name == "" || manifest.Version == "" {
return nil, errors.New("manifest 缺少必填字段")
}
// 构建创建请求
req := &request.CreateExtensionRequest{
Name: manifest.Name,
DisplayName: manifest.DisplayName,
Version: manifest.Version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.Homepage,
Repository: manifest.Repository,
License: manifest.License,
Tags: manifest.Tags,
ExtensionType: manifest.Type,
Category: manifest.Category,
Dependencies: manifest.Dependencies,
Conflicts: manifest.Conflicts,
ScriptPath: manifest.Entry,
StylePath: manifest.Style,
AssetsPaths: manifest.Assets,
Settings: manifest.Settings,
Options: manifest.Options,
InstallSource: "file",
Metadata: manifest.Metadata,
}
// 将 manifest 原始数据也保存
var manifestMap map[string]interface{}
_ = json.Unmarshal(manifestData, &manifestMap)
req.ManifestData = manifestMap
return es.CreateExtension(userID, req)
}
// ExportExtension 导出扩展
func (es *ExtensionService) ExportExtension(userID, extensionID uint) ([]byte, error) {
manifest, err := es.GetExtensionManifest(userID, extensionID)
if err != nil {
return nil, err
}
return json.MarshalIndent(manifest, "", " ")
}
// checkDependencies 检查扩展依赖
func (es *ExtensionService) checkDependencies(userID uint, extension *app.AIExtension) error {
if extension.Dependencies == nil || len(extension.Dependencies) == 0 {
return nil
}
var dependencies map[string]string
_ = json.Unmarshal([]byte(extension.Dependencies), &dependencies)
for depName := range dependencies {
var depExt app.AIExtension
err := global.GVA_DB.Where("user_id = ? AND name = ? AND is_enabled = true", userID, depName).First(&depExt).Error
if err != nil {
return fmt.Errorf("缺少依赖扩展: %s", depName)
}
// TODO: 检查版本号是否满足要求
}
return nil
}
// checkConflicts 检查扩展冲突
func (es *ExtensionService) checkConflicts(userID uint, extension *app.AIExtension) error {
if extension.Conflicts == nil || len(extension.Conflicts) == 0 {
return nil
}
var conflicts []string
_ = json.Unmarshal([]byte(extension.Conflicts), &conflicts)
for _, conflictName := range conflicts {
var conflictExt app.AIExtension
err := global.GVA_DB.Where("user_id = ? AND name = ? AND is_enabled = true", userID, conflictName).First(&conflictExt).Error
if err == nil {
return fmt.Errorf("扩展 %s 与 %s 冲突", extension.Name, conflictName)
}
}
return nil
}
// GetEnabledExtensions 获取用户启用的所有扩展(用于前端加载)
func (es *ExtensionService) GetEnabledExtensions(userID uint) ([]response.ExtensionResponse, error) {
var extensions []app.AIExtension
if err := global.GVA_DB.Where("user_id = ? AND is_enabled = true AND is_installed = true", userID).
Order("created_at ASC").Find(&extensions).Error; err != nil {
return nil, err
}
result := make([]response.ExtensionResponse, 0, len(extensions))
for i := range extensions {
result = append(result, response.ToExtensionResponse(&extensions[i]))
}
return result, nil
}
// InstallExtensionFromURL 智能安装扩展(自动识别 Git URL 或 Manifest URL
func (es *ExtensionService) InstallExtensionFromURL(userID uint, url string, branch string) (*app.AIExtension, error) {
global.GVA_LOG.Info("开始从 URL 安装扩展", zap.String("url", url), zap.String("branch", branch))
// 智能识别 URL 类型
if isGitURL(url) {
global.GVA_LOG.Info("检测到 Git 仓库 URL使用 Git 安装")
if branch == "" {
branch = "main"
}
return es.InstallExtensionFromGit(userID, url, branch)
}
// 否则作为 manifest.json URL 处理
global.GVA_LOG.Info("作为 Manifest URL 处理")
return es.downloadAndInstallFromManifestURL(userID, url)
}
// isGitURL 判断是否为 Git 仓库 URL
func isGitURL(url string) bool {
// Git 仓库特征:
// 1. 包含 .git 后缀
// 2. 包含常见的 Git 托管平台域名github.com, gitlab.com, gitee.com 等)
// 3. 不以 /manifest.json 或 .json 结尾
url = strings.ToLower(url)
// 如果明确以 .json 结尾,不是 Git URL
if strings.HasSuffix(url, ".json") {
return false
}
// 如果包含 .git 后缀,是 Git URL
if strings.HasSuffix(url, ".git") {
return true
}
// 检查是否包含 Git 托管平台域名
gitHosts := []string{
"github.com",
"gitlab.com",
"gitee.com",
"bitbucket.org",
"gitea.io",
"codeberg.org",
}
for _, host := range gitHosts {
if strings.Contains(url, host) {
// 如果包含 Git 平台且不是 raw 文件 URL则认为是 Git 仓库
if !strings.Contains(url, "/raw/") && !strings.Contains(url, "/blob/") {
return true
}
}
}
return false
}
// GetExtensionAssetURL 根据扩展的安装来源构建资源文件的远程 URL
func (es *ExtensionService) GetExtensionAssetURL(extension *app.AIExtension, assetPath string) (string, error) {
if extension.SourceURL == "" {
return "", errors.New("扩展没有源地址")
}
sourceURL := strings.TrimSuffix(strings.TrimSuffix(extension.SourceURL, "/"), ".git")
branch := extension.Branch
if branch == "" {
branch = "main"
}
// GitLab: repo/-/raw/branch/path
if strings.Contains(sourceURL, "gitlab.com") {
return fmt.Sprintf("%s/-/raw/%s/%s", sourceURL, branch, assetPath), nil
}
// GitHub: raw.githubusercontent.com/user/repo/branch/path
if strings.Contains(sourceURL, "github.com") {
rawURL := strings.Replace(sourceURL, "github.com", "raw.githubusercontent.com", 1)
return fmt.Sprintf("%s/%s/%s", rawURL, branch, assetPath), nil
}
// Gitee: repo/raw/branch/path
if strings.Contains(sourceURL, "gitee.com") {
return fmt.Sprintf("%s/raw/%s/%s", sourceURL, branch, assetPath), nil
}
return fmt.Sprintf("%s/%s", sourceURL, assetPath), nil
}
// downloadAndInstallFromManifestURL 从 Manifest URL 下载并安装(同时下载资源文件到本地)
func (es *ExtensionService) downloadAndInstallFromManifestURL(userID uint, manifestURL string) (*app.AIExtension, error) {
client := &http.Client{
Timeout: 30 * time.Second,
}
// 下载 manifest.json
resp, err := client.Get(manifestURL)
if err != nil {
return nil, fmt.Errorf("下载 manifest.json 失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载 manifest.json 失败: HTTP %d", resp.StatusCode)
}
manifestData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取 manifest.json 失败: %w", err)
}
// 解析 manifest
var manifest app.AIExtensionManifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
// 获取有效名称
effectiveName := manifest.GetEffectiveName()
if effectiveName == "" {
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
}
// 检查扩展是否已存在
var existing app.AIExtension
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
if err == nil {
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 创建本地存储目录并保存 manifest.json
storePath, err := ensureExtensionDir(effectiveName)
if err != nil {
return nil, err
}
if err := os.WriteFile(filepath.Join(storePath, "manifest.json"), manifestData, 0644); err != nil {
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("保存 manifest.json 失败: %w", err)
}
// 获取 manifest URL 的基础目录(用于下载关联资源)
baseURL := manifestURL[:strings.LastIndex(manifestURL, "/")+1]
// 下载 JS/CSS 等资源文件到本地
filesToDownload := []string{}
if entry := manifest.GetEffectiveEntry(); entry != "" {
filesToDownload = append(filesToDownload, entry)
}
if style := manifest.GetEffectiveStyle(); style != "" {
filesToDownload = append(filesToDownload, style)
}
filesToDownload = append(filesToDownload, manifest.Assets...)
for _, file := range filesToDownload {
if file == "" {
continue
}
fileURL := baseURL + file
if err := downloadFileToLocal(client, fileURL, filepath.Join(storePath, file)); err != nil {
global.GVA_LOG.Warn("下载扩展资源文件失败(非致命)",
zap.String("file", file),
zap.String("url", fileURL),
zap.Error(err))
}
}
global.GVA_LOG.Info("扩展文件已保存到本地",
zap.String("name", effectiveName),
zap.String("path", storePath))
// 将 manifest 转换为 map[string]interface{}
var manifestMap map[string]interface{}
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
return nil, fmt.Errorf("转换 manifest 失败: %w", err)
}
// 构建创建请求
createReq := &request.CreateExtensionRequest{
Name: effectiveName,
DisplayName: manifest.DisplayName,
Version: manifest.Version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.GetEffectiveHomepage(),
Repository: manifest.Repository,
License: manifest.License,
Tags: manifest.Tags,
ExtensionType: manifest.Type,
Category: manifest.Category,
Dependencies: manifest.Dependencies,
Conflicts: manifest.Conflicts,
ManifestData: manifestMap,
ScriptPath: manifest.GetEffectiveEntry(),
StylePath: manifest.GetEffectiveStyle(),
AssetsPaths: manifest.Assets,
Settings: manifest.Settings,
Options: manifest.Options,
InstallSource: "url",
SourceURL: manifestURL,
AutoUpdate: manifest.AutoUpdate,
Metadata: nil,
}
// 确保扩展类型有效
if createReq.ExtensionType == "" {
createReq.ExtensionType = "ui"
}
// 创建扩展
extension, err := es.CreateExtension(userID, createReq)
if err != nil {
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("创建扩展失败: %w", err)
}
global.GVA_LOG.Info("从 URL 安装扩展成功",
zap.Uint("extensionID", extension.ID),
zap.String("name", extension.Name),
zap.String("url", manifestURL))
return extension, nil
}
// downloadFileToLocal 下载远程文件到本地路径
func downloadFileToLocal(client *http.Client, url string, localPath string) error {
// 确保目标文件的父目录存在
if err := os.MkdirAll(filepath.Dir(localPath), 0755); err != nil {
return err
}
resp, err := client.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
return os.WriteFile(localPath, data, 0644)
}
// UpgradeExtension 升级扩展版本(根据安装来源自动选择更新方式)
func (es *ExtensionService) UpgradeExtension(userID, extensionID uint, force bool) (*app.AIExtension, error) {
// 获取扩展信息
var extension app.AIExtension
if err := global.GVA_DB.Where("id = ? AND user_id = ?", extensionID, userID).First(&extension).Error; err != nil {
return nil, errors.New("扩展不存在")
}
global.GVA_LOG.Info("开始升级扩展",
zap.Uint("extensionID", extensionID),
zap.String("name", extension.Name),
zap.String("installSource", extension.InstallSource),
zap.String("sourceUrl", extension.SourceURL))
// 根据安装来源选择更新方式
switch extension.InstallSource {
case "git":
return es.updateExtensionFromGit(userID, &extension, force)
case "url":
return es.updateExtensionFromURL(userID, &extension)
default:
return nil, fmt.Errorf("不支持的安装来源: %s", extension.InstallSource)
}
}
// updateExtensionFromGit 从 Git 仓库更新扩展(先删除旧记录和文件,再重新安装)
func (es *ExtensionService) updateExtensionFromGit(userID uint, extension *app.AIExtension, force bool) (*app.AIExtension, error) {
if extension.SourceURL == "" {
return nil, errors.New("缺少 Git 仓库 URL")
}
global.GVA_LOG.Info("从 Git 更新扩展",
zap.String("name", extension.Name),
zap.String("sourceUrl", extension.SourceURL),
zap.String("branch", extension.Branch))
// 先删除旧的数据库记录和本地文件
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
}
_ = removeExtensionDir(extension.Name)
// 重新克隆安装
return es.InstallExtensionFromGit(userID, extension.SourceURL, extension.Branch)
}
// updateExtensionFromURL 从 URL 更新扩展(先删除旧记录和文件,再重新下载安装)
func (es *ExtensionService) updateExtensionFromURL(userID uint, extension *app.AIExtension) (*app.AIExtension, error) {
if extension.SourceURL == "" {
return nil, errors.New("缺少 Manifest URL")
}
global.GVA_LOG.Info("从 URL 更新扩展",
zap.String("name", extension.Name),
zap.String("sourceUrl", extension.SourceURL))
// 先删除旧的数据库记录和本地文件
if err := global.GVA_DB.Unscoped().Delete(extension).Error; err != nil {
return nil, fmt.Errorf("删除旧扩展记录失败: %w", err)
}
_ = removeExtensionDir(extension.Name)
// 重新下载安装
return es.downloadAndInstallFromManifestURL(userID, extension.SourceURL)
}
// InstallExtensionFromGit 从 Git URL 安装扩展(与原版 SillyTavern 一致:将源码下载到本地)
func (es *ExtensionService) InstallExtensionFromGit(userID uint, gitUrl, branch string) (*app.AIExtension, error) {
// 验证 Git URL
if !strings.Contains(gitUrl, "://") && !strings.HasSuffix(gitUrl, ".git") {
return nil, errors.New("无效的 Git URL")
}
// 先 clone 到临时目录读取 manifest获取扩展名后再移动到正式目录
tempDir, err := os.MkdirTemp("", "extension-*")
if err != nil {
return nil, fmt.Errorf("创建临时目录失败: %w", err)
}
defer os.RemoveAll(tempDir) // 确保清理临时目录
global.GVA_LOG.Info("开始从 Git 克隆扩展",
zap.String("gitUrl", gitUrl),
zap.String("branch", branch))
// 执行 git clone浅克隆
cmd := exec.Command("git", "clone", "--depth=1", "--branch="+branch, gitUrl, tempDir)
output, err := cmd.CombinedOutput()
if err != nil {
global.GVA_LOG.Error("Git clone 失败",
zap.String("gitUrl", gitUrl),
zap.String("output", string(output)),
zap.Error(err))
return nil, fmt.Errorf("Git clone 失败: %s", string(output))
}
// 读取 manifest.json
manifestPath := filepath.Join(tempDir, "manifest.json")
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("读取 manifest.json 失败: %w", err)
}
// 解析 manifest
var manifest app.AIExtensionManifest
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
// 获取有效名称(兼容 SillyTavern manifest 没有 name 字段的情况)
effectiveName := manifest.GetEffectiveName()
if effectiveName == "" {
return nil, errors.New("manifest.json 缺少 name 或 display_name 字段")
}
// 检查扩展是否已存在
var existing app.AIExtension
err = global.GVA_DB.Where("user_id = ? AND name = ?", userID, effectiveName).First(&existing).Error
if err == nil {
return nil, fmt.Errorf("扩展 %s 已安装", effectiveName)
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 将扩展文件保存到公共目录: web-app/public/scripts/extensions/third-party/{extensionName}/
storePath, err := ensureExtensionDir(effectiveName)
if err != nil {
return nil, err
}
// 清空目标目录(如果有残留文件)后复制 clone 内容
_ = os.RemoveAll(storePath)
if err := copyDir(tempDir, storePath); err != nil {
return nil, fmt.Errorf("保存扩展文件失败: %w", err)
}
global.GVA_LOG.Info("扩展文件已保存到本地",
zap.String("name", effectiveName),
zap.String("path", storePath))
// 将 manifest 转换为 map[string]interface{}
var manifestMap map[string]interface{}
if err := json.Unmarshal(manifestData, &manifestMap); err != nil {
return nil, fmt.Errorf("转换 manifest 失败: %w", err)
}
// 构建创建请求(使用兼容方法获取字段值)
createReq := &request.CreateExtensionRequest{
Name: effectiveName,
DisplayName: manifest.DisplayName,
Version: manifest.Version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.GetEffectiveHomepage(),
Repository: manifest.Repository,
License: manifest.License,
Tags: manifest.Tags,
ExtensionType: manifest.Type,
Category: manifest.Category,
Dependencies: manifest.Dependencies,
Conflicts: manifest.Conflicts,
ManifestData: manifestMap,
ScriptPath: manifest.GetEffectiveEntry(),
StylePath: manifest.GetEffectiveStyle(),
AssetsPaths: manifest.Assets,
Settings: manifest.Settings,
Options: manifest.Options,
InstallSource: "git",
SourceURL: gitUrl,
Branch: branch,
AutoUpdate: manifest.AutoUpdate,
Metadata: manifest.Metadata,
}
// 确保扩展类型有效
if createReq.ExtensionType == "" {
createReq.ExtensionType = "ui"
}
// 创建扩展记录
extension, err := es.CreateExtension(userID, createReq)
if err != nil {
// 创建失败则清理本地文件
_ = removeExtensionDir(effectiveName)
return nil, fmt.Errorf("创建扩展记录失败: %w", err)
}
global.GVA_LOG.Info("从 Git 安装扩展成功",
zap.Uint("extensionID", extension.ID),
zap.String("name", extension.Name),
zap.String("version", extension.Version),
zap.String("localPath", storePath))
return extension, nil
}
// copyDir 递归复制目录(排除 .git 目录以节省空间)
func copyDir(src, dst string) error {
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 计算相对路径
relPath, err := filepath.Rel(src, path)
if err != nil {
return err
}
// 排除 .git 目录
if info.IsDir() && info.Name() == ".git" {
return filepath.SkipDir
}
dstPath := filepath.Join(dst, relPath)
if info.IsDir() {
return os.MkdirAll(dstPath, info.Mode())
}
// 复制文件
srcFile, err := os.ReadFile(path)
if err != nil {
return err
}
return os.WriteFile(dstPath, srcFile, info.Mode())
})
}