🎨 优化项目结构 && 完善ai配置
This commit is contained in:
@@ -1,273 +1,43 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
req "git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||||
)
|
||||
|
||||
type AiPresetService struct{}
|
||||
|
||||
// CreateAiPreset 创建AI预设
|
||||
func (s *AiPresetService) CreateAiPreset(userId uint, req *request.CreateAiPresetRequest) (preset app.AiPreset, err error) {
|
||||
preset = app.AiPreset{
|
||||
UserID: userId,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Prompts: req.Prompts,
|
||||
RegexScripts: req.RegexScripts,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
MaxTokens: req.MaxTokens,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
StreamEnabled: req.StreamEnabled,
|
||||
IsDefault: req.IsDefault,
|
||||
IsPublic: req.IsPublic,
|
||||
}
|
||||
err = global.GVA_DB.Create(&preset).Error
|
||||
return preset, err
|
||||
// CreateAiPreset 创建预设
|
||||
func (s *AiPresetService) CreateAiPreset(preset *app.AiPreset) error {
|
||||
return global.GVA_DB.Create(preset).Error
|
||||
}
|
||||
|
||||
// DeleteAiPreset 删除AI预设
|
||||
func (s *AiPresetService) DeleteAiPreset(id uint, userId uint) (err error) {
|
||||
// 如果 userId 为 0(未登录),不允许删除
|
||||
if userId == 0 {
|
||||
return global.GVA_DB.Where("id = ?", id).Delete(&app.AiPreset{}).Error
|
||||
}
|
||||
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userId).Delete(&app.AiPreset{}).Error
|
||||
// DeleteAiPreset 删除预设
|
||||
func (s *AiPresetService) DeleteAiPreset(id uint, userID uint) error {
|
||||
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiPreset{}).Error
|
||||
}
|
||||
|
||||
// UpdateAiPreset 更新AI预设
|
||||
func (s *AiPresetService) UpdateAiPreset(userId uint, req *request.UpdateAiPresetRequest) (preset app.AiPreset, err error) {
|
||||
// 如果 userId 为 0(未登录),允许更新任何预设
|
||||
if userId == 0 {
|
||||
err = global.GVA_DB.Where("id = ?", req.ID).First(&preset).Error
|
||||
} else {
|
||||
err = global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userId).First(&preset).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return preset, err
|
||||
}
|
||||
|
||||
if req.Name != "" {
|
||||
preset.Name = req.Name
|
||||
}
|
||||
preset.Description = req.Description
|
||||
if req.Prompts != nil {
|
||||
preset.Prompts = req.Prompts
|
||||
}
|
||||
if req.RegexScripts != nil {
|
||||
preset.RegexScripts = req.RegexScripts
|
||||
}
|
||||
preset.Temperature = req.Temperature
|
||||
preset.TopP = req.TopP
|
||||
preset.MaxTokens = req.MaxTokens
|
||||
preset.FrequencyPenalty = req.FrequencyPenalty
|
||||
preset.PresencePenalty = req.PresencePenalty
|
||||
preset.StreamEnabled = req.StreamEnabled
|
||||
preset.IsDefault = req.IsDefault
|
||||
preset.IsPublic = req.IsPublic
|
||||
|
||||
err = global.GVA_DB.Save(&preset).Error
|
||||
return preset, err
|
||||
// UpdateAiPreset 更新预设
|
||||
func (s *AiPresetService) UpdateAiPreset(preset *app.AiPreset, userID uint) error {
|
||||
return global.GVA_DB.Where("user_id = ?", userID).Updates(preset).Error
|
||||
}
|
||||
|
||||
// GetAiPreset 获取AI预设详情
|
||||
func (s *AiPresetService) GetAiPreset(id uint, userId uint) (preset app.AiPreset, err error) {
|
||||
// 如果 userId 为 0(未登录),只能获取公开的预设
|
||||
if userId == 0 {
|
||||
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
|
||||
} else {
|
||||
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
|
||||
}
|
||||
return preset, err
|
||||
// GetAiPreset 查询预设
|
||||
func (s *AiPresetService) GetAiPreset(id uint, userID uint) (preset app.AiPreset, err error) {
|
||||
err = global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error
|
||||
return
|
||||
}
|
||||
|
||||
// GetAiPresetList 获取AI预设列表
|
||||
func (s *AiPresetService) GetAiPresetList(userId uint, info req.PageInfo) (list []app.AiPreset, total int64, err error) {
|
||||
// GetAiPresetList 获取预设列表
|
||||
func (s *AiPresetService) GetAiPresetList(info request.PageInfo, userID uint) (list []app.AiPreset, total int64, err error) {
|
||||
limit := info.PageSize
|
||||
offset := info.PageSize * (info.Page - 1)
|
||||
db := global.GVA_DB.Model(&app.AiPreset{})
|
||||
|
||||
// 如果 userId 为 0(未登录),只返回公开的预设
|
||||
if userId == 0 {
|
||||
db = db.Where("is_public = ?", true)
|
||||
} else {
|
||||
db = db.Where("user_id = ? OR is_public = ?", userId, true)
|
||||
}
|
||||
|
||||
db := global.GVA_DB.Model(&app.AiPreset{}).Where("user_id = ?", userID)
|
||||
err = db.Count(&total).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = db.Limit(limit).Offset(offset).Order("created_at DESC").Find(&list).Error
|
||||
return list, total, err
|
||||
}
|
||||
|
||||
// ImportAiPreset 导入AI预设(支持SillyTavern格式)
|
||||
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
|
||||
// 解析 SillyTavern JSON 格式
|
||||
var stData map[string]interface{}
|
||||
var jsonData []byte
|
||||
|
||||
// 类型断言处理 req.Data
|
||||
switch v := req.Data.(type) {
|
||||
case string:
|
||||
jsonData = []byte(v)
|
||||
case []byte:
|
||||
jsonData = v
|
||||
default:
|
||||
return preset, fmt.Errorf("不支持的数据类型")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &stData); err != nil {
|
||||
return preset, fmt.Errorf("JSON 解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 提取基本信息
|
||||
preset = app.AiPreset{
|
||||
UserID: userId,
|
||||
Name: req.Name,
|
||||
Description: getStringValue(stData, "description"),
|
||||
IsPublic: false,
|
||||
}
|
||||
|
||||
// 提取参数
|
||||
if temp, ok := stData["temperature"].(float64); ok {
|
||||
preset.Temperature = temp
|
||||
}
|
||||
if topP, ok := stData["top_p"].(float64); ok {
|
||||
preset.TopP = topP
|
||||
}
|
||||
if maxTokens, ok := stData["openai_max_tokens"].(float64); ok {
|
||||
preset.MaxTokens = int(maxTokens)
|
||||
} else if maxTokens, ok := stData["max_tokens"].(float64); ok {
|
||||
preset.MaxTokens = int(maxTokens)
|
||||
}
|
||||
if freqPenalty, ok := stData["frequency_penalty"].(float64); ok {
|
||||
preset.FrequencyPenalty = freqPenalty
|
||||
}
|
||||
if presPenalty, ok := stData["presence_penalty"].(float64); ok {
|
||||
preset.PresencePenalty = presPenalty
|
||||
}
|
||||
if stream, ok := stData["stream_openai"].(bool); ok {
|
||||
preset.StreamEnabled = stream
|
||||
}
|
||||
|
||||
// 提取提示词
|
||||
prompts := make([]app.Prompt, 0)
|
||||
if promptsData, ok := stData["prompts"].([]interface{}); ok {
|
||||
for i, p := range promptsData {
|
||||
if promptMap, ok := p.(map[string]interface{}); ok {
|
||||
prompt := app.Prompt{
|
||||
Name: getStringValue(promptMap, "name"),
|
||||
Role: getStringValue(promptMap, "role"),
|
||||
Content: getStringValue(promptMap, "content"),
|
||||
SystemPrompt: getBoolValue(promptMap, "system_prompt"),
|
||||
Marker: getBoolValue(promptMap, "marker"),
|
||||
InjectionOrder: i,
|
||||
InjectionDepth: int(getFloatValue(promptMap, "injection_depth")),
|
||||
InjectionPosition: int(getFloatValue(promptMap, "injection_position")),
|
||||
}
|
||||
prompts = append(prompts, prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
preset.Prompts = prompts
|
||||
|
||||
// 提取正则脚本
|
||||
regexScripts := make([]app.RegexScript, 0)
|
||||
if extensions, ok := stData["extensions"].(map[string]interface{}); ok {
|
||||
if scripts, ok := extensions["regex_scripts"].([]interface{}); ok {
|
||||
for _, s := range scripts {
|
||||
if scriptMap, ok := s.(map[string]interface{}); ok {
|
||||
script := app.RegexScript{
|
||||
ScriptName: getStringValue(scriptMap, "script_name"),
|
||||
FindRegex: getStringValue(scriptMap, "find_regex"),
|
||||
ReplaceString: getStringValue(scriptMap, "replace_string"),
|
||||
Disabled: getBoolValue(scriptMap, "disabled"),
|
||||
Placement: getIntArray(scriptMap, "placement"),
|
||||
}
|
||||
regexScripts = append(regexScripts, script)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
preset.RegexScripts = regexScripts
|
||||
|
||||
// 保存到数据库
|
||||
err = global.GVA_DB.Create(&preset).Error
|
||||
return preset, err
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func getStringValue(m map[string]interface{}, key string) string {
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getBoolValue(m map[string]interface{}, key string) bool {
|
||||
if v, ok := m[key].(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getFloatValue(m map[string]interface{}, key string) float64 {
|
||||
if v, ok := m[key].(float64); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getIntArray(m map[string]interface{}, key string) []int {
|
||||
result := make([]int, 0)
|
||||
if arr, ok := m[key].([]interface{}); ok {
|
||||
for _, v := range arr {
|
||||
if num, ok := v.(float64); ok {
|
||||
result = append(result, int(num))
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ExportAiPreset 导出AI预设
|
||||
func (s *AiPresetService) ExportAiPreset(id uint, userId uint) (data map[string]interface{}, err error) {
|
||||
var preset app.AiPreset
|
||||
|
||||
// 如果 userId 为 0(未登录),只能导出公开的预设
|
||||
if userId == 0 {
|
||||
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
|
||||
} else {
|
||||
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data = map[string]interface{}{
|
||||
"prompts": preset.Prompts,
|
||||
"extensions": map[string]interface{}{
|
||||
"regex_scripts": preset.RegexScripts,
|
||||
},
|
||||
"temperature": preset.Temperature,
|
||||
"top_p": preset.TopP,
|
||||
"openai_max_tokens": preset.MaxTokens,
|
||||
"frequency_penalty": preset.FrequencyPenalty,
|
||||
"presence_penalty": preset.PresencePenalty,
|
||||
"stream_openai": preset.StreamEnabled,
|
||||
}
|
||||
|
||||
return data, nil
|
||||
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,154 +1,43 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||||
"gorm.io/gorm"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||||
)
|
||||
|
||||
type PresetBindingService struct{}
|
||||
type AiPresetBindingService struct{}
|
||||
|
||||
// CreateBinding 创建预设绑定
|
||||
func (s *PresetBindingService) CreateBinding(req *request.CreateBindingRequest) error {
|
||||
// 检查预设是否存在
|
||||
var preset app.AiPreset
|
||||
if err := global.GVA_DB.First(&preset, req.PresetID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("预设不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查提供商是否存在
|
||||
var provider app.AiProvider
|
||||
if err := global.GVA_DB.First(&provider, req.ProviderID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否已存在相同的绑定
|
||||
var count int64
|
||||
err := global.GVA_DB.Model(&app.AiPresetBinding{}).
|
||||
Where("preset_id = ? AND provider_id = ?", req.PresetID, req.ProviderID).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return errors.New("该绑定已存在")
|
||||
}
|
||||
|
||||
binding := app.AiPresetBinding{
|
||||
PresetID: req.PresetID,
|
||||
ProviderID: req.ProviderID,
|
||||
Priority: req.Priority,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
return global.GVA_DB.Create(&binding).Error
|
||||
// CreateAiPresetBinding 创建绑定
|
||||
func (s *AiPresetBindingService) CreateAiPresetBinding(binding *app.AiPresetBinding) error {
|
||||
return global.GVA_DB.Create(binding).Error
|
||||
}
|
||||
|
||||
// UpdateBinding 更新预设绑定
|
||||
func (s *PresetBindingService) UpdateBinding(req *request.UpdateBindingRequest) error {
|
||||
var binding app.AiPresetBinding
|
||||
if err := global.GVA_DB.First(&binding, req.ID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("绑定不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"priority": req.Priority,
|
||||
"is_active": req.IsActive,
|
||||
}
|
||||
|
||||
return global.GVA_DB.Model(&binding).Updates(updates).Error
|
||||
// DeleteAiPresetBinding 删除绑定
|
||||
func (s *AiPresetBindingService) DeleteAiPresetBinding(id uint, userID uint) error {
|
||||
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiPresetBinding{}).Error
|
||||
}
|
||||
|
||||
// DeleteBinding 删除预设绑定
|
||||
func (s *PresetBindingService) DeleteBinding(id uint) error {
|
||||
var binding app.AiPresetBinding
|
||||
if err := global.GVA_DB.First(&binding, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("绑定不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Delete(&binding).Error
|
||||
// UpdateAiPresetBinding 更新绑定
|
||||
func (s *AiPresetBindingService) UpdateAiPresetBinding(binding *app.AiPresetBinding, userID uint) error {
|
||||
return global.GVA_DB.Where("user_id = ?", userID).Updates(binding).Error
|
||||
}
|
||||
|
||||
// GetBindingList 获取绑定列表
|
||||
func (s *PresetBindingService) GetBindingList(req *request.GetBindingListRequest) (list []response.BindingInfo, total int64, err error) {
|
||||
db := global.GVA_DB.Model(&app.AiPresetBinding{})
|
||||
// GetAiPresetBinding 查询绑定
|
||||
func (s *AiPresetBindingService) GetAiPresetBinding(id uint, userID uint) (binding app.AiPresetBinding, err error) {
|
||||
err = global.GVA_DB.Preload("Preset").Preload("Provider").Where("id = ? AND user_id = ?", id, userID).First(&binding).Error
|
||||
return
|
||||
}
|
||||
|
||||
// 条件查询
|
||||
if req.ProviderID > 0 {
|
||||
db = db.Where("provider_id = ?", req.ProviderID)
|
||||
}
|
||||
if req.PresetID > 0 {
|
||||
db = db.Where("preset_id = ?", req.PresetID)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
// GetAiPresetBindingList 获取绑定列表
|
||||
func (s *AiPresetBindingService) GetAiPresetBindingList(info request.PageInfo, userID uint) (list []app.AiPresetBinding, total int64, err error) {
|
||||
limit := info.PageSize
|
||||
offset := info.PageSize * (info.Page - 1)
|
||||
db := global.GVA_DB.Model(&app.AiPresetBinding{}).Preload("Preset").Preload("Provider").Where("user_id = ?", userID)
|
||||
err = db.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
if req.Page > 0 && req.PageSize > 0 {
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
db = db.Offset(offset).Limit(req.PageSize)
|
||||
}
|
||||
|
||||
var bindings []app.AiPresetBinding
|
||||
err = db.Preload("Preset").Preload("Provider").Order("priority ASC, created_at DESC").Find(&bindings).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
list = make([]response.BindingInfo, len(bindings))
|
||||
for i, binding := range bindings {
|
||||
list[i] = response.BindingInfo{
|
||||
ID: binding.ID,
|
||||
PresetID: binding.PresetID,
|
||||
PresetName: binding.Preset.Name,
|
||||
ProviderID: binding.ProviderID,
|
||||
ProviderName: binding.Provider.Name,
|
||||
Priority: binding.Priority,
|
||||
IsActive: binding.IsActive,
|
||||
CreatedAt: binding.CreatedAt,
|
||||
UpdatedAt: binding.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
return list, total, nil
|
||||
}
|
||||
|
||||
// GetBindingsByProvider 根据提供商获取绑定的预设列表
|
||||
func (s *PresetBindingService) GetBindingsByProvider(providerID uint) ([]app.AiPreset, error) {
|
||||
var bindings []app.AiPresetBinding
|
||||
err := global.GVA_DB.Where("provider_id = ? AND is_active = ?", providerID, true).
|
||||
Preload("Preset").
|
||||
Order("priority ASC").
|
||||
Find(&bindings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
presets := make([]app.AiPreset, len(bindings))
|
||||
for i, binding := range bindings {
|
||||
presets[i] = binding.Preset
|
||||
}
|
||||
|
||||
return presets, nil
|
||||
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
|
||||
return
|
||||
}
|
||||
|
||||
305
server/service/app/ai_preset_injector.go
Normal file
305
server/service/app/ai_preset_injector.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
)
|
||||
|
||||
// PresetInjector 预设注入器
|
||||
type PresetInjector struct {
|
||||
preset *app.AiPreset
|
||||
}
|
||||
|
||||
// NewPresetInjector 创建预设注入器
|
||||
func NewPresetInjector(preset *app.AiPreset) *PresetInjector {
|
||||
return &PresetInjector{preset: preset}
|
||||
}
|
||||
|
||||
// InjectMessages 注入预设到消息列表
|
||||
func (p *PresetInjector) InjectMessages(messages []request.ChatMessage) []request.ChatMessage {
|
||||
if p.preset == nil || len(p.preset.Prompts) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 1. 应用用户输入前的正则替换
|
||||
messages = p.applyRegexScripts(messages, 1)
|
||||
|
||||
// 2. 获取启用的提示词并排序
|
||||
enabledPrompts := p.getEnabledPrompts()
|
||||
|
||||
// 3. 构建注入后的消息列表
|
||||
injectedMessages := p.buildInjectedMessages(messages, enabledPrompts)
|
||||
|
||||
return injectedMessages
|
||||
}
|
||||
|
||||
// getEnabledPrompts 获取启用的提示词并按注入顺序排序
|
||||
func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
|
||||
var prompts []app.PresetPrompt
|
||||
for _, prompt := range p.preset.Prompts {
|
||||
if prompt.Enabled && !prompt.Marker {
|
||||
prompts = append(prompts, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// 按 injection_order 和 injection_depth 排序
|
||||
sort.Slice(prompts, func(i, j int) bool {
|
||||
if prompts[i].InjectionOrder != prompts[j].InjectionOrder {
|
||||
return prompts[i].InjectionOrder < prompts[j].InjectionOrder
|
||||
}
|
||||
return prompts[i].InjectionDepth < prompts[j].InjectionDepth
|
||||
})
|
||||
|
||||
return prompts
|
||||
}
|
||||
|
||||
// buildInjectedMessages 构建注入后的消息列表
|
||||
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
|
||||
result := make([]request.ChatMessage, 0)
|
||||
|
||||
// 分离系统提示词和对话消息
|
||||
var systemPrompts []app.PresetPrompt
|
||||
var otherPrompts []app.PresetPrompt
|
||||
|
||||
for _, prompt := range prompts {
|
||||
if prompt.Role == "system" {
|
||||
systemPrompts = append(systemPrompts, prompt)
|
||||
} else {
|
||||
otherPrompts = append(otherPrompts, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// 1. 先添加系统提示词
|
||||
for _, prompt := range systemPrompts {
|
||||
result = append(result, request.ChatMessage{
|
||||
Role: "system",
|
||||
Content: p.processPromptContent(prompt.Content),
|
||||
})
|
||||
}
|
||||
|
||||
// 2. 处理对话历史注入
|
||||
chatHistoryIndex := p.findMarkerIndex("chatHistory")
|
||||
if chatHistoryIndex >= 0 {
|
||||
// 在 chatHistory 标记位置注入原始消息
|
||||
result = append(result, messages...)
|
||||
} else {
|
||||
// 如果没有 chatHistory 标记,直接添加到末尾
|
||||
result = append(result, messages...)
|
||||
}
|
||||
|
||||
// 3. 添加其他角色的提示词(assistant等)
|
||||
for _, prompt := range otherPrompts {
|
||||
result = append(result, request.ChatMessage{
|
||||
Role: prompt.Role,
|
||||
Content: p.processPromptContent(prompt.Content),
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// findMarkerIndex 查找标记位置
|
||||
func (p *PresetInjector) findMarkerIndex(identifier string) int {
|
||||
for i, prompt := range p.preset.Prompts {
|
||||
if prompt.Identifier == identifier && prompt.Marker {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// processPromptContent 处理提示词内容(变量替换等)
|
||||
func (p *PresetInjector) processPromptContent(content string) string {
|
||||
// 处理 {{user}} 和 {{char}} 等变量
|
||||
content = strings.ReplaceAll(content, "{{user}}", "User")
|
||||
content = strings.ReplaceAll(content, "{{char}}", "Assistant")
|
||||
|
||||
// 处理 {{getvar::key}} 语法
|
||||
getvarRegex := regexp.MustCompile(`\{\{getvar::(\w+)\}\}`)
|
||||
content = getvarRegex.ReplaceAllString(content, "")
|
||||
|
||||
// 处理 {{setvar::key::value}} 语法
|
||||
setvarRegex := regexp.MustCompile(`\{\{setvar::(\w+)::(.*?)\}\}`)
|
||||
content = setvarRegex.ReplaceAllString(content, "")
|
||||
|
||||
// 处理注释 {{//...}}
|
||||
commentRegex := regexp.MustCompile(`\{\{//.*?\}\}`)
|
||||
content = commentRegex.ReplaceAllString(content, "")
|
||||
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
// applyRegexScripts 应用正则替换脚本
|
||||
func (p *PresetInjector) applyRegexScripts(messages []request.ChatMessage, placement int) []request.ChatMessage {
|
||||
if p.preset.Extensions.RegexBinding == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查 placement
|
||||
hasPlacement := false
|
||||
for _, p := range script.Placement {
|
||||
if p == placement {
|
||||
hasPlacement = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasPlacement {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用正则替换
|
||||
messages = p.applyRegexScript(messages, script)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// applyRegexScript 应用单个正则脚本
|
||||
func (p *PresetInjector) applyRegexScript(messages []request.ChatMessage, script app.RegexScript) []request.ChatMessage {
|
||||
// 解析正则表达式
|
||||
pattern := script.FindRegex
|
||||
// 移除正则标志(如 /pattern/g)
|
||||
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
|
||||
pattern = pattern[1 : len(pattern)-2]
|
||||
} else if strings.HasPrefix(pattern, "/") {
|
||||
lastSlash := strings.LastIndex(pattern, "/")
|
||||
if lastSlash > 0 {
|
||||
pattern = pattern[1:lastSlash]
|
||||
}
|
||||
}
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 对每条消息应用替换
|
||||
for i := range messages {
|
||||
if script.PromptOnly && messages[i].Role != "user" {
|
||||
continue
|
||||
}
|
||||
|
||||
messages[i].Content = re.ReplaceAllString(messages[i].Content, script.ReplaceString)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// ProcessResponse 处理AI响应(应用输出后的正则)
|
||||
func (p *PresetInjector) ProcessResponse(content string) string {
|
||||
if p.preset == nil || p.preset.Extensions.RegexBinding == nil {
|
||||
return content
|
||||
}
|
||||
|
||||
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否应用于输出(placement=2)
|
||||
hasPlacement := false
|
||||
for _, placement := range script.Placement {
|
||||
if placement == 2 {
|
||||
hasPlacement = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasPlacement {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析正则表达式
|
||||
pattern := script.FindRegex
|
||||
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
|
||||
pattern = pattern[1 : len(pattern)-2]
|
||||
} else if strings.HasPrefix(pattern, "/") {
|
||||
lastSlash := strings.LastIndex(pattern, "/")
|
||||
if lastSlash > 0 {
|
||||
pattern = pattern[1:lastSlash]
|
||||
}
|
||||
}
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content = re.ReplaceAllString(content, script.ReplaceString)
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// ApplyPresetParameters 应用预设参数到请求
|
||||
func (p *PresetInjector) ApplyPresetParameters(req *request.ChatCompletionRequest) {
|
||||
if p.preset == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果请求中没有指定参数,使用预设的参数
|
||||
if req.Temperature == nil && p.preset.Temperature > 0 {
|
||||
temp := p.preset.Temperature
|
||||
req.Temperature = &temp
|
||||
}
|
||||
|
||||
if req.TopP == nil && p.preset.TopP > 0 {
|
||||
topP := p.preset.TopP
|
||||
req.TopP = &topP
|
||||
}
|
||||
|
||||
if req.MaxTokens == nil && p.preset.MaxTokens > 0 {
|
||||
maxTokens := p.preset.MaxTokens
|
||||
req.MaxTokens = &maxTokens
|
||||
}
|
||||
|
||||
if req.PresencePenalty == nil && p.preset.PresencePenalty != 0 {
|
||||
pp := p.preset.PresencePenalty
|
||||
req.PresencePenalty = &pp
|
||||
}
|
||||
|
||||
if req.FrequencyPenalty == nil && p.preset.FrequencyPenalty != 0 {
|
||||
fp := p.preset.FrequencyPenalty
|
||||
req.FrequencyPenalty = &fp
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePreset 验证预设配置
|
||||
func ValidatePreset(preset *app.AiPreset) error {
|
||||
if preset == nil {
|
||||
return fmt.Errorf("预设不能为空")
|
||||
}
|
||||
|
||||
if preset.Name == "" {
|
||||
return fmt.Errorf("预设名称不能为空")
|
||||
}
|
||||
|
||||
// 验证正则表达式
|
||||
if preset.Extensions.RegexBinding != nil {
|
||||
for _, script := range preset.Extensions.RegexBinding.Regexes {
|
||||
pattern := script.FindRegex
|
||||
if strings.HasPrefix(pattern, "/") {
|
||||
lastSlash := strings.LastIndex(pattern, "/")
|
||||
if lastSlash > 0 {
|
||||
pattern = pattern[1:lastSlash]
|
||||
}
|
||||
}
|
||||
|
||||
_, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("正则表达式 '%s' 无效: %v", script.ScriptName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,230 +1,43 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||||
)
|
||||
|
||||
type AiProviderService struct{}
|
||||
|
||||
// CreateAiProvider 创建AI提供商
|
||||
func (s *AiProviderService) CreateAiProvider(req *request.CreateAiProviderRequest) (provider app.AiProvider, err error) {
|
||||
provider = app.AiProvider{
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
BaseURL: req.BaseURL,
|
||||
Endpoint: req.Endpoint,
|
||||
UpstreamKey: req.UpstreamKey,
|
||||
Model: req.Model,
|
||||
ProxyKey: req.ProxyKey,
|
||||
Config: req.Config,
|
||||
IsActive: req.IsActive,
|
||||
}
|
||||
err = global.GVA_DB.Create(&provider).Error
|
||||
return provider, err
|
||||
// CreateAiProvider 创建提供商
|
||||
func (s *AiProviderService) CreateAiProvider(provider *app.AiProvider) error {
|
||||
return global.GVA_DB.Create(provider).Error
|
||||
}
|
||||
|
||||
// DeleteAiProvider 删除AI提供商
|
||||
func (s *AiProviderService) DeleteAiProvider(id uint) (err error) {
|
||||
return global.GVA_DB.Delete(&app.AiProvider{}, id).Error
|
||||
// DeleteAiProvider 删除提供商
|
||||
func (s *AiProviderService) DeleteAiProvider(id uint, userID uint) error {
|
||||
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiProvider{}).Error
|
||||
}
|
||||
|
||||
// UpdateAiProvider 更新AI提供商
|
||||
func (s *AiProviderService) UpdateAiProvider(req *request.UpdateAiProviderRequest) (provider app.AiProvider, err error) {
|
||||
err = global.GVA_DB.First(&provider, req.ID).Error
|
||||
// UpdateAiProvider 更新提供商
|
||||
func (s *AiProviderService) UpdateAiProvider(provider *app.AiProvider, userID uint) error {
|
||||
return global.GVA_DB.Where("user_id = ?", userID).Updates(provider).Error
|
||||
}
|
||||
|
||||
// GetAiProvider 查询提供商
|
||||
func (s *AiProviderService) GetAiProvider(id uint, userID uint) (provider app.AiProvider, err error) {
|
||||
err = global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&provider).Error
|
||||
return
|
||||
}
|
||||
|
||||
// GetAiProviderList 获取提供商列表
|
||||
func (s *AiProviderService) GetAiProviderList(info request.PageInfo, userID uint) (list []app.AiProvider, total int64, err error) {
|
||||
limit := info.PageSize
|
||||
offset := info.PageSize * (info.Page - 1)
|
||||
db := global.GVA_DB.Model(&app.AiProvider{}).Where("user_id = ?", userID)
|
||||
err = db.Count(&total).Error
|
||||
if err != nil {
|
||||
return provider, err
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != "" {
|
||||
provider.Name = req.Name
|
||||
}
|
||||
if req.Type != "" {
|
||||
provider.Type = req.Type
|
||||
}
|
||||
if req.BaseURL != "" {
|
||||
provider.BaseURL = req.BaseURL
|
||||
}
|
||||
if req.Endpoint != "" {
|
||||
provider.Endpoint = req.Endpoint
|
||||
}
|
||||
if req.UpstreamKey != "" {
|
||||
provider.UpstreamKey = req.UpstreamKey
|
||||
}
|
||||
if req.Model != "" {
|
||||
provider.Model = req.Model
|
||||
}
|
||||
if req.ProxyKey != "" {
|
||||
provider.ProxyKey = req.ProxyKey
|
||||
}
|
||||
if req.Config != nil {
|
||||
provider.Config = req.Config
|
||||
}
|
||||
provider.IsActive = req.IsActive
|
||||
|
||||
err = global.GVA_DB.Save(&provider).Error
|
||||
return provider, err
|
||||
}
|
||||
|
||||
// GetAiProvider 获取AI提供商详情
|
||||
func (s *AiProviderService) GetAiProvider(id uint) (provider app.AiProvider, err error) {
|
||||
err = global.GVA_DB.First(&provider, id).Error
|
||||
return provider, err
|
||||
}
|
||||
|
||||
// GetAiProviderList 获取AI提供商列表
|
||||
func (s *AiProviderService) GetAiProviderList() (list []app.AiProvider, err error) {
|
||||
err = global.GVA_DB.Where("is_active = ?", true).Find(&list).Error
|
||||
return list, err
|
||||
}
|
||||
|
||||
// TestConnection 测试连接
|
||||
func (s *AiProviderService) TestConnection(req *request.TestConnectionRequest) (resp response.TestConnectionResponse, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 根据类型构建测试 URL
|
||||
var testURL string
|
||||
switch strings.ToLower(req.Type) {
|
||||
case "openai":
|
||||
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
case "claude":
|
||||
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/messages"
|
||||
default:
|
||||
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
}
|
||||
|
||||
// 创建 HTTP 请求
|
||||
httpReq, err := http.NewRequest("GET", testURL, nil)
|
||||
if err != nil {
|
||||
return response.TestConnectionResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("创建请求失败: %v", err),
|
||||
Latency: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return response.TestConnectionResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("连接失败: %v", err),
|
||||
Latency: time.Since(startTime).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
latency := time.Since(startTime).Milliseconds()
|
||||
|
||||
// 检查响应状态
|
||||
if httpResp.StatusCode == http.StatusOK || httpResp.StatusCode == http.StatusCreated {
|
||||
return response.TestConnectionResponse{
|
||||
Success: true,
|
||||
Message: "连接成功",
|
||||
Latency: latency,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 读取错误响应
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return response.TestConnectionResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("连接失败 (状态码: %d): %s", httpResp.StatusCode, string(body)),
|
||||
Latency: latency,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetModels 获取模型列表
|
||||
func (s *AiProviderService) GetModels(req *request.GetModelsRequest) (models []response.ModelInfo, err error) {
|
||||
// 根据类型构建 URL
|
||||
var modelsURL string
|
||||
switch strings.ToLower(req.Type) {
|
||||
case "openai":
|
||||
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
case "claude":
|
||||
// Claude API 不提供模型列表接口,返回预定义的模型
|
||||
return []response.ModelInfo{
|
||||
{ID: "claude-opus-4-6", Name: "Claude Opus 4.6", OwnedBy: "anthropic"},
|
||||
{ID: "claude-sonnet-4-6", Name: "Claude Sonnet 4.6", OwnedBy: "anthropic"},
|
||||
{ID: "claude-haiku-4-5-20251001", Name: "Claude Haiku 4.5", OwnedBy: "anthropic"},
|
||||
{ID: "claude-3-5-sonnet-20241022", Name: "Claude 3.5 Sonnet", OwnedBy: "anthropic"},
|
||||
{ID: "claude-3-opus-20240229", Name: "Claude 3 Opus", OwnedBy: "anthropic"},
|
||||
}, nil
|
||||
default:
|
||||
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
}
|
||||
|
||||
// 创建 HTTP 请求
|
||||
httpReq, err := http.NewRequest("GET", modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, errors.New("创建请求失败: " + err.Error())
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, errors.New("请求失败: " + err.Error())
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return nil, fmt.Errorf("获取模型列表失败 (状态码: %d): %s", httpResp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.New("读取响应失败: " + err.Error())
|
||||
}
|
||||
|
||||
// OpenAI 格式的响应
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||
return nil, errors.New("解析响应失败: " + err.Error())
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
models = make([]response.ModelInfo, len(modelsResp.Data))
|
||||
for i, model := range modelsResp.Data {
|
||||
models[i] = response.ModelInfo{
|
||||
ID: model.ID,
|
||||
Name: model.ID,
|
||||
OwnedBy: model.OwnedBy,
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
err = db.Limit(limit).Offset(offset).Order("priority desc, id desc").Find(&list).Error
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -16,264 +15,252 @@ import (
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AiProxyService struct{}
|
||||
|
||||
// ProcessChatCompletion 处理聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint, req *request.ChatCompletionRequest) (resp response.ChatCompletionResponse, err error) {
|
||||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userID uint, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 获取预设配置
|
||||
var preset app.AiPreset
|
||||
if req.PresetID > 0 {
|
||||
err = global.GVA_DB.First(&preset, req.PresetID).Error
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("预设不存在: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取提供商配置
|
||||
var provider app.AiProvider
|
||||
|
||||
// 根据 binding_key 或预设绑定获取 provider
|
||||
if req.BindingKey != "" {
|
||||
// 通过 binding_key 查找绑定关系
|
||||
var binding app.AiPresetBinding
|
||||
err = global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
|
||||
Order("priority ASC").
|
||||
First(&binding).Error
|
||||
if err == nil {
|
||||
err = global.GVA_DB.First(&provider, binding.ProviderID).Error
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到,使用默认的活跃提供商
|
||||
if provider.ID == 0 {
|
||||
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建注入后的消息
|
||||
messages, err := s.buildInjectedMessages(req, &preset)
|
||||
// 1. 获取绑定配置
|
||||
binding, err := s.getBinding(userID, req)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("构建消息失败: %w", err)
|
||||
return nil, fmt.Errorf("获取绑定配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 转发到上游AI
|
||||
resp, err = s.forwardToAI(ctx, &provider, &preset, messages)
|
||||
// 2. 注入预设
|
||||
injector := NewPresetInjector(&binding.Preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
|
||||
// 3. 转发请求到上游
|
||||
resp, err := s.forwardRequest(ctx, &binding.Provider, req)
|
||||
if err != nil {
|
||||
// 记录失败日志
|
||||
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, "", err, time.Since(startTime))
|
||||
return resp, err
|
||||
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. 应用输出正则脚本
|
||||
resp.Choices[0].Message.Content = s.applyOutputRegex(resp.Choices[0].Message.Content, preset.RegexScripts)
|
||||
// 4. 处理响应
|
||||
if len(resp.Choices) > 0 {
|
||||
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
// 6. 记录成功日志
|
||||
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, resp.Choices[0].Message.Content, nil, time.Since(startTime))
|
||||
// 5. 记录日志
|
||||
s.logRequest(userID, binding, req, resp, nil, time.Since(startTime))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// buildInjectedMessages 构建注入预设后的消息数组
|
||||
func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionRequest, preset *app.AiPreset) ([]request.Message, error) {
|
||||
if preset == nil || preset.ID == 0 {
|
||||
return req.Messages, nil
|
||||
}
|
||||
// ProcessChatCompletionStream 处理流式聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint, req *request.ChatCompletionRequest) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 按 injection_order 排序 prompts
|
||||
sortedPrompts := make([]app.Prompt, len(preset.Prompts))
|
||||
copy(sortedPrompts, preset.Prompts)
|
||||
sort.Slice(sortedPrompts, func(i, j int) bool {
|
||||
return sortedPrompts[i].InjectionOrder < sortedPrompts[j].InjectionOrder
|
||||
})
|
||||
|
||||
messages := make([]request.Message, 0)
|
||||
|
||||
// 2. 根据 injection_depth 插入到对话历史中
|
||||
for _, prompt := range sortedPrompts {
|
||||
if prompt.Marker {
|
||||
continue // 跳过标记提示词
|
||||
}
|
||||
|
||||
// 替换变量
|
||||
content := s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard)
|
||||
|
||||
// 根据 injection_depth 决定插入位置
|
||||
// depth=0: 插入到最前面(系统提示词)
|
||||
// depth>0: 从对话历史末尾往前数 depth 条消息的位置插入
|
||||
if prompt.InjectionDepth == 0 || prompt.SystemPrompt {
|
||||
messages = append(messages, request.Message{
|
||||
Role: prompt.Role,
|
||||
Content: content,
|
||||
})
|
||||
} else {
|
||||
// 先添加用户消息,稍后根据 depth 插入
|
||||
// 这里简化处理,将非系统提示词也添加到前面
|
||||
messages = append(messages, request.Message{
|
||||
Role: prompt.Role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 添加用户消息
|
||||
messages = append(messages, req.Messages...)
|
||||
|
||||
// 4. 应用输入正则脚本 (placement=1)
|
||||
for i := range messages {
|
||||
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// replaceVariables 替换变量
|
||||
func (s *AiProxyService) replaceVariables(content string, vars map[string]string, card *request.CharacterCard) string {
|
||||
result := content
|
||||
|
||||
// 替换自定义变量
|
||||
for key, value := range vars {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
result = replaceAll(result, placeholder, value)
|
||||
}
|
||||
|
||||
// 替换角色卡片变量
|
||||
if card != nil {
|
||||
result = replaceAll(result, "{{char}}", card.Name)
|
||||
result = replaceAll(result, "{{char_name}}", card.Name)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// applyInputRegex 应用输入正则脚本
|
||||
func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScript) string {
|
||||
for _, script := range scripts {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
if !containsPlacement(script.Placement, 1) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 编译正则表达式
|
||||
re, err := regexp.Compile(script.FindRegex)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行替换
|
||||
content = re.ReplaceAllString(content, script.ReplaceString)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// applyOutputRegex 应用输出正则脚本
|
||||
func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScript) string {
|
||||
for _, script := range scripts {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
if !containsPlacement(script.Placement, 2) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 编译正则表达式
|
||||
re, err := regexp.Compile(script.FindRegex)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行替换
|
||||
content = re.ReplaceAllString(content, script.ReplaceString)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// forwardToAI 转发请求到上游AI
|
||||
func (s *AiProxyService) forwardToAI(ctx context.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message) (response.ChatCompletionResponse, error) {
|
||||
// 构建请求体
|
||||
reqBody := map[string]interface{}{
|
||||
"model": provider.Model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if preset != nil {
|
||||
reqBody["temperature"] = preset.Temperature
|
||||
reqBody["top_p"] = preset.TopP
|
||||
reqBody["max_tokens"] = preset.MaxTokens
|
||||
reqBody["frequency_penalty"] = preset.FrequencyPenalty
|
||||
reqBody["presence_penalty"] = preset.PresencePenalty
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
// 1. 获取绑定配置
|
||||
binding, err := s.getBinding(userID, req)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
// 2. 注入预设
|
||||
injector := NewPresetInjector(&binding.Preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
|
||||
// 3. 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 4. 转发流式请求
|
||||
err = s.forwardStreamRequest(c, &binding.Provider, req, injector)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
|
||||
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
|
||||
// getBinding 获取绑定配置
|
||||
func (s *AiProxyService) getBinding(userID uint, req *request.ChatCompletionRequest) (*app.AiPresetBinding, error) {
|
||||
var binding app.AiPresetBinding
|
||||
|
||||
query := global.GVA_DB.Preload("Preset").Preload("Provider").Where("user_id = ? AND enabled = ?", userID, true)
|
||||
|
||||
// 优先使用 binding_name
|
||||
if req.BindingName != "" {
|
||||
query = query.Where("name = ?", req.BindingName)
|
||||
} else if req.PresetName != "" && req.ProviderName != "" {
|
||||
// 使用 preset_name 和 provider_name
|
||||
query = query.Joins("JOIN ai_presets ON ai_presets.id = ai_preset_bindings.preset_id").
|
||||
Joins("JOIN ai_providers ON ai_providers.id = ai_preset_bindings.provider_id").
|
||||
Where("ai_presets.name = ? AND ai_providers.name = ?", req.PresetName, req.ProviderName)
|
||||
} else {
|
||||
// 使用默认绑定(第一个启用的)
|
||||
query = query.Order("id ASC")
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if provider.UpstreamKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
|
||||
if err := query.First(&binding).Error; err != nil {
|
||||
return nil, fmt.Errorf("未找到可用的绑定配置")
|
||||
}
|
||||
|
||||
if !binding.Provider.Enabled {
|
||||
return nil, fmt.Errorf("提供商已禁用")
|
||||
}
|
||||
|
||||
if !binding.Preset.Enabled {
|
||||
return nil, fmt.Errorf("预设已禁用")
|
||||
}
|
||||
|
||||
return &binding, nil
|
||||
}
|
||||
|
||||
// forwardRequest 转发请求到上游 AI 服务
|
||||
func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||||
// 使用提供商的默认模型(如果请求中没有指定)
|
||||
if req.Model == "" && provider.Model != "" {
|
||||
req.Model = provider.Model
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return response.ChatCompletionResponse{}, fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var aiResp response.ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &aiResp); err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
var resp response.ChatCompletionResponse
|
||||
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
return aiResp, nil
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// forwardStreamRequest 转发流式请求
|
||||
func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error {
|
||||
// 使用提供商的默认模型
|
||||
if req.Model == "" && provider.Model != "" {
|
||||
req.Model = provider.Model
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 读取并转发流式响应
|
||||
reader := bufio.NewReader(httpResp.Body)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return fmt.Errorf("不支持流式响应")
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 跳过空行
|
||||
if len(bytes.TrimSpace(line)) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理 SSE 数据
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||||
data = bytes.TrimSpace(data)
|
||||
|
||||
// 检查是否是结束标记
|
||||
if string(data) == "[DONE]" {
|
||||
c.Writer.Write([]byte("data: [DONE]\n\n"))
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 解析并处理响应
|
||||
var chunk response.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用输出正则处理
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
chunk.Choices[0].Delta.Content = injector.ProcessResponse(chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
// 重新序列化并发送
|
||||
processedData, _ := json.Marshal(chunk)
|
||||
c.Writer.Write([]byte("data: "))
|
||||
c.Writer.Write(processedData)
|
||||
c.Writer.Write([]byte("\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logRequest 记录请求日志
|
||||
func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider *app.AiProvider, originalMsg, responseText string, err error, latency time.Duration) {
|
||||
func (s *AiProxyService) logRequest(userID uint, binding *app.AiPresetBinding, req *request.ChatCompletionRequest, resp *response.ChatCompletionResponse, err error, duration time.Duration) {
|
||||
log := app.AiRequestLog{
|
||||
UserID: &userId,
|
||||
OriginalMessage: originalMsg,
|
||||
ResponseText: responseText,
|
||||
LatencyMs: int(latency.Milliseconds()),
|
||||
}
|
||||
|
||||
if preset != nil {
|
||||
presetID := preset.ID
|
||||
log.PresetID = &presetID
|
||||
}
|
||||
|
||||
if provider != nil {
|
||||
providerID := provider.ID
|
||||
log.ProviderID = &providerID
|
||||
UserID: userID,
|
||||
BindingID: binding.ID,
|
||||
ProviderID: binding.ProviderID,
|
||||
PresetID: binding.PresetID,
|
||||
Model: req.Model,
|
||||
Duration: duration.Milliseconds(),
|
||||
RequestTime: time.Now(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -281,21 +268,12 @@ func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider
|
||||
log.ErrorMessage = err.Error()
|
||||
} else {
|
||||
log.Status = "success"
|
||||
if resp != nil {
|
||||
log.PromptTokens = resp.Usage.PromptTokens
|
||||
log.CompletionTokens = resp.Usage.CompletionTokens
|
||||
log.TotalTokens = resp.Usage.TotalTokens
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_DB.Create(&log)
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func replaceAll(s, old, new string) string {
|
||||
return strings.ReplaceAll(s, old, new)
|
||||
}
|
||||
|
||||
func containsPlacement(placements []int, target int) bool {
|
||||
for _, p := range placements {
|
||||
if p == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProcessChatCompletionStream 处理流式聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userId uint, req *request.ChatCompletionRequest) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 获取预设配置
|
||||
var preset app.AiPreset
|
||||
if req.PresetID > 0 {
|
||||
err := global.GVA_DB.First(&preset, req.PresetID).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "预设不存在"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取提供商配置
|
||||
var provider app.AiProvider
|
||||
if req.BindingKey != "" {
|
||||
var binding app.AiPresetBinding
|
||||
err := global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
|
||||
Order("priority ASC").
|
||||
First(&binding).Error
|
||||
if err == nil {
|
||||
global.GVA_DB.First(&provider, binding.ProviderID)
|
||||
}
|
||||
}
|
||||
|
||||
if provider.ID == 0 {
|
||||
err := global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到可用的AI提供商"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建注入后的消息
|
||||
messages, err := s.buildInjectedMessages(req, &preset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "构建消息失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 转发流式请求到上游AI
|
||||
err = s.forwardStreamToAI(c, &provider, &preset, messages, userId, startTime)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// forwardStreamToAI 转发流式请求到上游AI
|
||||
func (s *AiProxyService) forwardStreamToAI(c *gin.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message, userId uint, startTime time.Time) error {
|
||||
// 构建请求体
|
||||
reqBody := map[string]interface{}{
|
||||
"model": provider.Model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
if preset != nil {
|
||||
reqBody["temperature"] = preset.Temperature
|
||||
reqBody["top_p"] = preset.TopP
|
||||
reqBody["max_tokens"] = preset.MaxTokens
|
||||
reqBody["frequency_penalty"] = preset.FrequencyPenalty
|
||||
reqBody["presence_penalty"] = preset.PresencePenalty
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if provider.UpstreamKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{Timeout: 300 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// 读取并转发流式响应
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return fmt.Errorf("streaming not supported")
|
||||
}
|
||||
|
||||
var fullResponse strings.Builder
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 跳过空行
|
||||
if len(bytes.TrimSpace(line)) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析SSE数据
|
||||
lineStr := string(line)
|
||||
if strings.HasPrefix(lineStr, "data: ") {
|
||||
data := strings.TrimPrefix(lineStr, "data: ")
|
||||
data = strings.TrimSpace(data)
|
||||
|
||||
// 检查是否是结束标记
|
||||
if data == "[DONE]" {
|
||||
c.Writer.Write([]byte("data: [DONE]\n\n"))
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 解析JSON并提取内容
|
||||
var chunk map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
|
||||
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||
if delta, ok := choice["delta"].(map[string]interface{}); ok {
|
||||
if content, ok := delta["content"].(string); ok {
|
||||
fullResponse.WriteString(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 转发原始数据
|
||||
c.Writer.Write(line)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// 应用输出正则脚本
|
||||
finalContent := fullResponse.String()
|
||||
if preset != nil {
|
||||
finalContent = s.applyOutputRegex(finalContent, preset.RegexScripts)
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
var originalMsg string
|
||||
if len(messages) > 0 {
|
||||
originalMsg = messages[len(messages)-1].Content
|
||||
}
|
||||
s.logRequest(userId, preset, provider, originalMsg, finalContent, nil, time.Since(startTime))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package app
|
||||
|
||||
type AppServiceGroup struct {
|
||||
AiPresetService AiPresetService
|
||||
AiProviderService AiProviderService
|
||||
AiProxyService AiProxyService
|
||||
PresetBindingService PresetBindingService
|
||||
AiProxyService
|
||||
AiPresetService
|
||||
AiProviderService
|
||||
AiPresetBindingService
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user