274 lines
8.0 KiB
Go
274 lines
8.0 KiB
Go
package app
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
|
||
"git.echol.cn/loser/ai_proxy/server/global"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||
req "git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||
)
|
||
|
||
type AiPresetService struct{}
|
||
|
||
// CreateAiPreset 创建AI预设
|
||
func (s *AiPresetService) CreateAiPreset(userId uint, req *request.CreateAiPresetRequest) (preset app.AiPreset, err error) {
|
||
preset = app.AiPreset{
|
||
UserID: userId,
|
||
Name: req.Name,
|
||
Description: req.Description,
|
||
Prompts: req.Prompts,
|
||
RegexScripts: req.RegexScripts,
|
||
Temperature: req.Temperature,
|
||
TopP: req.TopP,
|
||
MaxTokens: req.MaxTokens,
|
||
FrequencyPenalty: req.FrequencyPenalty,
|
||
PresencePenalty: req.PresencePenalty,
|
||
StreamEnabled: req.StreamEnabled,
|
||
IsDefault: req.IsDefault,
|
||
IsPublic: req.IsPublic,
|
||
}
|
||
err = global.GVA_DB.Create(&preset).Error
|
||
return preset, err
|
||
}
|
||
|
||
// DeleteAiPreset 删除AI预设
|
||
func (s *AiPresetService) DeleteAiPreset(id uint, userId uint) (err error) {
|
||
// 如果 userId 为 0(未登录),不允许删除
|
||
if userId == 0 {
|
||
return global.GVA_DB.Where("id = ?", id).Delete(&app.AiPreset{}).Error
|
||
}
|
||
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userId).Delete(&app.AiPreset{}).Error
|
||
}
|
||
|
||
// UpdateAiPreset 更新AI预设
|
||
func (s *AiPresetService) UpdateAiPreset(userId uint, req *request.UpdateAiPresetRequest) (preset app.AiPreset, err error) {
|
||
// 如果 userId 为 0(未登录),允许更新任何预设
|
||
if userId == 0 {
|
||
err = global.GVA_DB.Where("id = ?", req.ID).First(&preset).Error
|
||
} else {
|
||
err = global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userId).First(&preset).Error
|
||
}
|
||
|
||
if err != nil {
|
||
return preset, err
|
||
}
|
||
|
||
if req.Name != "" {
|
||
preset.Name = req.Name
|
||
}
|
||
preset.Description = req.Description
|
||
if req.Prompts != nil {
|
||
preset.Prompts = req.Prompts
|
||
}
|
||
if req.RegexScripts != nil {
|
||
preset.RegexScripts = req.RegexScripts
|
||
}
|
||
preset.Temperature = req.Temperature
|
||
preset.TopP = req.TopP
|
||
preset.MaxTokens = req.MaxTokens
|
||
preset.FrequencyPenalty = req.FrequencyPenalty
|
||
preset.PresencePenalty = req.PresencePenalty
|
||
preset.StreamEnabled = req.StreamEnabled
|
||
preset.IsDefault = req.IsDefault
|
||
preset.IsPublic = req.IsPublic
|
||
|
||
err = global.GVA_DB.Save(&preset).Error
|
||
return preset, err
|
||
}
|
||
|
||
// GetAiPreset 获取AI预设详情
|
||
func (s *AiPresetService) GetAiPreset(id uint, userId uint) (preset app.AiPreset, err error) {
|
||
// 如果 userId 为 0(未登录),只能获取公开的预设
|
||
if userId == 0 {
|
||
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
|
||
} else {
|
||
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
|
||
}
|
||
return preset, err
|
||
}
|
||
|
||
// GetAiPresetList 获取AI预设列表
|
||
func (s *AiPresetService) GetAiPresetList(userId uint, info req.PageInfo) (list []app.AiPreset, total int64, err error) {
|
||
limit := info.PageSize
|
||
offset := info.PageSize * (info.Page - 1)
|
||
db := global.GVA_DB.Model(&app.AiPreset{})
|
||
|
||
// 如果 userId 为 0(未登录),只返回公开的预设
|
||
if userId == 0 {
|
||
db = db.Where("is_public = ?", true)
|
||
} else {
|
||
db = db.Where("user_id = ? OR is_public = ?", userId, true)
|
||
}
|
||
|
||
err = db.Count(&total).Error
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
err = db.Limit(limit).Offset(offset).Order("created_at DESC").Find(&list).Error
|
||
return list, total, err
|
||
}
|
||
|
||
// ImportAiPreset 导入AI预设(支持SillyTavern格式)
|
||
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
|
||
// 解析 SillyTavern JSON 格式
|
||
var stData map[string]interface{}
|
||
var jsonData []byte
|
||
|
||
// 类型断言处理 req.Data
|
||
switch v := req.Data.(type) {
|
||
case string:
|
||
jsonData = []byte(v)
|
||
case []byte:
|
||
jsonData = v
|
||
default:
|
||
return preset, fmt.Errorf("不支持的数据类型")
|
||
}
|
||
|
||
if err := json.Unmarshal(jsonData, &stData); err != nil {
|
||
return preset, fmt.Errorf("JSON 解析失败: %w", err)
|
||
}
|
||
|
||
// 提取基本信息
|
||
preset = app.AiPreset{
|
||
UserID: userId,
|
||
Name: req.Name,
|
||
Description: getStringValue(stData, "description"),
|
||
IsPublic: false,
|
||
}
|
||
|
||
// 提取参数
|
||
if temp, ok := stData["temperature"].(float64); ok {
|
||
preset.Temperature = temp
|
||
}
|
||
if topP, ok := stData["top_p"].(float64); ok {
|
||
preset.TopP = topP
|
||
}
|
||
if maxTokens, ok := stData["openai_max_tokens"].(float64); ok {
|
||
preset.MaxTokens = int(maxTokens)
|
||
} else if maxTokens, ok := stData["max_tokens"].(float64); ok {
|
||
preset.MaxTokens = int(maxTokens)
|
||
}
|
||
if freqPenalty, ok := stData["frequency_penalty"].(float64); ok {
|
||
preset.FrequencyPenalty = freqPenalty
|
||
}
|
||
if presPenalty, ok := stData["presence_penalty"].(float64); ok {
|
||
preset.PresencePenalty = presPenalty
|
||
}
|
||
if stream, ok := stData["stream_openai"].(bool); ok {
|
||
preset.StreamEnabled = stream
|
||
}
|
||
|
||
// 提取提示词
|
||
prompts := make([]app.Prompt, 0)
|
||
if promptsData, ok := stData["prompts"].([]interface{}); ok {
|
||
for i, p := range promptsData {
|
||
if promptMap, ok := p.(map[string]interface{}); ok {
|
||
prompt := app.Prompt{
|
||
Name: getStringValue(promptMap, "name"),
|
||
Role: getStringValue(promptMap, "role"),
|
||
Content: getStringValue(promptMap, "content"),
|
||
SystemPrompt: getBoolValue(promptMap, "system_prompt"),
|
||
Marker: getBoolValue(promptMap, "marker"),
|
||
InjectionOrder: i,
|
||
InjectionDepth: int(getFloatValue(promptMap, "injection_depth")),
|
||
InjectionPosition: int(getFloatValue(promptMap, "injection_position")),
|
||
}
|
||
prompts = append(prompts, prompt)
|
||
}
|
||
}
|
||
}
|
||
preset.Prompts = prompts
|
||
|
||
// 提取正则脚本
|
||
regexScripts := make([]app.RegexScript, 0)
|
||
if extensions, ok := stData["extensions"].(map[string]interface{}); ok {
|
||
if scripts, ok := extensions["regex_scripts"].([]interface{}); ok {
|
||
for _, s := range scripts {
|
||
if scriptMap, ok := s.(map[string]interface{}); ok {
|
||
script := app.RegexScript{
|
||
ScriptName: getStringValue(scriptMap, "script_name"),
|
||
FindRegex: getStringValue(scriptMap, "find_regex"),
|
||
ReplaceString: getStringValue(scriptMap, "replace_string"),
|
||
Disabled: getBoolValue(scriptMap, "disabled"),
|
||
Placement: getIntArray(scriptMap, "placement"),
|
||
}
|
||
regexScripts = append(regexScripts, script)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
preset.RegexScripts = regexScripts
|
||
|
||
// 保存到数据库
|
||
err = global.GVA_DB.Create(&preset).Error
|
||
return preset, err
|
||
}
|
||
|
||
// 辅助函数
|
||
func getStringValue(m map[string]interface{}, key string) string {
|
||
if v, ok := m[key].(string); ok {
|
||
return v
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func getBoolValue(m map[string]interface{}, key string) bool {
|
||
if v, ok := m[key].(bool); ok {
|
||
return v
|
||
}
|
||
return false
|
||
}
|
||
|
||
func getFloatValue(m map[string]interface{}, key string) float64 {
|
||
if v, ok := m[key].(float64); ok {
|
||
return v
|
||
}
|
||
return 0
|
||
}
|
||
|
||
func getIntArray(m map[string]interface{}, key string) []int {
|
||
result := make([]int, 0)
|
||
if arr, ok := m[key].([]interface{}); ok {
|
||
for _, v := range arr {
|
||
if num, ok := v.(float64); ok {
|
||
result = append(result, int(num))
|
||
}
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// ExportAiPreset 导出AI预设
|
||
func (s *AiPresetService) ExportAiPreset(id uint, userId uint) (data map[string]interface{}, err error) {
|
||
var preset app.AiPreset
|
||
|
||
// 如果 userId 为 0(未登录),只能导出公开的预设
|
||
if userId == 0 {
|
||
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
|
||
} else {
|
||
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
data = map[string]interface{}{
|
||
"prompts": preset.Prompts,
|
||
"extensions": map[string]interface{}{
|
||
"regex_scripts": preset.RegexScripts,
|
||
},
|
||
"temperature": preset.Temperature,
|
||
"top_p": preset.TopP,
|
||
"openai_max_tokens": preset.MaxTokens,
|
||
"frequency_penalty": preset.FrequencyPenalty,
|
||
"presence_penalty": preset.PresencePenalty,
|
||
"stream_openai": preset.StreamEnabled,
|
||
}
|
||
|
||
return data, nil
|
||
}
|