Files
st-react/server/service/app/preset.go
Echo f4e166c5ee 🎉 初始化项目
Signed-off-by: Echo <1711788888@qq.com>
2026-02-27 21:52:00 +08:00

354 lines
10 KiB
Go

package app
import (
"encoding/json"
"errors"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"go.uber.org/zap"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type PresetService struct{}
// CreatePreset 创建预设
func (s *PresetService) CreatePreset(userID uint, req *request.CreatePresetRequest) (*app.AIPreset, error) {
// 序列化 StopSequences
var stopSequencesJSON datatypes.JSON
if len(req.StopSequences) > 0 {
data, err := json.Marshal(req.StopSequences)
if err != nil {
global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err))
return nil, err
}
stopSequencesJSON = data
}
// 序列化 Extensions
var extensionsJSON datatypes.JSON
if len(req.Extensions) > 0 {
data, err := json.Marshal(req.Extensions)
if err != nil {
global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err))
return nil, err
}
extensionsJSON = data
}
preset := &app.AIPreset{
UserID: userID,
Name: req.Name,
Description: req.Description,
IsPublic: req.IsPublic,
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
FrequencyPenalty: req.FrequencyPenalty,
PresencePenalty: req.PresencePenalty,
MaxTokens: req.MaxTokens,
RepetitionPenalty: req.RepetitionPenalty,
MinP: req.MinP,
TopA: req.TopA,
SystemPrompt: req.SystemPrompt,
StopSequences: stopSequencesJSON,
Extensions: extensionsJSON,
}
if err := global.GVA_DB.Create(preset).Error; err != nil {
global.GVA_LOG.Error("创建预设失败", zap.Error(err))
return nil, err
}
return preset, nil
}
// GetPresetList 获取预设列表
func (s *PresetService) GetPresetList(userID uint, req *request.GetPresetListRequest) ([]app.AIPreset, int64, error) {
var presets []app.AIPreset
var total int64
db := global.GVA_DB.Model(&app.AIPreset{})
// 权限过滤:只能看到自己的预设或公开的预设
db = db.Where("user_id = ? OR is_public = ?", userID, true)
// 关键词搜索
if req.Keyword != "" {
db = db.Where("name LIKE ? OR description LIKE ?", "%"+req.Keyword+"%", "%"+req.Keyword+"%")
}
// 公开/私有过滤
if req.IsPublic != nil {
db = db.Where("is_public = ?", *req.IsPublic)
}
// 获取总数
if err := db.Count(&total).Error; err != nil {
global.GVA_LOG.Error("获取预设总数失败", zap.Error(err))
return nil, 0, err
}
// 分页查询
offset := (req.Page - 1) * req.PageSize
if err := db.Order("is_default DESC, updated_at DESC").
Offset(offset).
Limit(req.PageSize).
Find(&presets).Error; err != nil {
global.GVA_LOG.Error("获取预设列表失败", zap.Error(err))
return nil, 0, err
}
return presets, total, nil
}
// GetPresetByID 根据ID获取预设
func (s *PresetService) GetPresetByID(userID uint, id uint) (*app.AIPreset, error) {
var preset app.AIPreset
// 权限检查:只能访问自己的预设或公开的预设
if err := global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userID, true).
First(&preset).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("预设不存在或无权访问")
}
global.GVA_LOG.Error("获取预设失败", zap.Error(err))
return nil, err
}
return &preset, nil
}
// UpdatePreset 更新预设
func (s *PresetService) UpdatePreset(userID uint, id uint, req *request.UpdatePresetRequest) error {
var preset app.AIPreset
// 权限检查:只能更新自己的预设
if err := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("预设不存在或无权修改")
}
global.GVA_LOG.Error("查询预设失败", zap.Error(err))
return err
}
// 构建更新数据
updates := make(map[string]interface{})
if req.Name != "" {
updates["name"] = req.Name
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.IsPublic != nil {
updates["is_public"] = *req.IsPublic
}
if req.Temperature != nil {
updates["temperature"] = *req.Temperature
}
if req.TopP != nil {
updates["top_p"] = *req.TopP
}
if req.TopK != nil {
updates["top_k"] = *req.TopK
}
if req.FrequencyPenalty != nil {
updates["frequency_penalty"] = *req.FrequencyPenalty
}
if req.PresencePenalty != nil {
updates["presence_penalty"] = *req.PresencePenalty
}
if req.MaxTokens != nil {
updates["max_tokens"] = *req.MaxTokens
}
if req.RepetitionPenalty != nil {
updates["repetition_penalty"] = *req.RepetitionPenalty
}
if req.MinP != nil {
updates["min_p"] = *req.MinP
}
if req.TopA != nil {
updates["top_a"] = *req.TopA
}
if req.SystemPrompt != nil {
updates["system_prompt"] = *req.SystemPrompt
}
// 更新 StopSequences
if req.StopSequences != nil {
data, err := json.Marshal(req.StopSequences)
if err != nil {
global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err))
return err
}
updates["stop_sequences"] = datatypes.JSON(data)
}
// 更新 Extensions
if req.Extensions != nil {
data, err := json.Marshal(req.Extensions)
if err != nil {
global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err))
return err
}
updates["extensions"] = datatypes.JSON(data)
}
if err := global.GVA_DB.Model(&preset).Updates(updates).Error; err != nil {
global.GVA_LOG.Error("更新预设失败", zap.Error(err))
return err
}
return nil
}
// DeletePreset 删除预设
func (s *PresetService) DeletePreset(userID uint, id uint) error {
// 权限检查:只能删除自己的预设
result := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AIPreset{})
if result.Error != nil {
global.GVA_LOG.Error("删除预设失败", zap.Error(result.Error))
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("预设不存在或无权删除")
}
return nil
}
// SetDefaultPreset 设置默认预设
func (s *PresetService) SetDefaultPreset(userID uint, id uint) error {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 检查预设是否存在且属于当前用户
var preset app.AIPreset
if err := tx.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("预设不存在或无权访问")
}
return err
}
// 取消当前用户的所有默认预设
if err := tx.Model(&app.AIPreset{}).
Where("user_id = ? AND is_default = ?", userID, true).
Update("is_default", false).Error; err != nil {
return err
}
// 设置新的默认预设
if err := tx.Model(&preset).Update("is_default", true).Error; err != nil {
return err
}
return nil
})
}
// ImportPresetFromJSON 从JSON导入预设
func (s *PresetService) ImportPresetFromJSON(userID uint, jsonData []byte, filename string) (*app.AIPreset, error) {
// 尝试解析为 SillyTavern 格式
var stPreset struct {
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
MaxTokens int `json:"openai_max_tokens"`
RepetitionPenalty float64 `json:"repetition_penalty"`
MinP float64 `json:"min_p"`
TopA float64 `json:"top_a"`
StopSequences []string `json:"stop_sequences"`
Prompts []map[string]interface{} `json:"prompts"`
PromptOrder []map[string]interface{} `json:"prompt_order"`
}
if err := json.Unmarshal(jsonData, &stPreset); err != nil {
global.GVA_LOG.Error("解析预设JSON失败", zap.Error(err))
return nil, errors.New("无效的预设格式")
}
// 从文件名提取预设名称(去掉 .json 后缀)
name := filename
if len(name) > 5 && name[len(name)-5:] == ".json" {
name = name[:len(name)-5]
}
// 构建 extensions 对象,包含 prompts 和 prompt_order
extensions := map[string]interface{}{
"prompts": stPreset.Prompts,
"prompt_order": stPreset.PromptOrder,
}
// 转换为创建请求
req := &request.CreatePresetRequest{
Name: name,
Description: "从 SillyTavern 导入",
Temperature: stPreset.Temperature,
TopP: stPreset.TopP,
TopK: stPreset.TopK,
FrequencyPenalty: stPreset.FrequencyPenalty,
PresencePenalty: stPreset.PresencePenalty,
MaxTokens: stPreset.MaxTokens,
RepetitionPenalty: stPreset.RepetitionPenalty,
MinP: stPreset.MinP,
TopA: stPreset.TopA,
SystemPrompt: "",
StopSequences: stPreset.StopSequences,
Extensions: extensions,
}
return s.CreatePreset(userID, req)
}
// ExportPresetToJSON 导出预设为JSON
func (s *PresetService) ExportPresetToJSON(userID uint, id uint) ([]byte, error) {
preset, err := s.GetPresetByID(userID, id)
if err != nil {
return nil, err
}
// 解析 StopSequences
var stopSequences []string
if len(preset.StopSequences) > 0 {
json.Unmarshal(preset.StopSequences, &stopSequences)
}
// 解析 Extensions
var extensions map[string]interface{}
if len(preset.Extensions) > 0 {
json.Unmarshal(preset.Extensions, &extensions)
}
// 转换为 SillyTavern 格式
stPreset := map[string]interface{}{
"name": preset.Name,
"description": preset.Description,
"temperature": preset.Temperature,
"top_p": preset.TopP,
"top_k": preset.TopK,
"frequency_penalty": preset.FrequencyPenalty,
"presence_penalty": preset.PresencePenalty,
"max_tokens": preset.MaxTokens,
"repetition_penalty": preset.RepetitionPenalty,
"min_p": preset.MinP,
"top_a": preset.TopA,
"system_prompt": preset.SystemPrompt,
"stop_sequences": stopSequences,
"extensions": extensions,
}
return json.MarshalIndent(stPreset, "", " ")
}
// IncrementUseCount 增加使用次数
func (s *PresetService) IncrementUseCount(id uint) error {
return global.GVA_DB.Model(&app.AIPreset{}).
Where("id = ?", id).
Update("use_count", gorm.Expr("use_count + ?", 1)).Error
}