353
server/service/app/preset.go
Normal file
353
server/service/app/preset.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type PresetService struct{}
|
||||
|
||||
// CreatePreset 创建预设
|
||||
func (s *PresetService) CreatePreset(userID uint, req *request.CreatePresetRequest) (*app.AIPreset, error) {
|
||||
// 序列化 StopSequences
|
||||
var stopSequencesJSON datatypes.JSON
|
||||
if len(req.StopSequences) > 0 {
|
||||
data, err := json.Marshal(req.StopSequences)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
stopSequencesJSON = data
|
||||
}
|
||||
|
||||
// 序列化 Extensions
|
||||
var extensionsJSON datatypes.JSON
|
||||
if len(req.Extensions) > 0 {
|
||||
data, err := json.Marshal(req.Extensions)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
extensionsJSON = data
|
||||
}
|
||||
|
||||
preset := &app.AIPreset{
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
IsPublic: req.IsPublic,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
MaxTokens: req.MaxTokens,
|
||||
RepetitionPenalty: req.RepetitionPenalty,
|
||||
MinP: req.MinP,
|
||||
TopA: req.TopA,
|
||||
SystemPrompt: req.SystemPrompt,
|
||||
StopSequences: stopSequencesJSON,
|
||||
Extensions: extensionsJSON,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(preset).Error; err != nil {
|
||||
global.GVA_LOG.Error("创建预设失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return preset, nil
|
||||
}
|
||||
|
||||
// GetPresetList 获取预设列表
|
||||
func (s *PresetService) GetPresetList(userID uint, req *request.GetPresetListRequest) ([]app.AIPreset, int64, error) {
|
||||
var presets []app.AIPreset
|
||||
var total int64
|
||||
|
||||
db := global.GVA_DB.Model(&app.AIPreset{})
|
||||
|
||||
// 权限过滤:只能看到自己的预设或公开的预设
|
||||
db = db.Where("user_id = ? OR is_public = ?", userID, true)
|
||||
|
||||
// 关键词搜索
|
||||
if req.Keyword != "" {
|
||||
db = db.Where("name LIKE ? OR description LIKE ?", "%"+req.Keyword+"%", "%"+req.Keyword+"%")
|
||||
}
|
||||
|
||||
// 公开/私有过滤
|
||||
if req.IsPublic != nil {
|
||||
db = db.Where("is_public = ?", *req.IsPublic)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
global.GVA_LOG.Error("获取预设总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
if err := db.Order("is_default DESC, updated_at DESC").
|
||||
Offset(offset).
|
||||
Limit(req.PageSize).
|
||||
Find(&presets).Error; err != nil {
|
||||
global.GVA_LOG.Error("获取预设列表失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return presets, total, nil
|
||||
}
|
||||
|
||||
// GetPresetByID 根据ID获取预设
|
||||
func (s *PresetService) GetPresetByID(userID uint, id uint) (*app.AIPreset, error) {
|
||||
var preset app.AIPreset
|
||||
|
||||
// 权限检查:只能访问自己的预设或公开的预设
|
||||
if err := global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userID, true).
|
||||
First(&preset).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("预设不存在或无权访问")
|
||||
}
|
||||
global.GVA_LOG.Error("获取预设失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &preset, nil
|
||||
}
|
||||
|
||||
// UpdatePreset 更新预设
|
||||
func (s *PresetService) UpdatePreset(userID uint, id uint, req *request.UpdatePresetRequest) error {
|
||||
var preset app.AIPreset
|
||||
|
||||
// 权限检查:只能更新自己的预设
|
||||
if err := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("预设不存在或无权修改")
|
||||
}
|
||||
global.GVA_LOG.Error("查询预设失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 构建更新数据
|
||||
updates := make(map[string]interface{})
|
||||
|
||||
if req.Name != "" {
|
||||
updates["name"] = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
updates["description"] = req.Description
|
||||
}
|
||||
if req.IsPublic != nil {
|
||||
updates["is_public"] = *req.IsPublic
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
updates["temperature"] = *req.Temperature
|
||||
}
|
||||
if req.TopP != nil {
|
||||
updates["top_p"] = *req.TopP
|
||||
}
|
||||
if req.TopK != nil {
|
||||
updates["top_k"] = *req.TopK
|
||||
}
|
||||
if req.FrequencyPenalty != nil {
|
||||
updates["frequency_penalty"] = *req.FrequencyPenalty
|
||||
}
|
||||
if req.PresencePenalty != nil {
|
||||
updates["presence_penalty"] = *req.PresencePenalty
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
updates["max_tokens"] = *req.MaxTokens
|
||||
}
|
||||
if req.RepetitionPenalty != nil {
|
||||
updates["repetition_penalty"] = *req.RepetitionPenalty
|
||||
}
|
||||
if req.MinP != nil {
|
||||
updates["min_p"] = *req.MinP
|
||||
}
|
||||
if req.TopA != nil {
|
||||
updates["top_a"] = *req.TopA
|
||||
}
|
||||
if req.SystemPrompt != nil {
|
||||
updates["system_prompt"] = *req.SystemPrompt
|
||||
}
|
||||
|
||||
// 更新 StopSequences
|
||||
if req.StopSequences != nil {
|
||||
data, err := json.Marshal(req.StopSequences)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
updates["stop_sequences"] = datatypes.JSON(data)
|
||||
}
|
||||
|
||||
// 更新 Extensions
|
||||
if req.Extensions != nil {
|
||||
data, err := json.Marshal(req.Extensions)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
updates["extensions"] = datatypes.JSON(data)
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Model(&preset).Updates(updates).Error; err != nil {
|
||||
global.GVA_LOG.Error("更新预设失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePreset 删除预设
|
||||
func (s *PresetService) DeletePreset(userID uint, id uint) error {
|
||||
// 权限检查:只能删除自己的预设
|
||||
result := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AIPreset{})
|
||||
if result.Error != nil {
|
||||
global.GVA_LOG.Error("删除预设失败", zap.Error(result.Error))
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.New("预设不存在或无权删除")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultPreset 设置默认预设
|
||||
func (s *PresetService) SetDefaultPreset(userID uint, id uint) error {
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 检查预设是否存在且属于当前用户
|
||||
var preset app.AIPreset
|
||||
if err := tx.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("预设不存在或无权访问")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 取消当前用户的所有默认预设
|
||||
if err := tx.Model(&app.AIPreset{}).
|
||||
Where("user_id = ? AND is_default = ?", userID, true).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置新的默认预设
|
||||
if err := tx.Model(&preset).Update("is_default", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ImportPresetFromJSON 从JSON导入预设
|
||||
func (s *PresetService) ImportPresetFromJSON(userID uint, jsonData []byte, filename string) (*app.AIPreset, error) {
|
||||
// 尝试解析为 SillyTavern 格式
|
||||
var stPreset struct {
|
||||
Temperature float64 `json:"temperature"`
|
||||
TopP float64 `json:"top_p"`
|
||||
TopK int `json:"top_k"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
MaxTokens int `json:"openai_max_tokens"`
|
||||
RepetitionPenalty float64 `json:"repetition_penalty"`
|
||||
MinP float64 `json:"min_p"`
|
||||
TopA float64 `json:"top_a"`
|
||||
StopSequences []string `json:"stop_sequences"`
|
||||
Prompts []map[string]interface{} `json:"prompts"`
|
||||
PromptOrder []map[string]interface{} `json:"prompt_order"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &stPreset); err != nil {
|
||||
global.GVA_LOG.Error("解析预设JSON失败", zap.Error(err))
|
||||
return nil, errors.New("无效的预设格式")
|
||||
}
|
||||
|
||||
// 从文件名提取预设名称(去掉 .json 后缀)
|
||||
name := filename
|
||||
if len(name) > 5 && name[len(name)-5:] == ".json" {
|
||||
name = name[:len(name)-5]
|
||||
}
|
||||
|
||||
// 构建 extensions 对象,包含 prompts 和 prompt_order
|
||||
extensions := map[string]interface{}{
|
||||
"prompts": stPreset.Prompts,
|
||||
"prompt_order": stPreset.PromptOrder,
|
||||
}
|
||||
|
||||
// 转换为创建请求
|
||||
req := &request.CreatePresetRequest{
|
||||
Name: name,
|
||||
Description: "从 SillyTavern 导入",
|
||||
Temperature: stPreset.Temperature,
|
||||
TopP: stPreset.TopP,
|
||||
TopK: stPreset.TopK,
|
||||
FrequencyPenalty: stPreset.FrequencyPenalty,
|
||||
PresencePenalty: stPreset.PresencePenalty,
|
||||
MaxTokens: stPreset.MaxTokens,
|
||||
RepetitionPenalty: stPreset.RepetitionPenalty,
|
||||
MinP: stPreset.MinP,
|
||||
TopA: stPreset.TopA,
|
||||
SystemPrompt: "",
|
||||
StopSequences: stPreset.StopSequences,
|
||||
Extensions: extensions,
|
||||
}
|
||||
|
||||
return s.CreatePreset(userID, req)
|
||||
}
|
||||
|
||||
// ExportPresetToJSON 导出预设为JSON
|
||||
func (s *PresetService) ExportPresetToJSON(userID uint, id uint) ([]byte, error) {
|
||||
preset, err := s.GetPresetByID(userID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析 StopSequences
|
||||
var stopSequences []string
|
||||
if len(preset.StopSequences) > 0 {
|
||||
json.Unmarshal(preset.StopSequences, &stopSequences)
|
||||
}
|
||||
|
||||
// 解析 Extensions
|
||||
var extensions map[string]interface{}
|
||||
if len(preset.Extensions) > 0 {
|
||||
json.Unmarshal(preset.Extensions, &extensions)
|
||||
}
|
||||
|
||||
// 转换为 SillyTavern 格式
|
||||
stPreset := map[string]interface{}{
|
||||
"name": preset.Name,
|
||||
"description": preset.Description,
|
||||
"temperature": preset.Temperature,
|
||||
"top_p": preset.TopP,
|
||||
"top_k": preset.TopK,
|
||||
"frequency_penalty": preset.FrequencyPenalty,
|
||||
"presence_penalty": preset.PresencePenalty,
|
||||
"max_tokens": preset.MaxTokens,
|
||||
"repetition_penalty": preset.RepetitionPenalty,
|
||||
"min_p": preset.MinP,
|
||||
"top_a": preset.TopA,
|
||||
"system_prompt": preset.SystemPrompt,
|
||||
"stop_sequences": stopSequences,
|
||||
"extensions": extensions,
|
||||
}
|
||||
|
||||
return json.MarshalIndent(stPreset, "", " ")
|
||||
}
|
||||
|
||||
// IncrementUseCount 增加使用次数
|
||||
func (s *PresetService) IncrementUseCount(id uint) error {
|
||||
return global.GVA_DB.Model(&app.AIPreset{}).
|
||||
Where("id = ?", id).
|
||||
Update("use_count", gorm.Expr("use_count + ?", 1)).Error
|
||||
}
|
||||
Reference in New Issue
Block a user