Files
ai_proxy/server/service/app/ai_preset_injector.go
2026-03-03 20:33:46 +08:00

414 lines
11 KiB
Go

package app
import (
"fmt"
"regexp"
"sort"
"strings"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
)
// PresetInjector 预设注入器
type PresetInjector struct {
preset *app.AiPreset
regexLogs *response.RegexExecutionLogs
}
// NewPresetInjector 创建预设注入器
func NewPresetInjector(preset *app.AiPreset) *PresetInjector {
return &PresetInjector{
preset: preset,
regexLogs: &response.RegexExecutionLogs{
InputScripts: []response.RegexScriptLog{},
OutputScripts: []response.RegexScriptLog{},
},
}
}
// GetRegexLogs 获取正则脚本执行日志
func (p *PresetInjector) GetRegexLogs() *response.RegexExecutionLogs {
return p.regexLogs
}
// InjectMessages 注入预设到消息列表
func (p *PresetInjector) InjectMessages(messages []request.ChatMessage) []request.ChatMessage {
if p.preset == nil || len(p.preset.Prompts) == 0 {
return messages
}
// 1. 应用用户输入前的正则替换
messages = p.applyRegexScripts(messages, 1)
// 2. 获取启用的提示词并排序
enabledPrompts := p.getEnabledPrompts()
// 3. 构建注入后的消息列表
injectedMessages := p.buildInjectedMessages(messages, enabledPrompts)
return injectedMessages
}
// getEnabledPrompts 获取启用的提示词并按注入顺序排序
func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
var prompts []app.PresetPrompt
for _, prompt := range p.preset.Prompts {
if prompt.Enabled && !prompt.Marker {
prompts = append(prompts, prompt)
}
}
// 按 injection_order 和 injection_depth 排序
sort.Slice(prompts, func(i, j int) bool {
if prompts[i].InjectionOrder != prompts[j].InjectionOrder {
return prompts[i].InjectionOrder < prompts[j].InjectionOrder
}
return prompts[i].InjectionDepth < prompts[j].InjectionDepth
})
return prompts
}
// buildInjectedMessages 构建注入后的消息列表
// 参考 SillyTavern 的实现逻辑
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
result := make([]request.ChatMessage, 0)
// 按照 injection_position 分组
// 0 = 相对位置(从对话历史的特定深度注入)
// 1 = 绝对位置(在消息列表的固定位置注入)
var relativePrompts []app.PresetPrompt
var absolutePrompts []app.PresetPrompt
for _, prompt := range prompts {
if prompt.InjectionPosition == 0 {
relativePrompts = append(relativePrompts, prompt)
} else {
absolutePrompts = append(absolutePrompts, prompt)
}
}
// 处理绝对位置的提示词(通常是系统提示)
for _, prompt := range absolutePrompts {
if prompt.InjectionDepth == 0 {
// depth=0 表示在最开始
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
// 处理相对位置的提示词和对话历史
// 按 injection_depth 从大到小排序(深度越大越靠前)
sort.Slice(relativePrompts, func(i, j int) bool {
if relativePrompts[i].InjectionDepth != relativePrompts[j].InjectionDepth {
return relativePrompts[i].InjectionDepth > relativePrompts[j].InjectionDepth
}
return relativePrompts[i].InjectionOrder < relativePrompts[j].InjectionOrder
})
// 注入相对位置的提示词到对话历史中
injectedMessages := p.injectRelativePrompts(messages, relativePrompts)
result = append(result, injectedMessages...)
// 处理绝对位置在末尾的提示词
for _, prompt := range absolutePrompts {
if prompt.InjectionDepth > 0 {
// depth>0 表示在末尾
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
return result
}
// injectRelativePrompts 将相对位置的提示词注入到对话历史中
func (p *PresetInjector) injectRelativePrompts(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
if len(prompts) == 0 {
return messages
}
result := make([]request.ChatMessage, 0, len(messages)+len(prompts))
messageCount := len(messages)
// 按深度分组提示词
depthMap := make(map[int][]app.PresetPrompt)
for _, prompt := range prompts {
depthMap[prompt.InjectionDepth] = append(depthMap[prompt.InjectionDepth], prompt)
}
// 遍历消息,在指定深度注入提示词
for i, msg := range messages {
// 计算当前位置的深度(从末尾开始计数)
depth := messageCount - i
// 在当前消息之前注入对应深度的提示词
if promptsAtDepth, exists := depthMap[depth]; exists {
for _, prompt := range promptsAtDepth {
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
// 添加当前消息
result = append(result, msg)
}
return result
}
// findMarkerIndex 查找标记位置
func (p *PresetInjector) findMarkerIndex(identifier string) int {
for i, prompt := range p.preset.Prompts {
if prompt.Identifier == identifier && prompt.Marker {
return i
}
}
return -1
}
// processPromptContent 处理提示词内容(变量替换等)
func (p *PresetInjector) processPromptContent(content string) string {
// 处理 {{user}} 和 {{char}} 等变量
content = strings.ReplaceAll(content, "{{user}}", "User")
content = strings.ReplaceAll(content, "{{char}}", "Assistant")
// 处理 {{getvar::key}} 语法
getvarRegex := regexp.MustCompile(`\{\{getvar::(\w+)\}\}`)
content = getvarRegex.ReplaceAllString(content, "")
// 处理 {{setvar::key::value}} 语法
setvarRegex := regexp.MustCompile(`\{\{setvar::(\w+)::(.*?)\}\}`)
content = setvarRegex.ReplaceAllString(content, "")
// 处理注释 {{//...}}
commentRegex := regexp.MustCompile(`\{\{//.*?\}\}`)
content = commentRegex.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
// applyRegexScripts 应用正则替换脚本
func (p *PresetInjector) applyRegexScripts(messages []request.ChatMessage, placement int) []request.ChatMessage {
if p.preset.Extensions.RegexBinding == nil {
return messages
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查 placement
hasPlacement := false
for _, pl := range script.Placement {
if pl == placement {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 应用正则替换并记录日志
var matchCount int
var err error
messages, matchCount, err = p.applyRegexScriptWithLog(messages, script)
// 记录执行日志
log := response.RegexScriptLog{
ScriptName: script.ScriptName,
ScriptID: script.ID,
Executed: true,
MatchCount: matchCount,
}
if err != nil {
log.ErrorMessage = err.Error()
}
// 根据 placement 添加到对应的日志列表
if placement == 1 {
p.regexLogs.InputScripts = append(p.regexLogs.InputScripts, log)
}
p.regexLogs.TotalMatches += matchCount
}
return messages
}
// applyRegexScriptWithLog 应用单个正则脚本并返回匹配次数
func (p *PresetInjector) applyRegexScriptWithLog(messages []request.ChatMessage, script app.RegexScript) ([]request.ChatMessage, int, error) {
// 解析正则表达式
pattern := script.FindRegex
// 移除正则标志(如 /pattern/g)
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
return messages, 0, fmt.Errorf("正则编译失败: %v", err)
}
matchCount := 0
// 对每条消息应用替换
for i := range messages {
if script.PromptOnly && messages[i].Role != "user" {
continue
}
// 统计匹配次数
matches := re.FindAllString(messages[i].Content, -1)
matchCount += len(matches)
// 执行替换
messages[i].Content = re.ReplaceAllString(messages[i].Content, script.ReplaceString)
}
return messages, matchCount, nil
}
// ProcessResponse 处理AI响应(应用输出后的正则)
func (p *PresetInjector) ProcessResponse(content string) string {
if p.preset == nil || p.preset.Extensions.RegexBinding == nil {
return content
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查是否应用于输出(placement=2)
hasPlacement := false
for _, placement := range script.Placement {
if placement == 2 {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 解析正则表达式
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
// 记录错误日志
p.regexLogs.OutputScripts = append(p.regexLogs.OutputScripts, response.RegexScriptLog{
ScriptName: script.ScriptName,
ScriptID: script.ID,
Executed: false,
ErrorMessage: fmt.Sprintf("正则编译失败: %v", err),
})
continue
}
// 统计匹配次数
matches := re.FindAllString(content, -1)
matchCount := len(matches)
// 执行替换
content = re.ReplaceAllString(content, script.ReplaceString)
// 记录执行日志
p.regexLogs.OutputScripts = append(p.regexLogs.OutputScripts, response.RegexScriptLog{
ScriptName: script.ScriptName,
ScriptID: script.ID,
Executed: true,
MatchCount: matchCount,
})
p.regexLogs.TotalMatches += matchCount
}
return content
}
// ApplyPresetParameters 应用预设参数到请求
func (p *PresetInjector) ApplyPresetParameters(req *request.ChatCompletionRequest) {
if p.preset == nil {
return
}
// 如果请求中没有指定参数,使用预设的参数
if req.Temperature == nil && p.preset.Temperature > 0 {
temp := p.preset.Temperature
req.Temperature = &temp
}
if req.TopP == nil && p.preset.TopP > 0 {
topP := p.preset.TopP
req.TopP = &topP
}
if req.MaxTokens == nil && p.preset.MaxTokens > 0 {
maxTokens := p.preset.MaxTokens
req.MaxTokens = &maxTokens
}
if req.PresencePenalty == nil && p.preset.PresencePenalty != 0 {
pp := p.preset.PresencePenalty
req.PresencePenalty = &pp
}
if req.FrequencyPenalty == nil && p.preset.FrequencyPenalty != 0 {
fp := p.preset.FrequencyPenalty
req.FrequencyPenalty = &fp
}
}
// ValidatePreset 验证预设配置
func ValidatePreset(preset *app.AiPreset) error {
if preset == nil {
return fmt.Errorf("预设不能为空")
}
if preset.Name == "" {
return fmt.Errorf("预设名称不能为空")
}
// 验证正则表达式
if preset.Extensions.RegexBinding != nil {
for _, script := range preset.Extensions.RegexBinding.Regexes {
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
_, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("正则表达式 '%s' 无效: %v", script.ScriptName, err)
}
}
}
return nil
}