354 lines
10 KiB
Go
354 lines
10 KiB
Go
package app
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
|
|
"git.echol.cn/loser/st/server/global"
|
|
"git.echol.cn/loser/st/server/model/app"
|
|
"git.echol.cn/loser/st/server/model/app/request"
|
|
"go.uber.org/zap"
|
|
"gorm.io/datatypes"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type PresetService struct{}
|
|
|
|
// CreatePreset 创建预设
|
|
func (s *PresetService) CreatePreset(userID uint, req *request.CreatePresetRequest) (*app.AIPreset, error) {
|
|
// 序列化 StopSequences
|
|
var stopSequencesJSON datatypes.JSON
|
|
if len(req.StopSequences) > 0 {
|
|
data, err := json.Marshal(req.StopSequences)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
stopSequencesJSON = data
|
|
}
|
|
|
|
// 序列化 Extensions
|
|
var extensionsJSON datatypes.JSON
|
|
if len(req.Extensions) > 0 {
|
|
data, err := json.Marshal(req.Extensions)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
extensionsJSON = data
|
|
}
|
|
|
|
preset := &app.AIPreset{
|
|
UserID: userID,
|
|
Name: req.Name,
|
|
Description: req.Description,
|
|
IsPublic: req.IsPublic,
|
|
Temperature: req.Temperature,
|
|
TopP: req.TopP,
|
|
TopK: req.TopK,
|
|
FrequencyPenalty: req.FrequencyPenalty,
|
|
PresencePenalty: req.PresencePenalty,
|
|
MaxTokens: req.MaxTokens,
|
|
RepetitionPenalty: req.RepetitionPenalty,
|
|
MinP: req.MinP,
|
|
TopA: req.TopA,
|
|
SystemPrompt: req.SystemPrompt,
|
|
StopSequences: stopSequencesJSON,
|
|
Extensions: extensionsJSON,
|
|
}
|
|
|
|
if err := global.GVA_DB.Create(preset).Error; err != nil {
|
|
global.GVA_LOG.Error("创建预设失败", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
return preset, nil
|
|
}
|
|
|
|
// GetPresetList 获取预设列表
|
|
func (s *PresetService) GetPresetList(userID uint, req *request.GetPresetListRequest) ([]app.AIPreset, int64, error) {
|
|
var presets []app.AIPreset
|
|
var total int64
|
|
|
|
db := global.GVA_DB.Model(&app.AIPreset{})
|
|
|
|
// 权限过滤:只能看到自己的预设或公开的预设
|
|
db = db.Where("user_id = ? OR is_public = ?", userID, true)
|
|
|
|
// 关键词搜索
|
|
if req.Keyword != "" {
|
|
db = db.Where("name LIKE ? OR description LIKE ?", "%"+req.Keyword+"%", "%"+req.Keyword+"%")
|
|
}
|
|
|
|
// 公开/私有过滤
|
|
if req.IsPublic != nil {
|
|
db = db.Where("is_public = ?", *req.IsPublic)
|
|
}
|
|
|
|
// 获取总数
|
|
if err := db.Count(&total).Error; err != nil {
|
|
global.GVA_LOG.Error("获取预设总数失败", zap.Error(err))
|
|
return nil, 0, err
|
|
}
|
|
|
|
// 分页查询
|
|
offset := (req.Page - 1) * req.PageSize
|
|
if err := db.Order("is_default DESC, updated_at DESC").
|
|
Offset(offset).
|
|
Limit(req.PageSize).
|
|
Find(&presets).Error; err != nil {
|
|
global.GVA_LOG.Error("获取预设列表失败", zap.Error(err))
|
|
return nil, 0, err
|
|
}
|
|
|
|
return presets, total, nil
|
|
}
|
|
|
|
// GetPresetByID 根据ID获取预设
|
|
func (s *PresetService) GetPresetByID(userID uint, id uint) (*app.AIPreset, error) {
|
|
var preset app.AIPreset
|
|
|
|
// 权限检查:只能访问自己的预设或公开的预设
|
|
if err := global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userID, true).
|
|
First(&preset).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, errors.New("预设不存在或无权访问")
|
|
}
|
|
global.GVA_LOG.Error("获取预设失败", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
return &preset, nil
|
|
}
|
|
|
|
// UpdatePreset 更新预设
|
|
func (s *PresetService) UpdatePreset(userID uint, id uint, req *request.UpdatePresetRequest) error {
|
|
var preset app.AIPreset
|
|
|
|
// 权限检查:只能更新自己的预设
|
|
if err := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return errors.New("预设不存在或无权修改")
|
|
}
|
|
global.GVA_LOG.Error("查询预设失败", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
// 构建更新数据
|
|
updates := make(map[string]interface{})
|
|
|
|
if req.Name != "" {
|
|
updates["name"] = req.Name
|
|
}
|
|
if req.Description != "" {
|
|
updates["description"] = req.Description
|
|
}
|
|
if req.IsPublic != nil {
|
|
updates["is_public"] = *req.IsPublic
|
|
}
|
|
if req.Temperature != nil {
|
|
updates["temperature"] = *req.Temperature
|
|
}
|
|
if req.TopP != nil {
|
|
updates["top_p"] = *req.TopP
|
|
}
|
|
if req.TopK != nil {
|
|
updates["top_k"] = *req.TopK
|
|
}
|
|
if req.FrequencyPenalty != nil {
|
|
updates["frequency_penalty"] = *req.FrequencyPenalty
|
|
}
|
|
if req.PresencePenalty != nil {
|
|
updates["presence_penalty"] = *req.PresencePenalty
|
|
}
|
|
if req.MaxTokens != nil {
|
|
updates["max_tokens"] = *req.MaxTokens
|
|
}
|
|
if req.RepetitionPenalty != nil {
|
|
updates["repetition_penalty"] = *req.RepetitionPenalty
|
|
}
|
|
if req.MinP != nil {
|
|
updates["min_p"] = *req.MinP
|
|
}
|
|
if req.TopA != nil {
|
|
updates["top_a"] = *req.TopA
|
|
}
|
|
if req.SystemPrompt != nil {
|
|
updates["system_prompt"] = *req.SystemPrompt
|
|
}
|
|
|
|
// 更新 StopSequences
|
|
if req.StopSequences != nil {
|
|
data, err := json.Marshal(req.StopSequences)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err))
|
|
return err
|
|
}
|
|
updates["stop_sequences"] = datatypes.JSON(data)
|
|
}
|
|
|
|
// 更新 Extensions
|
|
if req.Extensions != nil {
|
|
data, err := json.Marshal(req.Extensions)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err))
|
|
return err
|
|
}
|
|
updates["extensions"] = datatypes.JSON(data)
|
|
}
|
|
|
|
if err := global.GVA_DB.Model(&preset).Updates(updates).Error; err != nil {
|
|
global.GVA_LOG.Error("更新预设失败", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeletePreset 删除预设
|
|
func (s *PresetService) DeletePreset(userID uint, id uint) error {
|
|
// 权限检查:只能删除自己的预设
|
|
result := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AIPreset{})
|
|
if result.Error != nil {
|
|
global.GVA_LOG.Error("删除预设失败", zap.Error(result.Error))
|
|
return result.Error
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
return errors.New("预设不存在或无权删除")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetDefaultPreset 设置默认预设
|
|
func (s *PresetService) SetDefaultPreset(userID uint, id uint) error {
|
|
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
|
// 检查预设是否存在且属于当前用户
|
|
var preset app.AIPreset
|
|
if err := tx.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return errors.New("预设不存在或无权访问")
|
|
}
|
|
return err
|
|
}
|
|
|
|
// 取消当前用户的所有默认预设
|
|
if err := tx.Model(&app.AIPreset{}).
|
|
Where("user_id = ? AND is_default = ?", userID, true).
|
|
Update("is_default", false).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// 设置新的默认预设
|
|
if err := tx.Model(&preset).Update("is_default", true).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// ImportPresetFromJSON 从JSON导入预设
|
|
func (s *PresetService) ImportPresetFromJSON(userID uint, jsonData []byte, filename string) (*app.AIPreset, error) {
|
|
// 尝试解析为 SillyTavern 格式
|
|
var stPreset struct {
|
|
Temperature float64 `json:"temperature"`
|
|
TopP float64 `json:"top_p"`
|
|
TopK int `json:"top_k"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
PresencePenalty float64 `json:"presence_penalty"`
|
|
MaxTokens int `json:"openai_max_tokens"`
|
|
RepetitionPenalty float64 `json:"repetition_penalty"`
|
|
MinP float64 `json:"min_p"`
|
|
TopA float64 `json:"top_a"`
|
|
StopSequences []string `json:"stop_sequences"`
|
|
Prompts []map[string]interface{} `json:"prompts"`
|
|
PromptOrder []map[string]interface{} `json:"prompt_order"`
|
|
}
|
|
|
|
if err := json.Unmarshal(jsonData, &stPreset); err != nil {
|
|
global.GVA_LOG.Error("解析预设JSON失败", zap.Error(err))
|
|
return nil, errors.New("无效的预设格式")
|
|
}
|
|
|
|
// 从文件名提取预设名称(去掉 .json 后缀)
|
|
name := filename
|
|
if len(name) > 5 && name[len(name)-5:] == ".json" {
|
|
name = name[:len(name)-5]
|
|
}
|
|
|
|
// 构建 extensions 对象,包含 prompts 和 prompt_order
|
|
extensions := map[string]interface{}{
|
|
"prompts": stPreset.Prompts,
|
|
"prompt_order": stPreset.PromptOrder,
|
|
}
|
|
|
|
// 转换为创建请求
|
|
req := &request.CreatePresetRequest{
|
|
Name: name,
|
|
Description: "从 SillyTavern 导入",
|
|
Temperature: stPreset.Temperature,
|
|
TopP: stPreset.TopP,
|
|
TopK: stPreset.TopK,
|
|
FrequencyPenalty: stPreset.FrequencyPenalty,
|
|
PresencePenalty: stPreset.PresencePenalty,
|
|
MaxTokens: stPreset.MaxTokens,
|
|
RepetitionPenalty: stPreset.RepetitionPenalty,
|
|
MinP: stPreset.MinP,
|
|
TopA: stPreset.TopA,
|
|
SystemPrompt: "",
|
|
StopSequences: stPreset.StopSequences,
|
|
Extensions: extensions,
|
|
}
|
|
|
|
return s.CreatePreset(userID, req)
|
|
}
|
|
|
|
// ExportPresetToJSON 导出预设为JSON
|
|
func (s *PresetService) ExportPresetToJSON(userID uint, id uint) ([]byte, error) {
|
|
preset, err := s.GetPresetByID(userID, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 解析 StopSequences
|
|
var stopSequences []string
|
|
if len(preset.StopSequences) > 0 {
|
|
json.Unmarshal(preset.StopSequences, &stopSequences)
|
|
}
|
|
|
|
// 解析 Extensions
|
|
var extensions map[string]interface{}
|
|
if len(preset.Extensions) > 0 {
|
|
json.Unmarshal(preset.Extensions, &extensions)
|
|
}
|
|
|
|
// 转换为 SillyTavern 格式
|
|
stPreset := map[string]interface{}{
|
|
"name": preset.Name,
|
|
"description": preset.Description,
|
|
"temperature": preset.Temperature,
|
|
"top_p": preset.TopP,
|
|
"top_k": preset.TopK,
|
|
"frequency_penalty": preset.FrequencyPenalty,
|
|
"presence_penalty": preset.PresencePenalty,
|
|
"max_tokens": preset.MaxTokens,
|
|
"repetition_penalty": preset.RepetitionPenalty,
|
|
"min_p": preset.MinP,
|
|
"top_a": preset.TopA,
|
|
"system_prompt": preset.SystemPrompt,
|
|
"stop_sequences": stopSequences,
|
|
"extensions": extensions,
|
|
}
|
|
|
|
return json.MarshalIndent(stPreset, "", " ")
|
|
}
|
|
|
|
// IncrementUseCount 增加使用次数
|
|
func (s *PresetService) IncrementUseCount(id uint) error {
|
|
return global.GVA_DB.Model(&app.AIPreset{}).
|
|
Where("id = ?", id).
|
|
Update("use_count", gorm.Expr("use_count + ?", 1)).Error
|
|
}
|