Files
ai_proxy/server/service/app/ai_preset_injector.go

438 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"fmt"
"regexp"
"sort"
"strings"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
)
// PresetInjector 预设注入器
type PresetInjector struct {
preset *app.AiPreset
regexLogs *response.RegexExecutionLogs
}
// NewPresetInjector 创建预设注入器
func NewPresetInjector(preset *app.AiPreset) *PresetInjector {
return &PresetInjector{
preset: preset,
regexLogs: &response.RegexExecutionLogs{
InputScripts: []response.RegexScriptLog{},
OutputScripts: []response.RegexScriptLog{},
},
}
}
// GetRegexLogs 获取正则脚本执行日志
func (p *PresetInjector) GetRegexLogs() *response.RegexExecutionLogs {
return p.regexLogs
}
// InjectMessages 注入预设到消息列表
func (p *PresetInjector) InjectMessages(messages []request.ChatMessage) []request.ChatMessage {
if p.preset == nil || len(p.preset.Prompts) == 0 {
return messages
}
// 1. 应用用户输入前的正则替换
messages = p.applyRegexScripts(messages, 1)
// 2. 获取启用的提示词并排序
enabledPrompts := p.getEnabledPrompts()
// 3. 构建注入后的消息列表
injectedMessages := p.buildInjectedMessages(messages, enabledPrompts)
return injectedMessages
}
// getEnabledPrompts 获取启用的提示词并按注入顺序排序
func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
var prompts []app.PresetPrompt
for _, prompt := range p.preset.Prompts {
if prompt.Enabled && !prompt.Marker {
prompts = append(prompts, prompt)
}
}
// 按 injection_order 和 injection_depth 排序
sort.Slice(prompts, func(i, j int) bool {
if prompts[i].InjectionOrder != prompts[j].InjectionOrder {
return prompts[i].InjectionOrder < prompts[j].InjectionOrder
}
return prompts[i].InjectionDepth < prompts[j].InjectionDepth
})
return prompts
}
// buildInjectedMessages 构建注入后的消息列表
// 参考 SillyTavern 的实现逻辑
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
result := make([]request.ChatMessage, 0)
// 按照 injection_position 分组
// 0 = 相对位置(从对话历史的特定深度注入)
// 1 = 绝对位置(在消息列表的固定位置注入)
var relativePrompts []app.PresetPrompt
var absolutePrompts []app.PresetPrompt
for _, prompt := range prompts {
if prompt.InjectionPosition == 0 {
relativePrompts = append(relativePrompts, prompt)
} else {
absolutePrompts = append(absolutePrompts, prompt)
}
}
// 处理绝对位置的提示词(通常是系统提示)
for _, prompt := range absolutePrompts {
if prompt.InjectionDepth == 0 {
// depth=0 表示在最开始
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
// 处理相对位置的提示词和对话历史
// 按 injection_depth 从大到小排序(深度越大越靠前)
sort.Slice(relativePrompts, func(i, j int) bool {
if relativePrompts[i].InjectionDepth != relativePrompts[j].InjectionDepth {
return relativePrompts[i].InjectionDepth > relativePrompts[j].InjectionDepth
}
return relativePrompts[i].InjectionOrder < relativePrompts[j].InjectionOrder
})
// 注入相对位置的提示词到对话历史中
injectedMessages := p.injectRelativePrompts(messages, relativePrompts)
result = append(result, injectedMessages...)
// 处理绝对位置在末尾的提示词
for _, prompt := range absolutePrompts {
if prompt.InjectionDepth > 0 {
// depth>0 表示在末尾
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
return result
}
// injectRelativePrompts 将相对位置的提示词注入到对话历史中
// 注入深度从前往后计算depth=0 表示索引0最前面depth=1 表示索引1以此类推
func (p *PresetInjector) injectRelativePrompts(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
if len(prompts) == 0 {
return messages
}
// 按深度分组提示词,并按 injection_order 排序
depthMap := make(map[int][]app.PresetPrompt)
for _, prompt := range prompts {
depthMap[prompt.InjectionDepth] = append(depthMap[prompt.InjectionDepth], prompt)
}
// 对每个深度的提示词按 injection_order 排序(从大到小,优先级高的在前)
for depth := range depthMap {
sort.Slice(depthMap[depth], func(i, j int) bool {
return depthMap[depth][i].InjectionOrder > depthMap[depth][j].InjectionOrder
})
}
result := make([]request.ChatMessage, 0, len(messages)+len(prompts))
totalInserted := 0
// 找出最大深度
maxDepth := 0
for depth := range depthMap {
if depth > maxDepth {
maxDepth = depth
}
}
// 从 depth=0 开始,逐个深度注入
for depth := 0; depth <= maxDepth; depth++ {
// 计算实际注入位置(考虑之前已注入的消息数量)
injectIdx := depth + totalInserted
// 如果注入位置超出当前消息列表,先添加原始消息直到该位置
for len(result) < injectIdx && len(result)-totalInserted < len(messages) {
result = append(result, messages[len(result)-totalInserted])
}
// 注入当前深度的所有提示词
if promptsAtDepth, exists := depthMap[depth]; exists {
for _, prompt := range promptsAtDepth {
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
totalInserted++
}
}
}
// 添加剩余的原始消息
for len(result)-totalInserted < len(messages) {
result = append(result, messages[len(result)-totalInserted])
}
return result
}
// findMarkerIndex 查找标记位置
func (p *PresetInjector) findMarkerIndex(identifier string) int {
for i, prompt := range p.preset.Prompts {
if prompt.Identifier == identifier && prompt.Marker {
return i
}
}
return -1
}
// processPromptContent 处理提示词内容(变量替换等)
func (p *PresetInjector) processPromptContent(content string) string {
// 处理 {{user}} 和 {{char}} 等变量
content = strings.ReplaceAll(content, "{{user}}", "User")
content = strings.ReplaceAll(content, "{{char}}", "Assistant")
// 处理 {{getvar::key}} 语法
getvarRegex := regexp.MustCompile(`\{\{getvar::(\w+)\}\}`)
content = getvarRegex.ReplaceAllString(content, "")
// 处理 {{setvar::key::value}} 语法
setvarRegex := regexp.MustCompile(`\{\{setvar::(\w+)::(.*?)\}\}`)
content = setvarRegex.ReplaceAllString(content, "")
// 处理注释 {{//...}}
commentRegex := regexp.MustCompile(`\{\{//.*?\}\}`)
content = commentRegex.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
// applyRegexScripts 应用正则替换脚本
func (p *PresetInjector) applyRegexScripts(messages []request.ChatMessage, placement int) []request.ChatMessage {
if p.preset.Extensions.RegexBinding == nil {
return messages
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查 placement
hasPlacement := false
for _, pl := range script.Placement {
if pl == placement {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 应用正则替换并记录日志
var matchCount int
var err error
messages, matchCount, err = p.applyRegexScriptWithLog(messages, script)
// 记录执行日志
log := response.RegexScriptLog{
ScriptName: script.ScriptName,
ScriptID: script.ID,
Executed: true,
MatchCount: matchCount,
}
if err != nil {
log.ErrorMessage = err.Error()
}
// 根据 placement 添加到对应的日志列表
if placement == 1 {
p.regexLogs.InputScripts = append(p.regexLogs.InputScripts, log)
}
p.regexLogs.TotalMatches += matchCount
}
return messages
}
// applyRegexScriptWithLog 应用单个正则脚本并返回匹配次数
func (p *PresetInjector) applyRegexScriptWithLog(messages []request.ChatMessage, script app.RegexScript) ([]request.ChatMessage, int, error) {
// 解析正则表达式
pattern := script.FindRegex
// 移除正则标志(如 /pattern/g)
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
return messages, 0, fmt.Errorf("正则编译失败: %v", err)
}
matchCount := 0
// 对每条消息应用替换
for i := range messages {
if script.PromptOnly && messages[i].Role != "user" {
continue
}
// 统计匹配次数
matches := re.FindAllString(messages[i].Content, -1)
matchCount += len(matches)
// 执行替换
messages[i].Content = re.ReplaceAllString(messages[i].Content, script.ReplaceString)
}
return messages, matchCount, nil
}
// ProcessResponse 处理AI响应(应用输出后的正则)
func (p *PresetInjector) ProcessResponse(content string) string {
if p.preset == nil || p.preset.Extensions.RegexBinding == nil {
return content
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查是否应用于输出(placement=2)
hasPlacement := false
for _, placement := range script.Placement {
if placement == 2 {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 解析正则表达式
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
// 记录错误日志
p.regexLogs.OutputScripts = append(p.regexLogs.OutputScripts, response.RegexScriptLog{
ScriptName: script.ScriptName,
ScriptID: script.ID,
Executed: false,
ErrorMessage: fmt.Sprintf("正则编译失败: %v", err),
})
continue
}
// 统计匹配次数
matches := re.FindAllString(content, -1)
matchCount := len(matches)
// 执行替换
content = re.ReplaceAllString(content, script.ReplaceString)
// 记录执行日志
p.regexLogs.OutputScripts = append(p.regexLogs.OutputScripts, response.RegexScriptLog{
ScriptName: script.ScriptName,
ScriptID: script.ID,
Executed: true,
MatchCount: matchCount,
})
p.regexLogs.TotalMatches += matchCount
}
return content
}
// ApplyPresetParameters 应用预设参数到请求
func (p *PresetInjector) ApplyPresetParameters(req *request.ChatCompletionRequest) {
if p.preset == nil {
return
}
// 如果请求中没有指定参数,使用预设的参数
if req.Temperature == nil && p.preset.Temperature > 0 {
temp := p.preset.Temperature
req.Temperature = &temp
}
if req.TopP == nil && p.preset.TopP > 0 {
topP := p.preset.TopP
req.TopP = &topP
}
if req.MaxTokens == nil && p.preset.MaxTokens > 0 {
maxTokens := p.preset.MaxTokens
req.MaxTokens = &maxTokens
}
if req.PresencePenalty == nil && p.preset.PresencePenalty != 0 {
pp := p.preset.PresencePenalty
req.PresencePenalty = &pp
}
if req.FrequencyPenalty == nil && p.preset.FrequencyPenalty != 0 {
fp := p.preset.FrequencyPenalty
req.FrequencyPenalty = &fp
}
}
// ValidatePreset 验证预设配置
func ValidatePreset(preset *app.AiPreset) error {
if preset == nil {
return fmt.Errorf("预设不能为空")
}
if preset.Name == "" {
return fmt.Errorf("预设名称不能为空")
}
// 验证正则表达式
if preset.Extensions.RegexBinding != nil {
for _, script := range preset.Extensions.RegexBinding.Regexes {
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
_, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("正则表达式 '%s' 无效: %v", script.ScriptName, err)
}
}
}
return nil
}