151 lines
4.4 KiB
Go
151 lines
4.4 KiB
Go
package app
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"git.echol.cn/loser/ai_proxy/server/global"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
|
"git.echol.cn/loser/ai_proxy/server/model/common/request"
|
|
)
|
|
|
|
type AiPresetService struct{}
|
|
|
|
// CreateAiPreset 创建预设
|
|
func (s *AiPresetService) CreateAiPreset(preset *app.AiPreset) error {
|
|
return global.GVA_DB.Create(preset).Error
|
|
}
|
|
|
|
// DeleteAiPreset 删除预设
|
|
func (s *AiPresetService) DeleteAiPreset(id uint, userID uint) error {
|
|
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiPreset{}).Error
|
|
}
|
|
|
|
// UpdateAiPreset 更新预设
|
|
func (s *AiPresetService) UpdateAiPreset(preset *app.AiPreset, userID uint) error {
|
|
return global.GVA_DB.Where("user_id = ?", userID).Updates(preset).Error
|
|
}
|
|
|
|
// GetAiPreset 查询预设
|
|
func (s *AiPresetService) GetAiPreset(id uint, userID uint) (preset app.AiPreset, err error) {
|
|
err = global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error
|
|
return
|
|
}
|
|
|
|
// GetAiPresetList 获取预设列表
|
|
func (s *AiPresetService) GetAiPresetList(info request.PageInfo, userID uint) (list []app.AiPreset, total int64, err error) {
|
|
limit := info.PageSize
|
|
offset := info.PageSize * (info.Page - 1)
|
|
db := global.GVA_DB.Model(&app.AiPreset{}).Where("user_id = ?", userID)
|
|
err = db.Count(&total).Error
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
|
|
return
|
|
}
|
|
|
|
// ParseImportedPreset 解析导入的预设,支持 SillyTavern 格式
|
|
// defaultName: 当 JSON 中没有名称时使用的默认名称(通常是文件名)
|
|
func (s *AiPresetService) ParseImportedPreset(rawData map[string]interface{}, defaultName string) (*app.AiPreset, error) {
|
|
preset := &app.AiPreset{
|
|
Enabled: true,
|
|
}
|
|
|
|
// 处理名称字段 - 支持多种格式
|
|
if name, ok := rawData["name"].(string); ok && name != "" {
|
|
preset.Name = name
|
|
} else if name, ok := rawData["preset_name"].(string); ok && name != "" {
|
|
preset.Name = name
|
|
} else if name, ok := rawData["presetName"].(string); ok && name != "" {
|
|
preset.Name = name
|
|
} else if defaultName != "" {
|
|
// 使用默认名称(文件名)
|
|
preset.Name = defaultName
|
|
} else {
|
|
return nil, fmt.Errorf("预设名称不能为空")
|
|
}
|
|
|
|
// 处理描述
|
|
if desc, ok := rawData["description"].(string); ok {
|
|
preset.Description = desc
|
|
}
|
|
|
|
// 处理参数
|
|
if temp, ok := rawData["temperature"].(float64); ok {
|
|
preset.Temperature = temp
|
|
} else {
|
|
preset.Temperature = 1.0
|
|
}
|
|
|
|
if topP, ok := rawData["top_p"].(float64); ok {
|
|
preset.TopP = topP
|
|
} else {
|
|
preset.TopP = 0.9
|
|
}
|
|
|
|
if topK, ok := rawData["top_k"].(float64); ok {
|
|
preset.TopK = int(topK)
|
|
}
|
|
|
|
if maxTokens, ok := rawData["max_tokens"].(float64); ok {
|
|
preset.MaxTokens = int(maxTokens)
|
|
} else {
|
|
preset.MaxTokens = 4096
|
|
}
|
|
|
|
if freqPenalty, ok := rawData["frequency_penalty"].(float64); ok {
|
|
preset.FrequencyPenalty = freqPenalty
|
|
}
|
|
|
|
if presPenalty, ok := rawData["presence_penalty"].(float64); ok {
|
|
preset.PresencePenalty = presPenalty
|
|
}
|
|
|
|
// 处理提示词
|
|
if prompts, ok := rawData["prompts"].([]interface{}); ok {
|
|
promptsData, _ := json.Marshal(prompts)
|
|
json.Unmarshal(promptsData, &preset.Prompts)
|
|
}
|
|
|
|
// 处理提示词顺序
|
|
if promptOrder, ok := rawData["prompt_order"].([]interface{}); ok {
|
|
orderData, _ := json.Marshal(promptOrder)
|
|
json.Unmarshal(orderData, &preset.PromptOrder)
|
|
}
|
|
|
|
// 处理正则脚本 - 支持两种格式
|
|
// 格式1: 顶层 regex_scripts
|
|
if regexScripts, ok := rawData["regex_scripts"].([]interface{}); ok {
|
|
scriptsData, _ := json.Marshal(regexScripts)
|
|
json.Unmarshal(scriptsData, &preset.RegexScripts)
|
|
}
|
|
|
|
// 处理扩展配置
|
|
if extensions, ok := rawData["extensions"].(map[string]interface{}); ok {
|
|
// 格式2: extensions.regex_scripts (SillyTavern 格式)
|
|
if regexScripts, ok := extensions["regex_scripts"].([]interface{}); ok {
|
|
scriptsData, _ := json.Marshal(regexScripts)
|
|
// 同时填充到 RegexScripts 和 Extensions.RegexBinding.Regexes
|
|
json.Unmarshal(scriptsData, &preset.RegexScripts)
|
|
|
|
// 确保 Extensions.RegexBinding 被初始化
|
|
if preset.Extensions.RegexBinding == nil {
|
|
preset.Extensions.RegexBinding = &app.RegexBindingConfig{}
|
|
}
|
|
json.Unmarshal(scriptsData, &preset.Extensions.RegexBinding.Regexes)
|
|
}
|
|
|
|
// 解析其他扩展配置
|
|
extData, _ := json.Marshal(extensions)
|
|
json.Unmarshal(extData, &preset.Extensions)
|
|
}
|
|
|
|
// 处理启用状态
|
|
if enabled, ok := rawData["enabled"].(bool); ok {
|
|
preset.Enabled = enabled
|
|
}
|
|
|
|
return preset, nil
|
|
}
|