package app import ( "encoding/json" "errors" "git.echol.cn/loser/st/server/global" "git.echol.cn/loser/st/server/model/app" "git.echol.cn/loser/st/server/model/app/request" "go.uber.org/zap" "gorm.io/datatypes" "gorm.io/gorm" ) type PresetService struct{} // CreatePreset 创建预设 func (s *PresetService) CreatePreset(userID uint, req *request.CreatePresetRequest) (*app.AIPreset, error) { // 序列化 StopSequences var stopSequencesJSON datatypes.JSON if len(req.StopSequences) > 0 { data, err := json.Marshal(req.StopSequences) if err != nil { global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err)) return nil, err } stopSequencesJSON = data } // 序列化 Extensions var extensionsJSON datatypes.JSON if len(req.Extensions) > 0 { data, err := json.Marshal(req.Extensions) if err != nil { global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err)) return nil, err } extensionsJSON = data } preset := &app.AIPreset{ UserID: userID, Name: req.Name, Description: req.Description, IsPublic: req.IsPublic, Temperature: req.Temperature, TopP: req.TopP, TopK: req.TopK, FrequencyPenalty: req.FrequencyPenalty, PresencePenalty: req.PresencePenalty, MaxTokens: req.MaxTokens, RepetitionPenalty: req.RepetitionPenalty, MinP: req.MinP, TopA: req.TopA, SystemPrompt: req.SystemPrompt, StopSequences: stopSequencesJSON, Extensions: extensionsJSON, } if err := global.GVA_DB.Create(preset).Error; err != nil { global.GVA_LOG.Error("创建预设失败", zap.Error(err)) return nil, err } return preset, nil } // GetPresetList 获取预设列表 func (s *PresetService) GetPresetList(userID uint, req *request.GetPresetListRequest) ([]app.AIPreset, int64, error) { var presets []app.AIPreset var total int64 db := global.GVA_DB.Model(&app.AIPreset{}) // 权限过滤:只能看到自己的预设或公开的预设 db = db.Where("user_id = ? OR is_public = ?", userID, true) // 关键词搜索 if req.Keyword != "" { db = db.Where("name LIKE ? OR description LIKE ?", "%"+req.Keyword+"%", "%"+req.Keyword+"%") } // 公开/私有过滤 if req.IsPublic != nil { db = db.Where("is_public = ?", *req.IsPublic) } // 获取总数 if err := db.Count(&total).Error; err != nil { global.GVA_LOG.Error("获取预设总数失败", zap.Error(err)) return nil, 0, err } // 分页查询 offset := (req.Page - 1) * req.PageSize if err := db.Order("is_default DESC, updated_at DESC"). Offset(offset). Limit(req.PageSize). Find(&presets).Error; err != nil { global.GVA_LOG.Error("获取预设列表失败", zap.Error(err)) return nil, 0, err } return presets, total, nil } // GetPresetByID 根据ID获取预设 func (s *PresetService) GetPresetByID(userID uint, id uint) (*app.AIPreset, error) { var preset app.AIPreset // 权限检查:只能访问自己的预设或公开的预设 if err := global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userID, true). First(&preset).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("预设不存在或无权访问") } global.GVA_LOG.Error("获取预设失败", zap.Error(err)) return nil, err } return &preset, nil } // UpdatePreset 更新预设 func (s *PresetService) UpdatePreset(userID uint, id uint, req *request.UpdatePresetRequest) error { var preset app.AIPreset // 权限检查:只能更新自己的预设 if err := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("预设不存在或无权修改") } global.GVA_LOG.Error("查询预设失败", zap.Error(err)) return err } // 构建更新数据 updates := make(map[string]interface{}) if req.Name != "" { updates["name"] = req.Name } if req.Description != "" { updates["description"] = req.Description } if req.IsPublic != nil { updates["is_public"] = *req.IsPublic } if req.Temperature != nil { updates["temperature"] = *req.Temperature } if req.TopP != nil { updates["top_p"] = *req.TopP } if req.TopK != nil { updates["top_k"] = *req.TopK } if req.FrequencyPenalty != nil { updates["frequency_penalty"] = *req.FrequencyPenalty } if req.PresencePenalty != nil { updates["presence_penalty"] = *req.PresencePenalty } if req.MaxTokens != nil { updates["max_tokens"] = *req.MaxTokens } if req.RepetitionPenalty != nil { updates["repetition_penalty"] = *req.RepetitionPenalty } if req.MinP != nil { updates["min_p"] = *req.MinP } if req.TopA != nil { updates["top_a"] = *req.TopA } if req.SystemPrompt != nil { updates["system_prompt"] = *req.SystemPrompt } // 更新 StopSequences if req.StopSequences != nil { data, err := json.Marshal(req.StopSequences) if err != nil { global.GVA_LOG.Error("序列化 StopSequences 失败", zap.Error(err)) return err } updates["stop_sequences"] = datatypes.JSON(data) } // 更新 Extensions if req.Extensions != nil { data, err := json.Marshal(req.Extensions) if err != nil { global.GVA_LOG.Error("序列化 Extensions 失败", zap.Error(err)) return err } updates["extensions"] = datatypes.JSON(data) } if err := global.GVA_DB.Model(&preset).Updates(updates).Error; err != nil { global.GVA_LOG.Error("更新预设失败", zap.Error(err)) return err } return nil } // DeletePreset 删除预设 func (s *PresetService) DeletePreset(userID uint, id uint) error { // 权限检查:只能删除自己的预设 result := global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AIPreset{}) if result.Error != nil { global.GVA_LOG.Error("删除预设失败", zap.Error(result.Error)) return result.Error } if result.RowsAffected == 0 { return errors.New("预设不存在或无权删除") } return nil } // SetDefaultPreset 设置默认预设 func (s *PresetService) SetDefaultPreset(userID uint, id uint) error { return global.GVA_DB.Transaction(func(tx *gorm.DB) error { // 检查预设是否存在且属于当前用户 var preset app.AIPreset if err := tx.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("预设不存在或无权访问") } return err } // 取消当前用户的所有默认预设 if err := tx.Model(&app.AIPreset{}). Where("user_id = ? AND is_default = ?", userID, true). Update("is_default", false).Error; err != nil { return err } // 设置新的默认预设 if err := tx.Model(&preset).Update("is_default", true).Error; err != nil { return err } return nil }) } // ImportPresetFromJSON 从JSON导入预设 func (s *PresetService) ImportPresetFromJSON(userID uint, jsonData []byte, filename string) (*app.AIPreset, error) { // 尝试解析为 SillyTavern 格式 var stPreset struct { Temperature float64 `json:"temperature"` TopP float64 `json:"top_p"` TopK int `json:"top_k"` FrequencyPenalty float64 `json:"frequency_penalty"` PresencePenalty float64 `json:"presence_penalty"` MaxTokens int `json:"openai_max_tokens"` RepetitionPenalty float64 `json:"repetition_penalty"` MinP float64 `json:"min_p"` TopA float64 `json:"top_a"` StopSequences []string `json:"stop_sequences"` Prompts []map[string]interface{} `json:"prompts"` PromptOrder []map[string]interface{} `json:"prompt_order"` } if err := json.Unmarshal(jsonData, &stPreset); err != nil { global.GVA_LOG.Error("解析预设JSON失败", zap.Error(err)) return nil, errors.New("无效的预设格式") } // 从文件名提取预设名称(去掉 .json 后缀) name := filename if len(name) > 5 && name[len(name)-5:] == ".json" { name = name[:len(name)-5] } // 构建 extensions 对象,包含 prompts 和 prompt_order extensions := map[string]interface{}{ "prompts": stPreset.Prompts, "prompt_order": stPreset.PromptOrder, } // 转换为创建请求 req := &request.CreatePresetRequest{ Name: name, Description: "从 SillyTavern 导入", Temperature: stPreset.Temperature, TopP: stPreset.TopP, TopK: stPreset.TopK, FrequencyPenalty: stPreset.FrequencyPenalty, PresencePenalty: stPreset.PresencePenalty, MaxTokens: stPreset.MaxTokens, RepetitionPenalty: stPreset.RepetitionPenalty, MinP: stPreset.MinP, TopA: stPreset.TopA, SystemPrompt: "", StopSequences: stPreset.StopSequences, Extensions: extensions, } return s.CreatePreset(userID, req) } // ExportPresetToJSON 导出预设为JSON func (s *PresetService) ExportPresetToJSON(userID uint, id uint) ([]byte, error) { preset, err := s.GetPresetByID(userID, id) if err != nil { return nil, err } // 解析 StopSequences var stopSequences []string if len(preset.StopSequences) > 0 { json.Unmarshal(preset.StopSequences, &stopSequences) } // 解析 Extensions var extensions map[string]interface{} if len(preset.Extensions) > 0 { json.Unmarshal(preset.Extensions, &extensions) } // 转换为 SillyTavern 格式 stPreset := map[string]interface{}{ "name": preset.Name, "description": preset.Description, "temperature": preset.Temperature, "top_p": preset.TopP, "top_k": preset.TopK, "frequency_penalty": preset.FrequencyPenalty, "presence_penalty": preset.PresencePenalty, "max_tokens": preset.MaxTokens, "repetition_penalty": preset.RepetitionPenalty, "min_p": preset.MinP, "top_a": preset.TopA, "system_prompt": preset.SystemPrompt, "stop_sequences": stopSequences, "extensions": extensions, } return json.MarshalIndent(stPreset, "", " ") } // IncrementUseCount 增加使用次数 func (s *PresetService) IncrementUseCount(id uint) error { return global.GVA_DB.Model(&app.AIPreset{}). Where("id = ?", id). Update("use_count", gorm.Expr("use_count + ?", 1)).Error }