Files
ai_proxy/server/service/app/ai_preset.go

274 lines
8.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"encoding/json"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
req "git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiPresetService struct{}
// CreateAiPreset 创建AI预设
func (s *AiPresetService) CreateAiPreset(userId uint, req *request.CreateAiPresetRequest) (preset app.AiPreset, err error) {
preset = app.AiPreset{
UserID: userId,
Name: req.Name,
Description: req.Description,
Prompts: req.Prompts,
RegexScripts: req.RegexScripts,
Temperature: req.Temperature,
TopP: req.TopP,
MaxTokens: req.MaxTokens,
FrequencyPenalty: req.FrequencyPenalty,
PresencePenalty: req.PresencePenalty,
StreamEnabled: req.StreamEnabled,
IsDefault: req.IsDefault,
IsPublic: req.IsPublic,
}
err = global.GVA_DB.Create(&preset).Error
return preset, err
}
// DeleteAiPreset 删除AI预设
func (s *AiPresetService) DeleteAiPreset(id uint, userId uint) (err error) {
// 如果 userId 为 0未登录不允许删除
if userId == 0 {
return global.GVA_DB.Where("id = ?", id).Delete(&app.AiPreset{}).Error
}
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userId).Delete(&app.AiPreset{}).Error
}
// UpdateAiPreset 更新AI预设
func (s *AiPresetService) UpdateAiPreset(userId uint, req *request.UpdateAiPresetRequest) (preset app.AiPreset, err error) {
// 如果 userId 为 0未登录允许更新任何预设
if userId == 0 {
err = global.GVA_DB.Where("id = ?", req.ID).First(&preset).Error
} else {
err = global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userId).First(&preset).Error
}
if err != nil {
return preset, err
}
if req.Name != "" {
preset.Name = req.Name
}
preset.Description = req.Description
if req.Prompts != nil {
preset.Prompts = req.Prompts
}
if req.RegexScripts != nil {
preset.RegexScripts = req.RegexScripts
}
preset.Temperature = req.Temperature
preset.TopP = req.TopP
preset.MaxTokens = req.MaxTokens
preset.FrequencyPenalty = req.FrequencyPenalty
preset.PresencePenalty = req.PresencePenalty
preset.StreamEnabled = req.StreamEnabled
preset.IsDefault = req.IsDefault
preset.IsPublic = req.IsPublic
err = global.GVA_DB.Save(&preset).Error
return preset, err
}
// GetAiPreset 获取AI预设详情
func (s *AiPresetService) GetAiPreset(id uint, userId uint) (preset app.AiPreset, err error) {
// 如果 userId 为 0未登录只能获取公开的预设
if userId == 0 {
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
} else {
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
}
return preset, err
}
// GetAiPresetList 获取AI预设列表
func (s *AiPresetService) GetAiPresetList(userId uint, info req.PageInfo) (list []app.AiPreset, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiPreset{})
// 如果 userId 为 0未登录只返回公开的预设
if userId == 0 {
db = db.Where("is_public = ?", true)
} else {
db = db.Where("user_id = ? OR is_public = ?", userId, true)
}
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("created_at DESC").Find(&list).Error
return list, total, err
}
// ImportAiPreset 导入AI预设支持SillyTavern格式
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
// 解析 SillyTavern JSON 格式
var stData map[string]interface{}
var jsonData []byte
// 类型断言处理 req.Data
switch v := req.Data.(type) {
case string:
jsonData = []byte(v)
case []byte:
jsonData = v
default:
return preset, fmt.Errorf("不支持的数据类型")
}
if err := json.Unmarshal(jsonData, &stData); err != nil {
return preset, fmt.Errorf("JSON 解析失败: %w", err)
}
// 提取基本信息
preset = app.AiPreset{
UserID: userId,
Name: req.Name,
Description: getStringValue(stData, "description"),
IsPublic: false,
}
// 提取参数
if temp, ok := stData["temperature"].(float64); ok {
preset.Temperature = temp
}
if topP, ok := stData["top_p"].(float64); ok {
preset.TopP = topP
}
if maxTokens, ok := stData["openai_max_tokens"].(float64); ok {
preset.MaxTokens = int(maxTokens)
} else if maxTokens, ok := stData["max_tokens"].(float64); ok {
preset.MaxTokens = int(maxTokens)
}
if freqPenalty, ok := stData["frequency_penalty"].(float64); ok {
preset.FrequencyPenalty = freqPenalty
}
if presPenalty, ok := stData["presence_penalty"].(float64); ok {
preset.PresencePenalty = presPenalty
}
if stream, ok := stData["stream_openai"].(bool); ok {
preset.StreamEnabled = stream
}
// 提取提示词
prompts := make([]app.Prompt, 0)
if promptsData, ok := stData["prompts"].([]interface{}); ok {
for i, p := range promptsData {
if promptMap, ok := p.(map[string]interface{}); ok {
prompt := app.Prompt{
Name: getStringValue(promptMap, "name"),
Role: getStringValue(promptMap, "role"),
Content: getStringValue(promptMap, "content"),
SystemPrompt: getBoolValue(promptMap, "system_prompt"),
Marker: getBoolValue(promptMap, "marker"),
InjectionOrder: i,
InjectionDepth: int(getFloatValue(promptMap, "injection_depth")),
InjectionPosition: int(getFloatValue(promptMap, "injection_position")),
}
prompts = append(prompts, prompt)
}
}
}
preset.Prompts = prompts
// 提取正则脚本
regexScripts := make([]app.RegexScript, 0)
if extensions, ok := stData["extensions"].(map[string]interface{}); ok {
if scripts, ok := extensions["regex_scripts"].([]interface{}); ok {
for _, s := range scripts {
if scriptMap, ok := s.(map[string]interface{}); ok {
script := app.RegexScript{
ScriptName: getStringValue(scriptMap, "script_name"),
FindRegex: getStringValue(scriptMap, "find_regex"),
ReplaceString: getStringValue(scriptMap, "replace_string"),
Disabled: getBoolValue(scriptMap, "disabled"),
Placement: getIntArray(scriptMap, "placement"),
}
regexScripts = append(regexScripts, script)
}
}
}
}
preset.RegexScripts = regexScripts
// 保存到数据库
err = global.GVA_DB.Create(&preset).Error
return preset, err
}
// 辅助函数
func getStringValue(m map[string]interface{}, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func getBoolValue(m map[string]interface{}, key string) bool {
if v, ok := m[key].(bool); ok {
return v
}
return false
}
func getFloatValue(m map[string]interface{}, key string) float64 {
if v, ok := m[key].(float64); ok {
return v
}
return 0
}
func getIntArray(m map[string]interface{}, key string) []int {
result := make([]int, 0)
if arr, ok := m[key].([]interface{}); ok {
for _, v := range arr {
if num, ok := v.(float64); ok {
result = append(result, int(num))
}
}
}
return result
}
// ExportAiPreset 导出AI预设
func (s *AiPresetService) ExportAiPreset(id uint, userId uint) (data map[string]interface{}, err error) {
var preset app.AiPreset
// 如果 userId 为 0未登录只能导出公开的预设
if userId == 0 {
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
} else {
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
}
if err != nil {
return nil, err
}
data = map[string]interface{}{
"prompts": preset.Prompts,
"extensions": map[string]interface{}{
"regex_scripts": preset.RegexScripts,
},
"temperature": preset.Temperature,
"top_p": preset.TopP,
"openai_max_tokens": preset.MaxTokens,
"frequency_penalty": preset.FrequencyPenalty,
"presence_penalty": preset.PresencePenalty,
"stream_openai": preset.StreamEnabled,
}
return data, nil
}