438 lines
12 KiB
Go
438 lines
12 KiB
Go
package app
|
||
|
||
import (
|
||
"fmt"
|
||
"regexp"
|
||
"sort"
|
||
"strings"
|
||
|
||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||
)
|
||
|
||
// PresetInjector 预设注入器
|
||
type PresetInjector struct {
|
||
preset *app.AiPreset
|
||
regexLogs *response.RegexExecutionLogs
|
||
}
|
||
|
||
// NewPresetInjector 创建预设注入器
|
||
func NewPresetInjector(preset *app.AiPreset) *PresetInjector {
|
||
return &PresetInjector{
|
||
preset: preset,
|
||
regexLogs: &response.RegexExecutionLogs{
|
||
InputScripts: []response.RegexScriptLog{},
|
||
OutputScripts: []response.RegexScriptLog{},
|
||
},
|
||
}
|
||
}
|
||
|
||
// GetRegexLogs 获取正则脚本执行日志
|
||
func (p *PresetInjector) GetRegexLogs() *response.RegexExecutionLogs {
|
||
return p.regexLogs
|
||
}
|
||
|
||
// InjectMessages 注入预设到消息列表
|
||
func (p *PresetInjector) InjectMessages(messages []request.ChatMessage) []request.ChatMessage {
|
||
if p.preset == nil || len(p.preset.Prompts) == 0 {
|
||
return messages
|
||
}
|
||
|
||
// 1. 应用用户输入前的正则替换
|
||
messages = p.applyRegexScripts(messages, 1)
|
||
|
||
// 2. 获取启用的提示词并排序
|
||
enabledPrompts := p.getEnabledPrompts()
|
||
|
||
// 3. 构建注入后的消息列表
|
||
injectedMessages := p.buildInjectedMessages(messages, enabledPrompts)
|
||
|
||
return injectedMessages
|
||
}
|
||
|
||
// getEnabledPrompts 获取启用的提示词并按注入顺序排序
|
||
func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
|
||
var prompts []app.PresetPrompt
|
||
for _, prompt := range p.preset.Prompts {
|
||
if prompt.Enabled && !prompt.Marker {
|
||
prompts = append(prompts, prompt)
|
||
}
|
||
}
|
||
|
||
// 按 injection_order 和 injection_depth 排序
|
||
sort.Slice(prompts, func(i, j int) bool {
|
||
if prompts[i].InjectionOrder != prompts[j].InjectionOrder {
|
||
return prompts[i].InjectionOrder < prompts[j].InjectionOrder
|
||
}
|
||
return prompts[i].InjectionDepth < prompts[j].InjectionDepth
|
||
})
|
||
|
||
return prompts
|
||
}
|
||
|
||
// buildInjectedMessages 构建注入后的消息列表
|
||
// 参考 SillyTavern 的实现逻辑
|
||
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
|
||
result := make([]request.ChatMessage, 0)
|
||
|
||
// 按照 injection_position 分组
|
||
// 0 = 相对位置(从对话历史的特定深度注入)
|
||
// 1 = 绝对位置(在消息列表的固定位置注入)
|
||
var relativePrompts []app.PresetPrompt
|
||
var absolutePrompts []app.PresetPrompt
|
||
|
||
for _, prompt := range prompts {
|
||
if prompt.InjectionPosition == 0 {
|
||
relativePrompts = append(relativePrompts, prompt)
|
||
} else {
|
||
absolutePrompts = append(absolutePrompts, prompt)
|
||
}
|
||
}
|
||
|
||
// 处理绝对位置的提示词(通常是系统提示)
|
||
for _, prompt := range absolutePrompts {
|
||
if prompt.InjectionDepth == 0 {
|
||
// depth=0 表示在最开始
|
||
result = append(result, request.ChatMessage{
|
||
Role: prompt.Role,
|
||
Content: p.processPromptContent(prompt.Content),
|
||
})
|
||
}
|
||
}
|
||
|
||
// 处理相对位置的提示词和对话历史
|
||
// 按 injection_depth 从大到小排序(深度越大越靠前)
|
||
sort.Slice(relativePrompts, func(i, j int) bool {
|
||
if relativePrompts[i].InjectionDepth != relativePrompts[j].InjectionDepth {
|
||
return relativePrompts[i].InjectionDepth > relativePrompts[j].InjectionDepth
|
||
}
|
||
return relativePrompts[i].InjectionOrder < relativePrompts[j].InjectionOrder
|
||
})
|
||
|
||
// 注入相对位置的提示词到对话历史中
|
||
injectedMessages := p.injectRelativePrompts(messages, relativePrompts)
|
||
result = append(result, injectedMessages...)
|
||
|
||
// 处理绝对位置在末尾的提示词
|
||
for _, prompt := range absolutePrompts {
|
||
if prompt.InjectionDepth > 0 {
|
||
// depth>0 表示在末尾
|
||
result = append(result, request.ChatMessage{
|
||
Role: prompt.Role,
|
||
Content: p.processPromptContent(prompt.Content),
|
||
})
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// injectRelativePrompts 将相对位置的提示词注入到对话历史中
|
||
// 注入深度从前往后计算:depth=0 表示索引0(最前面),depth=1 表示索引1,以此类推
|
||
func (p *PresetInjector) injectRelativePrompts(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
|
||
if len(prompts) == 0 {
|
||
return messages
|
||
}
|
||
|
||
// 按深度分组提示词,并按 injection_order 排序
|
||
depthMap := make(map[int][]app.PresetPrompt)
|
||
for _, prompt := range prompts {
|
||
depthMap[prompt.InjectionDepth] = append(depthMap[prompt.InjectionDepth], prompt)
|
||
}
|
||
|
||
// 对每个深度的提示词按 injection_order 排序(从大到小,优先级高的在前)
|
||
for depth := range depthMap {
|
||
sort.Slice(depthMap[depth], func(i, j int) bool {
|
||
return depthMap[depth][i].InjectionOrder > depthMap[depth][j].InjectionOrder
|
||
})
|
||
}
|
||
|
||
result := make([]request.ChatMessage, 0, len(messages)+len(prompts))
|
||
totalInserted := 0
|
||
|
||
// 找出最大深度
|
||
maxDepth := 0
|
||
for depth := range depthMap {
|
||
if depth > maxDepth {
|
||
maxDepth = depth
|
||
}
|
||
}
|
||
|
||
// 从 depth=0 开始,逐个深度注入
|
||
for depth := 0; depth <= maxDepth; depth++ {
|
||
// 计算实际注入位置(考虑之前已注入的消息数量)
|
||
injectIdx := depth + totalInserted
|
||
|
||
// 如果注入位置超出当前消息列表,先添加原始消息直到该位置
|
||
for len(result) < injectIdx && len(result)-totalInserted < len(messages) {
|
||
result = append(result, messages[len(result)-totalInserted])
|
||
}
|
||
|
||
// 注入当前深度的所有提示词
|
||
if promptsAtDepth, exists := depthMap[depth]; exists {
|
||
for _, prompt := range promptsAtDepth {
|
||
result = append(result, request.ChatMessage{
|
||
Role: prompt.Role,
|
||
Content: p.processPromptContent(prompt.Content),
|
||
})
|
||
totalInserted++
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加剩余的原始消息
|
||
for len(result)-totalInserted < len(messages) {
|
||
result = append(result, messages[len(result)-totalInserted])
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// findMarkerIndex 查找标记位置
|
||
func (p *PresetInjector) findMarkerIndex(identifier string) int {
|
||
for i, prompt := range p.preset.Prompts {
|
||
if prompt.Identifier == identifier && prompt.Marker {
|
||
return i
|
||
}
|
||
}
|
||
return -1
|
||
}
|
||
|
||
// processPromptContent 处理提示词内容(变量替换等)
|
||
func (p *PresetInjector) processPromptContent(content string) string {
|
||
// 处理 {{user}} 和 {{char}} 等变量
|
||
content = strings.ReplaceAll(content, "{{user}}", "User")
|
||
content = strings.ReplaceAll(content, "{{char}}", "Assistant")
|
||
|
||
// 处理 {{getvar::key}} 语法
|
||
getvarRegex := regexp.MustCompile(`\{\{getvar::(\w+)\}\}`)
|
||
content = getvarRegex.ReplaceAllString(content, "")
|
||
|
||
// 处理 {{setvar::key::value}} 语法
|
||
setvarRegex := regexp.MustCompile(`\{\{setvar::(\w+)::(.*?)\}\}`)
|
||
content = setvarRegex.ReplaceAllString(content, "")
|
||
|
||
// 处理注释 {{//...}}
|
||
commentRegex := regexp.MustCompile(`\{\{//.*?\}\}`)
|
||
content = commentRegex.ReplaceAllString(content, "")
|
||
|
||
return strings.TrimSpace(content)
|
||
}
|
||
|
||
// applyRegexScripts 应用正则替换脚本
|
||
func (p *PresetInjector) applyRegexScripts(messages []request.ChatMessage, placement int) []request.ChatMessage {
|
||
if p.preset.Extensions.RegexBinding == nil {
|
||
return messages
|
||
}
|
||
|
||
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
|
||
if script.Disabled {
|
||
continue
|
||
}
|
||
|
||
// 检查 placement
|
||
hasPlacement := false
|
||
for _, pl := range script.Placement {
|
||
if pl == placement {
|
||
hasPlacement = true
|
||
break
|
||
}
|
||
}
|
||
if !hasPlacement {
|
||
continue
|
||
}
|
||
|
||
// 应用正则替换并记录日志
|
||
var matchCount int
|
||
var err error
|
||
messages, matchCount, err = p.applyRegexScriptWithLog(messages, script)
|
||
|
||
// 记录执行日志
|
||
log := response.RegexScriptLog{
|
||
ScriptName: script.ScriptName,
|
||
ScriptID: script.ID,
|
||
Executed: true,
|
||
MatchCount: matchCount,
|
||
}
|
||
if err != nil {
|
||
log.ErrorMessage = err.Error()
|
||
}
|
||
|
||
// 根据 placement 添加到对应的日志列表
|
||
if placement == 1 {
|
||
p.regexLogs.InputScripts = append(p.regexLogs.InputScripts, log)
|
||
}
|
||
p.regexLogs.TotalMatches += matchCount
|
||
}
|
||
|
||
return messages
|
||
}
|
||
|
||
// applyRegexScriptWithLog 应用单个正则脚本并返回匹配次数
|
||
func (p *PresetInjector) applyRegexScriptWithLog(messages []request.ChatMessage, script app.RegexScript) ([]request.ChatMessage, int, error) {
|
||
// 解析正则表达式
|
||
pattern := script.FindRegex
|
||
// 移除正则标志(如 /pattern/g)
|
||
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
|
||
pattern = pattern[1 : len(pattern)-2]
|
||
} else if strings.HasPrefix(pattern, "/") {
|
||
lastSlash := strings.LastIndex(pattern, "/")
|
||
if lastSlash > 0 {
|
||
pattern = pattern[1:lastSlash]
|
||
}
|
||
}
|
||
|
||
re, err := regexp.Compile(pattern)
|
||
if err != nil {
|
||
return messages, 0, fmt.Errorf("正则编译失败: %v", err)
|
||
}
|
||
|
||
matchCount := 0
|
||
// 对每条消息应用替换
|
||
for i := range messages {
|
||
if script.PromptOnly && messages[i].Role != "user" {
|
||
continue
|
||
}
|
||
|
||
// 统计匹配次数
|
||
matches := re.FindAllString(messages[i].Content, -1)
|
||
matchCount += len(matches)
|
||
|
||
// 执行替换
|
||
messages[i].Content = re.ReplaceAllString(messages[i].Content, script.ReplaceString)
|
||
}
|
||
|
||
return messages, matchCount, nil
|
||
}
|
||
|
||
// ProcessResponse 处理AI响应(应用输出后的正则)
|
||
func (p *PresetInjector) ProcessResponse(content string) string {
|
||
if p.preset == nil || p.preset.Extensions.RegexBinding == nil {
|
||
return content
|
||
}
|
||
|
||
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
|
||
if script.Disabled {
|
||
continue
|
||
}
|
||
|
||
// 检查是否应用于输出(placement=2)
|
||
hasPlacement := false
|
||
for _, placement := range script.Placement {
|
||
if placement == 2 {
|
||
hasPlacement = true
|
||
break
|
||
}
|
||
}
|
||
if !hasPlacement {
|
||
continue
|
||
}
|
||
|
||
// 解析正则表达式
|
||
pattern := script.FindRegex
|
||
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
|
||
pattern = pattern[1 : len(pattern)-2]
|
||
} else if strings.HasPrefix(pattern, "/") {
|
||
lastSlash := strings.LastIndex(pattern, "/")
|
||
if lastSlash > 0 {
|
||
pattern = pattern[1:lastSlash]
|
||
}
|
||
}
|
||
|
||
re, err := regexp.Compile(pattern)
|
||
if err != nil {
|
||
// 记录错误日志
|
||
p.regexLogs.OutputScripts = append(p.regexLogs.OutputScripts, response.RegexScriptLog{
|
||
ScriptName: script.ScriptName,
|
||
ScriptID: script.ID,
|
||
Executed: false,
|
||
ErrorMessage: fmt.Sprintf("正则编译失败: %v", err),
|
||
})
|
||
continue
|
||
}
|
||
|
||
// 统计匹配次数
|
||
matches := re.FindAllString(content, -1)
|
||
matchCount := len(matches)
|
||
|
||
// 执行替换
|
||
content = re.ReplaceAllString(content, script.ReplaceString)
|
||
|
||
// 记录执行日志
|
||
p.regexLogs.OutputScripts = append(p.regexLogs.OutputScripts, response.RegexScriptLog{
|
||
ScriptName: script.ScriptName,
|
||
ScriptID: script.ID,
|
||
Executed: true,
|
||
MatchCount: matchCount,
|
||
})
|
||
p.regexLogs.TotalMatches += matchCount
|
||
}
|
||
|
||
return content
|
||
}
|
||
|
||
// ApplyPresetParameters 应用预设参数到请求
|
||
func (p *PresetInjector) ApplyPresetParameters(req *request.ChatCompletionRequest) {
|
||
if p.preset == nil {
|
||
return
|
||
}
|
||
|
||
// 如果请求中没有指定参数,使用预设的参数
|
||
if req.Temperature == nil && p.preset.Temperature > 0 {
|
||
temp := p.preset.Temperature
|
||
req.Temperature = &temp
|
||
}
|
||
|
||
if req.TopP == nil && p.preset.TopP > 0 {
|
||
topP := p.preset.TopP
|
||
req.TopP = &topP
|
||
}
|
||
|
||
if req.MaxTokens == nil && p.preset.MaxTokens > 0 {
|
||
maxTokens := p.preset.MaxTokens
|
||
req.MaxTokens = &maxTokens
|
||
}
|
||
|
||
if req.PresencePenalty == nil && p.preset.PresencePenalty != 0 {
|
||
pp := p.preset.PresencePenalty
|
||
req.PresencePenalty = &pp
|
||
}
|
||
|
||
if req.FrequencyPenalty == nil && p.preset.FrequencyPenalty != 0 {
|
||
fp := p.preset.FrequencyPenalty
|
||
req.FrequencyPenalty = &fp
|
||
}
|
||
}
|
||
|
||
// ValidatePreset 验证预设配置
|
||
func ValidatePreset(preset *app.AiPreset) error {
|
||
if preset == nil {
|
||
return fmt.Errorf("预设不能为空")
|
||
}
|
||
|
||
if preset.Name == "" {
|
||
return fmt.Errorf("预设名称不能为空")
|
||
}
|
||
|
||
// 验证正则表达式
|
||
if preset.Extensions.RegexBinding != nil {
|
||
for _, script := range preset.Extensions.RegexBinding.Regexes {
|
||
pattern := script.FindRegex
|
||
if strings.HasPrefix(pattern, "/") {
|
||
lastSlash := strings.LastIndex(pattern, "/")
|
||
if lastSlash > 0 {
|
||
pattern = pattern[1:lastSlash]
|
||
}
|
||
}
|
||
|
||
_, err := regexp.Compile(pattern)
|
||
if err != nil {
|
||
return fmt.Errorf("正则表达式 '%s' 无效: %v", script.ScriptName, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|