355 lines
9.4 KiB
Go
355 lines
9.4 KiB
Go
package app
|
|
|
|
import (
|
|
"fmt"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
|
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
|
)
|
|
|
|
// PresetInjector 预设注入器
|
|
type PresetInjector struct {
|
|
preset *app.AiPreset
|
|
}
|
|
|
|
// NewPresetInjector 创建预设注入器
|
|
func NewPresetInjector(preset *app.AiPreset) *PresetInjector {
|
|
return &PresetInjector{preset: preset}
|
|
}
|
|
|
|
// InjectMessages 注入预设到消息列表
|
|
func (p *PresetInjector) InjectMessages(messages []request.ChatMessage) []request.ChatMessage {
|
|
if p.preset == nil || len(p.preset.Prompts) == 0 {
|
|
return messages
|
|
}
|
|
|
|
// 1. 应用用户输入前的正则替换
|
|
messages = p.applyRegexScripts(messages, 1)
|
|
|
|
// 2. 获取启用的提示词并排序
|
|
enabledPrompts := p.getEnabledPrompts()
|
|
|
|
// 3. 构建注入后的消息列表
|
|
injectedMessages := p.buildInjectedMessages(messages, enabledPrompts)
|
|
|
|
return injectedMessages
|
|
}
|
|
|
|
// getEnabledPrompts 获取启用的提示词并按注入顺序排序
|
|
func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
|
|
var prompts []app.PresetPrompt
|
|
for _, prompt := range p.preset.Prompts {
|
|
if prompt.Enabled && !prompt.Marker {
|
|
prompts = append(prompts, prompt)
|
|
}
|
|
}
|
|
|
|
// 按 injection_order 和 injection_depth 排序
|
|
sort.Slice(prompts, func(i, j int) bool {
|
|
if prompts[i].InjectionOrder != prompts[j].InjectionOrder {
|
|
return prompts[i].InjectionOrder < prompts[j].InjectionOrder
|
|
}
|
|
return prompts[i].InjectionDepth < prompts[j].InjectionDepth
|
|
})
|
|
|
|
return prompts
|
|
}
|
|
|
|
// buildInjectedMessages 构建注入后的消息列表
|
|
// 参考 SillyTavern 的实现逻辑
|
|
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
|
|
result := make([]request.ChatMessage, 0)
|
|
|
|
// 按照 injection_position 分组
|
|
// 0 = 相对位置(从对话历史的特定深度注入)
|
|
// 1 = 绝对位置(在消息列表的固定位置注入)
|
|
var relativePrompts []app.PresetPrompt
|
|
var absolutePrompts []app.PresetPrompt
|
|
|
|
for _, prompt := range prompts {
|
|
if prompt.InjectionPosition == 0 {
|
|
relativePrompts = append(relativePrompts, prompt)
|
|
} else {
|
|
absolutePrompts = append(absolutePrompts, prompt)
|
|
}
|
|
}
|
|
|
|
// 处理绝对位置的提示词(通常是系统提示)
|
|
for _, prompt := range absolutePrompts {
|
|
if prompt.InjectionDepth == 0 {
|
|
// depth=0 表示在最开始
|
|
result = append(result, request.ChatMessage{
|
|
Role: prompt.Role,
|
|
Content: p.processPromptContent(prompt.Content),
|
|
})
|
|
}
|
|
}
|
|
|
|
// 处理相对位置的提示词和对话历史
|
|
// 按 injection_depth 从大到小排序(深度越大越靠前)
|
|
sort.Slice(relativePrompts, func(i, j int) bool {
|
|
if relativePrompts[i].InjectionDepth != relativePrompts[j].InjectionDepth {
|
|
return relativePrompts[i].InjectionDepth > relativePrompts[j].InjectionDepth
|
|
}
|
|
return relativePrompts[i].InjectionOrder < relativePrompts[j].InjectionOrder
|
|
})
|
|
|
|
// 注入相对位置的提示词到对话历史中
|
|
injectedMessages := p.injectRelativePrompts(messages, relativePrompts)
|
|
result = append(result, injectedMessages...)
|
|
|
|
// 处理绝对位置在末尾的提示词
|
|
for _, prompt := range absolutePrompts {
|
|
if prompt.InjectionDepth > 0 {
|
|
// depth>0 表示在末尾
|
|
result = append(result, request.ChatMessage{
|
|
Role: prompt.Role,
|
|
Content: p.processPromptContent(prompt.Content),
|
|
})
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// injectRelativePrompts 将相对位置的提示词注入到对话历史中
|
|
func (p *PresetInjector) injectRelativePrompts(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
|
|
if len(prompts) == 0 {
|
|
return messages
|
|
}
|
|
|
|
result := make([]request.ChatMessage, 0, len(messages)+len(prompts))
|
|
messageCount := len(messages)
|
|
|
|
// 按深度分组提示词
|
|
depthMap := make(map[int][]app.PresetPrompt)
|
|
for _, prompt := range prompts {
|
|
depthMap[prompt.InjectionDepth] = append(depthMap[prompt.InjectionDepth], prompt)
|
|
}
|
|
|
|
// 遍历消息,在指定深度注入提示词
|
|
for i, msg := range messages {
|
|
// 计算当前位置的深度(从末尾开始计数)
|
|
depth := messageCount - i
|
|
|
|
// 在当前消息之前注入对应深度的提示词
|
|
if promptsAtDepth, exists := depthMap[depth]; exists {
|
|
for _, prompt := range promptsAtDepth {
|
|
result = append(result, request.ChatMessage{
|
|
Role: prompt.Role,
|
|
Content: p.processPromptContent(prompt.Content),
|
|
})
|
|
}
|
|
}
|
|
|
|
// 添加当前消息
|
|
result = append(result, msg)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// findMarkerIndex 查找标记位置
|
|
func (p *PresetInjector) findMarkerIndex(identifier string) int {
|
|
for i, prompt := range p.preset.Prompts {
|
|
if prompt.Identifier == identifier && prompt.Marker {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
// processPromptContent 处理提示词内容(变量替换等)
|
|
func (p *PresetInjector) processPromptContent(content string) string {
|
|
// 处理 {{user}} 和 {{char}} 等变量
|
|
content = strings.ReplaceAll(content, "{{user}}", "User")
|
|
content = strings.ReplaceAll(content, "{{char}}", "Assistant")
|
|
|
|
// 处理 {{getvar::key}} 语法
|
|
getvarRegex := regexp.MustCompile(`\{\{getvar::(\w+)\}\}`)
|
|
content = getvarRegex.ReplaceAllString(content, "")
|
|
|
|
// 处理 {{setvar::key::value}} 语法
|
|
setvarRegex := regexp.MustCompile(`\{\{setvar::(\w+)::(.*?)\}\}`)
|
|
content = setvarRegex.ReplaceAllString(content, "")
|
|
|
|
// 处理注释 {{//...}}
|
|
commentRegex := regexp.MustCompile(`\{\{//.*?\}\}`)
|
|
content = commentRegex.ReplaceAllString(content, "")
|
|
|
|
return strings.TrimSpace(content)
|
|
}
|
|
|
|
// applyRegexScripts 应用正则替换脚本
|
|
func (p *PresetInjector) applyRegexScripts(messages []request.ChatMessage, placement int) []request.ChatMessage {
|
|
if p.preset.Extensions.RegexBinding == nil {
|
|
return messages
|
|
}
|
|
|
|
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
|
|
if script.Disabled {
|
|
continue
|
|
}
|
|
|
|
// 检查 placement
|
|
hasPlacement := false
|
|
for _, p := range script.Placement {
|
|
if p == placement {
|
|
hasPlacement = true
|
|
break
|
|
}
|
|
}
|
|
if !hasPlacement {
|
|
continue
|
|
}
|
|
|
|
// 应用正则替换
|
|
messages = p.applyRegexScript(messages, script)
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
// applyRegexScript 应用单个正则脚本
|
|
func (p *PresetInjector) applyRegexScript(messages []request.ChatMessage, script app.RegexScript) []request.ChatMessage {
|
|
// 解析正则表达式
|
|
pattern := script.FindRegex
|
|
// 移除正则标志(如 /pattern/g)
|
|
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
|
|
pattern = pattern[1 : len(pattern)-2]
|
|
} else if strings.HasPrefix(pattern, "/") {
|
|
lastSlash := strings.LastIndex(pattern, "/")
|
|
if lastSlash > 0 {
|
|
pattern = pattern[1:lastSlash]
|
|
}
|
|
}
|
|
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return messages
|
|
}
|
|
|
|
// 对每条消息应用替换
|
|
for i := range messages {
|
|
if script.PromptOnly && messages[i].Role != "user" {
|
|
continue
|
|
}
|
|
|
|
messages[i].Content = re.ReplaceAllString(messages[i].Content, script.ReplaceString)
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
// ProcessResponse 处理AI响应(应用输出后的正则)
|
|
func (p *PresetInjector) ProcessResponse(content string) string {
|
|
if p.preset == nil || p.preset.Extensions.RegexBinding == nil {
|
|
return content
|
|
}
|
|
|
|
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
|
|
if script.Disabled {
|
|
continue
|
|
}
|
|
|
|
// 检查是否应用于输出(placement=2)
|
|
hasPlacement := false
|
|
for _, placement := range script.Placement {
|
|
if placement == 2 {
|
|
hasPlacement = true
|
|
break
|
|
}
|
|
}
|
|
if !hasPlacement {
|
|
continue
|
|
}
|
|
|
|
// 解析正则表达式
|
|
pattern := script.FindRegex
|
|
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
|
|
pattern = pattern[1 : len(pattern)-2]
|
|
} else if strings.HasPrefix(pattern, "/") {
|
|
lastSlash := strings.LastIndex(pattern, "/")
|
|
if lastSlash > 0 {
|
|
pattern = pattern[1:lastSlash]
|
|
}
|
|
}
|
|
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
content = re.ReplaceAllString(content, script.ReplaceString)
|
|
}
|
|
|
|
return content
|
|
}
|
|
|
|
// ApplyPresetParameters 应用预设参数到请求
|
|
func (p *PresetInjector) ApplyPresetParameters(req *request.ChatCompletionRequest) {
|
|
if p.preset == nil {
|
|
return
|
|
}
|
|
|
|
// 如果请求中没有指定参数,使用预设的参数
|
|
if req.Temperature == nil && p.preset.Temperature > 0 {
|
|
temp := p.preset.Temperature
|
|
req.Temperature = &temp
|
|
}
|
|
|
|
if req.TopP == nil && p.preset.TopP > 0 {
|
|
topP := p.preset.TopP
|
|
req.TopP = &topP
|
|
}
|
|
|
|
if req.MaxTokens == nil && p.preset.MaxTokens > 0 {
|
|
maxTokens := p.preset.MaxTokens
|
|
req.MaxTokens = &maxTokens
|
|
}
|
|
|
|
if req.PresencePenalty == nil && p.preset.PresencePenalty != 0 {
|
|
pp := p.preset.PresencePenalty
|
|
req.PresencePenalty = &pp
|
|
}
|
|
|
|
if req.FrequencyPenalty == nil && p.preset.FrequencyPenalty != 0 {
|
|
fp := p.preset.FrequencyPenalty
|
|
req.FrequencyPenalty = &fp
|
|
}
|
|
}
|
|
|
|
// ValidatePreset 验证预设配置
|
|
func ValidatePreset(preset *app.AiPreset) error {
|
|
if preset == nil {
|
|
return fmt.Errorf("预设不能为空")
|
|
}
|
|
|
|
if preset.Name == "" {
|
|
return fmt.Errorf("预设名称不能为空")
|
|
}
|
|
|
|
// 验证正则表达式
|
|
if preset.Extensions.RegexBinding != nil {
|
|
for _, script := range preset.Extensions.RegexBinding.Regexes {
|
|
pattern := script.FindRegex
|
|
if strings.HasPrefix(pattern, "/") {
|
|
lastSlash := strings.LastIndex(pattern, "/")
|
|
if lastSlash > 0 {
|
|
pattern = pattern[1:lastSlash]
|
|
}
|
|
}
|
|
|
|
_, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return fmt.Errorf("正则表达式 '%s' 无效: %v", script.ScriptName, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|