Files
ai_proxy/server/service/app/ai_preset_injector.go
2026-03-03 15:39:23 +08:00

306 lines
7.7 KiB
Go

package app
import (
"fmt"
"regexp"
"sort"
"strings"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
)
// PresetInjector 预设注入器
type PresetInjector struct {
preset *app.AiPreset
}
// NewPresetInjector 创建预设注入器
func NewPresetInjector(preset *app.AiPreset) *PresetInjector {
return &PresetInjector{preset: preset}
}
// InjectMessages 注入预设到消息列表
func (p *PresetInjector) InjectMessages(messages []request.ChatMessage) []request.ChatMessage {
if p.preset == nil || len(p.preset.Prompts) == 0 {
return messages
}
// 1. 应用用户输入前的正则替换
messages = p.applyRegexScripts(messages, 1)
// 2. 获取启用的提示词并排序
enabledPrompts := p.getEnabledPrompts()
// 3. 构建注入后的消息列表
injectedMessages := p.buildInjectedMessages(messages, enabledPrompts)
return injectedMessages
}
// getEnabledPrompts 获取启用的提示词并按注入顺序排序
func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
var prompts []app.PresetPrompt
for _, prompt := range p.preset.Prompts {
if prompt.Enabled && !prompt.Marker {
prompts = append(prompts, prompt)
}
}
// 按 injection_order 和 injection_depth 排序
sort.Slice(prompts, func(i, j int) bool {
if prompts[i].InjectionOrder != prompts[j].InjectionOrder {
return prompts[i].InjectionOrder < prompts[j].InjectionOrder
}
return prompts[i].InjectionDepth < prompts[j].InjectionDepth
})
return prompts
}
// buildInjectedMessages 构建注入后的消息列表
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
result := make([]request.ChatMessage, 0)
// 分离系统提示词和对话消息
var systemPrompts []app.PresetPrompt
var otherPrompts []app.PresetPrompt
for _, prompt := range prompts {
if prompt.Role == "system" {
systemPrompts = append(systemPrompts, prompt)
} else {
otherPrompts = append(otherPrompts, prompt)
}
}
// 1. 先添加系统提示词
for _, prompt := range systemPrompts {
result = append(result, request.ChatMessage{
Role: "system",
Content: p.processPromptContent(prompt.Content),
})
}
// 2. 处理对话历史注入
chatHistoryIndex := p.findMarkerIndex("chatHistory")
if chatHistoryIndex >= 0 {
// 在 chatHistory 标记位置注入原始消息
result = append(result, messages...)
} else {
// 如果没有 chatHistory 标记,直接添加到末尾
result = append(result, messages...)
}
// 3. 添加其他角色的提示词(assistant等)
for _, prompt := range otherPrompts {
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
return result
}
// findMarkerIndex 查找标记位置
func (p *PresetInjector) findMarkerIndex(identifier string) int {
for i, prompt := range p.preset.Prompts {
if prompt.Identifier == identifier && prompt.Marker {
return i
}
}
return -1
}
// processPromptContent 处理提示词内容(变量替换等)
func (p *PresetInjector) processPromptContent(content string) string {
// 处理 {{user}} 和 {{char}} 等变量
content = strings.ReplaceAll(content, "{{user}}", "User")
content = strings.ReplaceAll(content, "{{char}}", "Assistant")
// 处理 {{getvar::key}} 语法
getvarRegex := regexp.MustCompile(`\{\{getvar::(\w+)\}\}`)
content = getvarRegex.ReplaceAllString(content, "")
// 处理 {{setvar::key::value}} 语法
setvarRegex := regexp.MustCompile(`\{\{setvar::(\w+)::(.*?)\}\}`)
content = setvarRegex.ReplaceAllString(content, "")
// 处理注释 {{//...}}
commentRegex := regexp.MustCompile(`\{\{//.*?\}\}`)
content = commentRegex.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
// applyRegexScripts 应用正则替换脚本
func (p *PresetInjector) applyRegexScripts(messages []request.ChatMessage, placement int) []request.ChatMessage {
if p.preset.Extensions.RegexBinding == nil {
return messages
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查 placement
hasPlacement := false
for _, p := range script.Placement {
if p == placement {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 应用正则替换
messages = p.applyRegexScript(messages, script)
}
return messages
}
// applyRegexScript 应用单个正则脚本
func (p *PresetInjector) applyRegexScript(messages []request.ChatMessage, script app.RegexScript) []request.ChatMessage {
// 解析正则表达式
pattern := script.FindRegex
// 移除正则标志(如 /pattern/g)
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
return messages
}
// 对每条消息应用替换
for i := range messages {
if script.PromptOnly && messages[i].Role != "user" {
continue
}
messages[i].Content = re.ReplaceAllString(messages[i].Content, script.ReplaceString)
}
return messages
}
// ProcessResponse 处理AI响应(应用输出后的正则)
func (p *PresetInjector) ProcessResponse(content string) string {
if p.preset == nil || p.preset.Extensions.RegexBinding == nil {
return content
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查是否应用于输出(placement=2)
hasPlacement := false
for _, placement := range script.Placement {
if placement == 2 {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 解析正则表达式
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
continue
}
content = re.ReplaceAllString(content, script.ReplaceString)
}
return content
}
// ApplyPresetParameters 应用预设参数到请求
func (p *PresetInjector) ApplyPresetParameters(req *request.ChatCompletionRequest) {
if p.preset == nil {
return
}
// 如果请求中没有指定参数,使用预设的参数
if req.Temperature == nil && p.preset.Temperature > 0 {
temp := p.preset.Temperature
req.Temperature = &temp
}
if req.TopP == nil && p.preset.TopP > 0 {
topP := p.preset.TopP
req.TopP = &topP
}
if req.MaxTokens == nil && p.preset.MaxTokens > 0 {
maxTokens := p.preset.MaxTokens
req.MaxTokens = &maxTokens
}
if req.PresencePenalty == nil && p.preset.PresencePenalty != 0 {
pp := p.preset.PresencePenalty
req.PresencePenalty = &pp
}
if req.FrequencyPenalty == nil && p.preset.FrequencyPenalty != 0 {
fp := p.preset.FrequencyPenalty
req.FrequencyPenalty = &fp
}
}
// ValidatePreset 验证预设配置
func ValidatePreset(preset *app.AiPreset) error {
if preset == nil {
return fmt.Errorf("预设不能为空")
}
if preset.Name == "" {
return fmt.Errorf("预设名称不能为空")
}
// 验证正则表达式
if preset.Extensions.RegexBinding != nil {
for _, script := range preset.Extensions.RegexBinding.Regexes {
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
_, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("正则表达式 '%s' 无效: %v", script.ScriptName, err)
}
}
}
return nil
}