Files
st/server/service/app/extension_installer.go
2026-02-14 06:20:05 +08:00

765 lines
21 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"archive/zip"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"go.uber.org/zap"
"gorm.io/datatypes"
)
// extensionsBaseDir 扩展文件存放根目录(与 router.go 中的静态服务路径一致)
const extensionsBaseDir = "data/st-core-scripts/scripts/extensions/third-party"
// STManifest SillyTavern 扩展 manifest.json 结构
type STManifest struct {
DisplayName string `json:"display_name"`
Loading string `json:"loading_order"` // 加载顺序
Requires []string `json:"requires"`
Optional []string `json:"optional"`
Js string `json:"js"` // 入口 JS 文件
Css string `json:"css"` // 入口 CSS 文件
Author string `json:"author"`
Version string `json:"version"`
Homepages string `json:"homepages"`
Repository string `json:"repository"`
AutoUpdate bool `json:"auto_update"`
Description string `json:"description"`
Tags []string `json:"tags"`
Settings map[string]interface{} `json:"settings"`
Raw map[string]interface{} `json:"-"` // 原始 JSON 数据
}
// getExtensionDir 获取指定扩展的文件系统目录
func getExtensionDir(extName string) string {
return filepath.Join(extensionsBaseDir, extName)
}
// ensureExtensionsBaseDir 确保扩展基础目录存在
func ensureExtensionsBaseDir() error {
return os.MkdirAll(extensionsBaseDir, 0755)
}
// parseManifestFile 从扩展目录中读取并解析 manifest.json
func parseManifestFile(dir string) (*STManifest, error) {
manifestPath := filepath.Join(dir, "manifest.json")
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("无法读取 manifest.json: %w", err)
}
var manifest STManifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
// 保留原始 JSON 用于存储到数据库
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err == nil {
manifest.Raw = raw
}
return &manifest, nil
}
// parseManifestBytes 从字节数组解析 manifest.json
func parseManifestBytes(data []byte) (*STManifest, error) {
var manifest STManifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, fmt.Errorf("解析 manifest.json 失败: %w", err)
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err == nil {
manifest.Raw = raw
}
return &manifest, nil
}
// InstallExtensionFromGit 从 Git 仓库安装扩展
func (s *ExtensionService) InstallExtensionFromGit(userID uint, gitURL string, branch string) (*app.AIExtension, error) {
if branch == "" {
branch = "main"
}
if err := ensureExtensionsBaseDir(); err != nil {
return nil, fmt.Errorf("创建扩展目录失败: %w", err)
}
// 从 URL 提取扩展名
extName := extractRepoName(gitURL)
if extName == "" {
return nil, errors.New("无法从 URL 中提取扩展名")
}
extDir := getExtensionDir(extName)
// 检查目录是否已存在
if _, err := os.Stat(extDir); err == nil {
return nil, fmt.Errorf("扩展 '%s' 已存在,请先删除或使用升级功能", extName)
}
global.GVA_LOG.Info("从 Git 安装扩展",
zap.String("url", gitURL),
zap.String("branch", branch),
zap.String("dir", extDir),
)
// 执行 git clone
cmd := exec.Command("git", "clone", "--depth", "1", "--branch", branch, gitURL, extDir)
output, err := cmd.CombinedOutput()
if err != nil {
global.GVA_LOG.Error("git clone 失败",
zap.String("output", string(output)),
zap.Error(err),
)
// 清理失败的目录
_ = os.RemoveAll(extDir)
return nil, fmt.Errorf("git clone 失败: %s", strings.TrimSpace(string(output)))
}
global.GVA_LOG.Info("git clone 成功", zap.String("name", extName))
// 如果扩展需要构建(有 package.json 的 build 脚本且 dist 不存在),执行构建
if err := buildExtensionIfNeeded(extDir); err != nil {
global.GVA_LOG.Warn("扩展构建失败(不影响安装)",
zap.String("name", extName),
zap.Error(err),
)
}
// 创建数据库记录
return s.createExtensionFromDir(userID, extDir, extName, "git", gitURL, branch)
}
// InstallExtensionFromManifestURL 从 manifest URL 安装扩展
func (s *ExtensionService) InstallExtensionFromManifestURL(userID uint, manifestURL string, branch string) (*app.AIExtension, error) {
if err := ensureExtensionsBaseDir(); err != nil {
return nil, fmt.Errorf("创建扩展目录失败: %w", err)
}
global.GVA_LOG.Info("从 Manifest URL 安装扩展", zap.String("url", manifestURL))
// 下载 manifest.json
manifestData, err := httpGet(manifestURL)
if err != nil {
return nil, fmt.Errorf("下载 manifest.json 失败: %w", err)
}
manifest, err := parseManifestBytes(manifestData)
if err != nil {
return nil, err
}
// 确定扩展名
extName := sanitizeName(manifest.DisplayName)
if extName == "" {
extName = extractNameFromURL(manifestURL)
}
if extName == "" {
return nil, errors.New("无法确定扩展名manifest 中缺少 display_name")
}
extDir := getExtensionDir(extName)
if _, err := os.Stat(extDir); err == nil {
return nil, fmt.Errorf("扩展 '%s' 已存在,请先删除或使用升级功能", extName)
}
if err := os.MkdirAll(extDir, 0755); err != nil {
return nil, fmt.Errorf("创建扩展目录失败: %w", err)
}
// 保存 manifest.json
if err := os.WriteFile(filepath.Join(extDir, "manifest.json"), manifestData, 0644); err != nil {
_ = os.RemoveAll(extDir)
return nil, fmt.Errorf("保存 manifest.json 失败: %w", err)
}
// 获取 manifest URL 的基础路径
baseURL := manifestURL[:strings.LastIndex(manifestURL, "/")+1]
// 下载 JS 入口文件
if manifest.Js != "" {
jsURL := baseURL + manifest.Js
jsData, err := httpGet(jsURL)
if err != nil {
global.GVA_LOG.Warn("下载 JS 文件失败", zap.String("url", jsURL), zap.Error(err))
} else {
if err := os.WriteFile(filepath.Join(extDir, manifest.Js), jsData, 0644); err != nil {
global.GVA_LOG.Warn("保存 JS 文件失败", zap.Error(err))
}
}
}
// 下载 CSS 文件
if manifest.Css != "" {
cssURL := baseURL + manifest.Css
cssData, err := httpGet(cssURL)
if err != nil {
global.GVA_LOG.Warn("下载 CSS 文件失败", zap.String("url", cssURL), zap.Error(err))
} else {
if err := os.WriteFile(filepath.Join(extDir, manifest.Css), cssData, 0644); err != nil {
global.GVA_LOG.Warn("保存 CSS 文件失败", zap.Error(err))
}
}
}
// 创建数据库记录
return s.createExtensionFromDir(userID, extDir, extName, "url", manifestURL, branch)
}
// ImportExtensionFromZip 从 zip 文件导入扩展
func (s *ExtensionService) ImportExtensionFromZip(userID uint, filename string, zipData []byte) (*app.AIExtension, error) {
if err := ensureExtensionsBaseDir(); err != nil {
return nil, fmt.Errorf("创建扩展目录失败: %w", err)
}
// 先解压到临时目录
tmpDir, err := os.MkdirTemp("", "ext-import-*")
if err != nil {
return nil, fmt.Errorf("创建临时目录失败: %w", err)
}
defer os.RemoveAll(tmpDir)
// 解压 zip
if err := extractZip(zipData, tmpDir); err != nil {
return nil, fmt.Errorf("解压 zip 失败: %w", err)
}
// 找到 manifest.json 所在目录(可能在根目录或子目录)
manifestDir, err := findManifestDir(tmpDir)
if err != nil {
return nil, err
}
// 解析 manifest
manifest, err := parseManifestFile(manifestDir)
if err != nil {
return nil, err
}
// 确定扩展名
extName := sanitizeName(manifest.DisplayName)
if extName == "" {
extName = strings.TrimSuffix(filename, filepath.Ext(filename))
}
extDir := getExtensionDir(extName)
if _, err := os.Stat(extDir); err == nil {
return nil, fmt.Errorf("扩展 '%s' 已存在,请先删除或使用升级功能", extName)
}
// 移动文件到目标目录
if err := os.Rename(manifestDir, extDir); err != nil {
// 如果跨分区移动失败,回退为复制
if err := copyDir(manifestDir, extDir); err != nil {
return nil, fmt.Errorf("移动扩展文件失败: %w", err)
}
}
global.GVA_LOG.Info("ZIP 扩展导入成功",
zap.String("name", extName),
zap.String("dir", extDir),
)
return s.createExtensionFromDir(userID, extDir, extName, "file", "", "")
}
// UpgradeExtensionFromSource 从源地址升级扩展
func (s *ExtensionService) UpgradeExtensionFromSource(userID, extID uint) (*app.AIExtension, error) {
ext, err := s.GetExtension(userID, extID)
if err != nil {
return nil, err
}
if ext.SourceURL == "" {
return nil, errors.New("该扩展没有源地址,无法升级")
}
extDir := getExtensionDir(ext.Name)
switch ext.InstallSource {
case "git":
// Git 扩展:执行 git pull
global.GVA_LOG.Info("从 Git 升级扩展",
zap.String("name", ext.Name),
zap.String("dir", extDir),
)
cmd := exec.Command("git", "-C", extDir, "pull", "--ff-only")
output, err := cmd.CombinedOutput()
if err != nil {
global.GVA_LOG.Error("git pull 失败",
zap.String("output", string(output)),
zap.Error(err),
)
return nil, fmt.Errorf("git pull 失败: %s", strings.TrimSpace(string(output)))
}
global.GVA_LOG.Info("git pull 成功", zap.String("output", string(output)))
// 如果扩展需要构建,执行构建
if err := buildExtensionIfNeeded(extDir); err != nil {
global.GVA_LOG.Warn("升级后扩展构建失败",
zap.String("name", ext.Name),
zap.Error(err),
)
}
case "url":
// URL 扩展:重新下载 manifest 和文件
manifestData, err := httpGet(ext.SourceURL)
if err != nil {
return nil, fmt.Errorf("下载 manifest.json 失败: %w", err)
}
manifest, err := parseManifestBytes(manifestData)
if err != nil {
return nil, err
}
// 覆盖写入 manifest.json
if err := os.WriteFile(filepath.Join(extDir, "manifest.json"), manifestData, 0644); err != nil {
return nil, fmt.Errorf("保存 manifest.json 失败: %w", err)
}
baseURL := ext.SourceURL[:strings.LastIndex(ext.SourceURL, "/")+1]
// 重新下载 JS
if manifest.Js != "" {
if jsData, err := httpGet(baseURL + manifest.Js); err == nil {
_ = os.WriteFile(filepath.Join(extDir, manifest.Js), jsData, 0644)
}
}
// 重新下载 CSS
if manifest.Css != "" {
if cssData, err := httpGet(baseURL + manifest.Css); err == nil {
_ = os.WriteFile(filepath.Join(extDir, manifest.Css), cssData, 0644)
}
}
default:
return nil, errors.New("该扩展的安装来源不支持升级")
}
// 重新解析 manifest 并更新数据库
manifest, _ := parseManifestFile(extDir)
if manifest != nil {
now := time.Now().Unix()
updates := map[string]interface{}{
"last_update_check": &now,
}
if manifest.Version != "" {
updates["version"] = manifest.Version
}
if manifest.Description != "" {
updates["description"] = manifest.Description
}
if manifest.Author != "" {
updates["author"] = manifest.Author
}
if manifest.Js != "" {
updates["script_path"] = manifest.Js
}
if manifest.Css != "" {
updates["style_path"] = manifest.Css
}
if manifest.Raw != nil {
if raw, err := json.Marshal(manifest.Raw); err == nil {
updates["manifest_data"] = datatypes.JSON(raw)
}
}
global.GVA_DB.Model(&app.AIExtension{}).Where("id = ? AND user_id = ?", extID, userID).Updates(updates)
}
return s.GetExtension(userID, extID)
}
// InstallExtensionFromURL 智能安装:根据 URL 判断是 Git 仓库还是 Manifest URL
func (s *ExtensionService) InstallExtensionFromURL(userID uint, url string, branch string) (*app.AIExtension, error) {
if isGitURL(url) {
return s.InstallExtensionFromGit(userID, url, branch)
}
return s.InstallExtensionFromManifestURL(userID, url, branch)
}
// ---------------------
// 辅助函数
// ---------------------
// createExtensionFromDir 从扩展目录创建数据库记录
func (s *ExtensionService) createExtensionFromDir(userID uint, extDir, extName, installSource, sourceURL, branch string) (*app.AIExtension, error) {
manifest, err := parseManifestFile(extDir)
if err != nil {
// manifest 解析失败不阻止安装,使用基本信息
global.GVA_LOG.Warn("解析 manifest 失败,使用基本信息", zap.Error(err))
manifest = &STManifest{}
}
now := time.Now().Unix()
displayName := manifest.DisplayName
if displayName == "" {
displayName = extName
}
version := manifest.Version
if version == "" {
version = "1.0.0"
}
ext := app.AIExtension{
UserID: userID,
Name: extName,
DisplayName: displayName,
Version: version,
Author: manifest.Author,
Description: manifest.Description,
Homepage: manifest.Homepages,
Repository: manifest.Repository,
ExtensionType: "ui",
ScriptPath: manifest.Js,
StylePath: manifest.Css,
InstallSource: installSource,
SourceURL: sourceURL,
Branch: branch,
IsInstalled: true,
IsEnabled: false,
InstallDate: &now,
AutoUpdate: manifest.AutoUpdate,
}
// 存储 manifest 原始数据
if manifest.Raw != nil {
if raw, err := json.Marshal(manifest.Raw); err == nil {
ext.ManifestData = datatypes.JSON(raw)
}
}
// 存储标签
if len(manifest.Tags) > 0 {
if tags, err := json.Marshal(manifest.Tags); err == nil {
ext.Tags = datatypes.JSON(tags)
}
}
// 存储默认设置
if manifest.Settings != nil {
if settings, err := json.Marshal(manifest.Settings); err == nil {
ext.Settings = datatypes.JSON(settings)
}
}
// 存储依赖
if len(manifest.Requires) > 0 {
deps := make(map[string]string)
for _, r := range manifest.Requires {
deps[r] = "*"
}
if depsJSON, err := json.Marshal(deps); err == nil {
ext.Dependencies = datatypes.JSON(depsJSON)
}
}
if err := global.GVA_DB.Create(&ext).Error; err != nil {
return nil, fmt.Errorf("创建扩展记录失败: %w", err)
}
global.GVA_LOG.Info("扩展安装成功",
zap.String("name", extName),
zap.String("source", installSource),
zap.String("version", version),
)
return &ext, nil
}
// buildExtensionIfNeeded 如果扩展目录中有 package.json 且包含 build 脚本,
// 而 manifest 中指定的入口 JS 文件不存在,则自动执行 npm/pnpm install && build
func buildExtensionIfNeeded(extDir string) error {
// 读取 manifest 获取入口文件路径
manifest, err := parseManifestFile(extDir)
if err != nil || manifest.Js == "" {
return nil // 无 manifest 或无 JS 入口,不需要构建
}
// 检查入口 JS 文件是否存在
jsPath := filepath.Join(extDir, manifest.Js)
if _, err := os.Stat(jsPath); err == nil {
return nil // 入口文件已存在,无需构建
}
// 检查 package.json 是否存在
pkgPath := filepath.Join(extDir, "package.json")
pkgData, err := os.ReadFile(pkgPath)
if err != nil {
return nil // 无 package.json不是需要构建的扩展
}
// 检查是否有 build 脚本
var pkg struct {
Scripts map[string]string `json:"scripts"`
}
if err := json.Unmarshal(pkgData, &pkg); err != nil {
return nil
}
if _, hasBuild := pkg.Scripts["build"]; !hasBuild {
return nil // 没有 build 脚本
}
global.GVA_LOG.Info("扩展需要构建,开始安装依赖及构建",
zap.String("dir", extDir),
zap.String("entry", manifest.Js),
)
// 判断使用 pnpm 还是 npm
var pkgManager string
if _, err := os.Stat(filepath.Join(extDir, "pnpm-lock.yaml")); err == nil {
pkgManager = "pnpm"
} else if _, err := os.Stat(filepath.Join(extDir, "pnpm-workspace.yaml")); err == nil {
pkgManager = "pnpm"
} else {
pkgManager = "npm"
}
// 确认包管理器可用
if _, err := exec.LookPath(pkgManager); err != nil {
// 回退到 npm
pkgManager = "npm"
if _, err := exec.LookPath("npm"); err != nil {
return fmt.Errorf("未找到 npm 或 pnpm无法构建扩展")
}
}
global.GVA_LOG.Info("使用包管理器", zap.String("manager", pkgManager))
// 第一步:安装依赖
installCmd := exec.Command(pkgManager, "install")
installCmd.Dir = extDir
installOutput, err := installCmd.CombinedOutput()
if err != nil {
global.GVA_LOG.Error("依赖安装失败",
zap.String("output", string(installOutput)),
zap.Error(err),
)
return fmt.Errorf("%s install 失败: %s", pkgManager, strings.TrimSpace(string(installOutput)))
}
global.GVA_LOG.Info("依赖安装成功", zap.String("manager", pkgManager))
// 第二步:执行构建
buildCmd := exec.Command(pkgManager, "run", "build")
buildCmd.Dir = extDir
buildOutput, err := buildCmd.CombinedOutput()
if err != nil {
global.GVA_LOG.Error("构建失败",
zap.String("output", string(buildOutput)),
zap.Error(err),
)
return fmt.Errorf("%s run build 失败: %s", pkgManager, strings.TrimSpace(string(buildOutput)))
}
global.GVA_LOG.Info("扩展构建成功", zap.String("dir", extDir))
// 验证入口文件是否已生成
if _, err := os.Stat(jsPath); err != nil {
return fmt.Errorf("构建完成但入口文件仍不存在: %s", manifest.Js)
}
return nil
}
// isGitURL 判断 URL 是否为 Git 仓库
func isGitURL(url string) bool {
url = strings.ToLower(url)
if strings.HasSuffix(url, ".git") {
return true
}
if strings.Contains(url, "github.com/") ||
strings.Contains(url, "gitlab.com/") ||
strings.Contains(url, "gitee.com/") ||
strings.Contains(url, "bitbucket.org/") {
// 排除以 .json 结尾的 URL
if strings.HasSuffix(url, ".json") {
return false
}
return true
}
return false
}
// extractRepoName 从 Git URL 提取仓库名
func extractRepoName(gitURL string) string {
gitURL = strings.TrimSuffix(gitURL, ".git")
gitURL = strings.TrimRight(gitURL, "/")
parts := strings.Split(gitURL, "/")
if len(parts) == 0 {
return ""
}
return parts[len(parts)-1]
}
// extractNameFromURL 从 URL 路径中提取名称
func extractNameFromURL(url string) string {
// 对于 manifest URLhttps://example.com/extensions/my-ext/manifest.json
// 提取上一级目录名
url = strings.TrimRight(url, "/")
parts := strings.Split(url, "/")
if len(parts) >= 2 {
filename := parts[len(parts)-1]
if strings.Contains(filename, "manifest") {
return parts[len(parts)-2]
}
}
return ""
}
// sanitizeName 清理扩展名(移除不安全字符)
func sanitizeName(name string) string {
name = strings.TrimSpace(name)
// 将空格替换为连字符
name = strings.ReplaceAll(name, " ", "-")
// 只保留字母、数字、连字符、下划线
var result strings.Builder
for _, c := range name {
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
result.WriteRune(c)
}
}
return result.String()
}
// httpGet 发送 HTTP GET 请求
func httpGet(url string) ([]byte, error) {
client := &http.Client{
Timeout: 60 * time.Second,
}
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
}
return io.ReadAll(resp.Body)
}
// extractZip 解压 zip 文件到指定目录
func extractZip(zipData []byte, destDir string) error {
reader, err := zip.NewReader(bytes.NewReader(zipData), int64(len(zipData)))
if err != nil {
return fmt.Errorf("打开 zip 文件失败: %w", err)
}
for _, file := range reader.File {
// 安全检查:防止 zip slip 攻击
destPath := filepath.Join(destDir, file.Name)
if !strings.HasPrefix(filepath.Clean(destPath), filepath.Clean(destDir)+string(os.PathSeparator)) {
return fmt.Errorf("非法的 zip 文件路径: %s", file.Name)
}
if file.FileInfo().IsDir() {
if err := os.MkdirAll(destPath, 0755); err != nil {
return err
}
continue
}
// 确保父目录存在
if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil {
return err
}
rc, err := file.Open()
if err != nil {
return err
}
outFile, err := os.Create(destPath)
if err != nil {
rc.Close()
return err
}
_, err = io.Copy(outFile, rc)
outFile.Close()
rc.Close()
if err != nil {
return err
}
}
return nil
}
// findManifestDir 在解压的目录中查找 manifest.json 所在目录
func findManifestDir(rootDir string) (string, error) {
// 先检查根目录
if _, err := os.Stat(filepath.Join(rootDir, "manifest.json")); err == nil {
return rootDir, nil
}
// 检查一级子目录(常见的 zip 结构是 zip 内包含一个项目文件夹)
entries, err := os.ReadDir(rootDir)
if err != nil {
return "", fmt.Errorf("读取目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
subDir := filepath.Join(rootDir, entry.Name())
if _, err := os.Stat(filepath.Join(subDir, "manifest.json")); err == nil {
return subDir, nil
}
}
}
return "", errors.New("未找到 manifest.json请确保 zip 文件包含有效的 SillyTavern 扩展")
}
// copyDir 递归复制目录
func copyDir(src, dst string) error {
if err := os.MkdirAll(dst, 0755); err != nil {
return err
}
entries, err := os.ReadDir(src)
if err != nil {
return err
}
for _, entry := range entries {
srcPath := filepath.Join(src, entry.Name())
dstPath := filepath.Join(dst, entry.Name())
if entry.IsDir() {
if err := copyDir(srcPath, dstPath); err != nil {
return err
}
} else {
data, err := os.ReadFile(srcPath)
if err != nil {
return err
}
if err := os.WriteFile(dstPath, data, 0644); err != nil {
return err
}
}
}
return nil
}