🎉 初始化项目
This commit is contained in:
147
server/service/app/ai_preset.go
Normal file
147
server/service/app/ai_preset.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
req "git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||||
)
|
||||
|
||||
type AiPresetService struct{}
|
||||
|
||||
// CreateAiPreset 创建AI预设
|
||||
func (s *AiPresetService) CreateAiPreset(userId uint, req *request.CreateAiPresetRequest) (preset app.AiPreset, err error) {
|
||||
preset = app.AiPreset{
|
||||
UserID: userId,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Prompts: req.Prompts,
|
||||
RegexScripts: req.RegexScripts,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
MaxTokens: req.MaxTokens,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
StreamEnabled: req.StreamEnabled,
|
||||
IsDefault: req.IsDefault,
|
||||
IsPublic: req.IsPublic,
|
||||
}
|
||||
err = global.GVA_DB.Create(&preset).Error
|
||||
return preset, err
|
||||
}
|
||||
|
||||
// DeleteAiPreset 删除AI预设
|
||||
func (s *AiPresetService) DeleteAiPreset(id uint, userId uint) (err error) {
|
||||
// 如果 userId 为 0(未登录),不允许删除
|
||||
if userId == 0 {
|
||||
return global.GVA_DB.Where("id = ?", id).Delete(&app.AiPreset{}).Error
|
||||
}
|
||||
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userId).Delete(&app.AiPreset{}).Error
|
||||
}
|
||||
|
||||
// UpdateAiPreset 更新AI预设
|
||||
func (s *AiPresetService) UpdateAiPreset(userId uint, req *request.UpdateAiPresetRequest) (preset app.AiPreset, err error) {
|
||||
// 如果 userId 为 0(未登录),允许更新任何预设
|
||||
if userId == 0 {
|
||||
err = global.GVA_DB.Where("id = ?", req.ID).First(&preset).Error
|
||||
} else {
|
||||
err = global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userId).First(&preset).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return preset, err
|
||||
}
|
||||
|
||||
if req.Name != "" {
|
||||
preset.Name = req.Name
|
||||
}
|
||||
preset.Description = req.Description
|
||||
if req.Prompts != nil {
|
||||
preset.Prompts = req.Prompts
|
||||
}
|
||||
if req.RegexScripts != nil {
|
||||
preset.RegexScripts = req.RegexScripts
|
||||
}
|
||||
preset.Temperature = req.Temperature
|
||||
preset.TopP = req.TopP
|
||||
preset.MaxTokens = req.MaxTokens
|
||||
preset.FrequencyPenalty = req.FrequencyPenalty
|
||||
preset.PresencePenalty = req.PresencePenalty
|
||||
preset.StreamEnabled = req.StreamEnabled
|
||||
preset.IsDefault = req.IsDefault
|
||||
preset.IsPublic = req.IsPublic
|
||||
|
||||
err = global.GVA_DB.Save(&preset).Error
|
||||
return preset, err
|
||||
}
|
||||
|
||||
// GetAiPreset 获取AI预设详情
|
||||
func (s *AiPresetService) GetAiPreset(id uint, userId uint) (preset app.AiPreset, err error) {
|
||||
// 如果 userId 为 0(未登录),只能获取公开的预设
|
||||
if userId == 0 {
|
||||
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
|
||||
} else {
|
||||
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
|
||||
}
|
||||
return preset, err
|
||||
}
|
||||
|
||||
// GetAiPresetList 获取AI预设列表
|
||||
func (s *AiPresetService) GetAiPresetList(userId uint, info req.PageInfo) (list []app.AiPreset, total int64, err error) {
|
||||
limit := info.PageSize
|
||||
offset := info.PageSize * (info.Page - 1)
|
||||
db := global.GVA_DB.Model(&app.AiPreset{})
|
||||
|
||||
// 如果 userId 为 0(未登录),只返回公开的预设
|
||||
if userId == 0 {
|
||||
db = db.Where("is_public = ?", true)
|
||||
} else {
|
||||
db = db.Where("user_id = ? OR is_public = ?", userId, true)
|
||||
}
|
||||
|
||||
err = db.Count(&total).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = db.Limit(limit).Offset(offset).Order("created_at DESC").Find(&list).Error
|
||||
return list, total, err
|
||||
}
|
||||
|
||||
// ImportAiPreset 导入AI预设(支持SillyTavern格式)
|
||||
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
|
||||
// TODO: 解析SillyTavern JSON格式
|
||||
// 这里需要实现JSON解析逻辑,将SillyTavern格式转换为我们的格式
|
||||
return preset, nil
|
||||
}
|
||||
|
||||
// ExportAiPreset 导出AI预设
|
||||
func (s *AiPresetService) ExportAiPreset(id uint, userId uint) (data map[string]interface{}, err error) {
|
||||
var preset app.AiPreset
|
||||
|
||||
// 如果 userId 为 0(未登录),只能导出公开的预设
|
||||
if userId == 0 {
|
||||
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
|
||||
} else {
|
||||
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data = map[string]interface{}{
|
||||
"prompts": preset.Prompts,
|
||||
"extensions": map[string]interface{}{
|
||||
"regex_scripts": preset.RegexScripts,
|
||||
},
|
||||
"temperature": preset.Temperature,
|
||||
"top_p": preset.TopP,
|
||||
"openai_max_tokens": preset.MaxTokens,
|
||||
"frequency_penalty": preset.FrequencyPenalty,
|
||||
"presence_penalty": preset.PresencePenalty,
|
||||
"stream_openai": preset.StreamEnabled,
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
154
server/service/app/ai_preset_binding.go
Normal file
154
server/service/app/ai_preset_binding.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type PresetBindingService struct{}
|
||||
|
||||
// CreateBinding 创建预设绑定
|
||||
func (s *PresetBindingService) CreateBinding(req *request.CreateBindingRequest) error {
|
||||
// 检查预设是否存在
|
||||
var preset app.AiPreset
|
||||
if err := global.GVA_DB.First(&preset, req.PresetID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("预设不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查提供商是否存在
|
||||
var provider app.AiProvider
|
||||
if err := global.GVA_DB.First(&provider, req.ProviderID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否已存在相同的绑定
|
||||
var count int64
|
||||
err := global.GVA_DB.Model(&app.AiPresetBinding{}).
|
||||
Where("preset_id = ? AND provider_id = ?", req.PresetID, req.ProviderID).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return errors.New("该绑定已存在")
|
||||
}
|
||||
|
||||
binding := app.AiPresetBinding{
|
||||
PresetID: req.PresetID,
|
||||
ProviderID: req.ProviderID,
|
||||
Priority: req.Priority,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
return global.GVA_DB.Create(&binding).Error
|
||||
}
|
||||
|
||||
// UpdateBinding 更新预设绑定
|
||||
func (s *PresetBindingService) UpdateBinding(req *request.UpdateBindingRequest) error {
|
||||
var binding app.AiPresetBinding
|
||||
if err := global.GVA_DB.First(&binding, req.ID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("绑定不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"priority": req.Priority,
|
||||
"is_active": req.IsActive,
|
||||
}
|
||||
|
||||
return global.GVA_DB.Model(&binding).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteBinding 删除预设绑定
|
||||
func (s *PresetBindingService) DeleteBinding(id uint) error {
|
||||
var binding app.AiPresetBinding
|
||||
if err := global.GVA_DB.First(&binding, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("绑定不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Delete(&binding).Error
|
||||
}
|
||||
|
||||
// GetBindingList 获取绑定列表
|
||||
func (s *PresetBindingService) GetBindingList(req *request.GetBindingListRequest) (list []response.BindingInfo, total int64, err error) {
|
||||
db := global.GVA_DB.Model(&app.AiPresetBinding{})
|
||||
|
||||
// 条件查询
|
||||
if req.ProviderID > 0 {
|
||||
db = db.Where("provider_id = ?", req.ProviderID)
|
||||
}
|
||||
if req.PresetID > 0 {
|
||||
db = db.Where("preset_id = ?", req.PresetID)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err = db.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
if req.Page > 0 && req.PageSize > 0 {
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
db = db.Offset(offset).Limit(req.PageSize)
|
||||
}
|
||||
|
||||
var bindings []app.AiPresetBinding
|
||||
err = db.Preload("Preset").Preload("Provider").Order("priority ASC, created_at DESC").Find(&bindings).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
list = make([]response.BindingInfo, len(bindings))
|
||||
for i, binding := range bindings {
|
||||
list[i] = response.BindingInfo{
|
||||
ID: binding.ID,
|
||||
PresetID: binding.PresetID,
|
||||
PresetName: binding.Preset.Name,
|
||||
ProviderID: binding.ProviderID,
|
||||
ProviderName: binding.Provider.Name,
|
||||
Priority: binding.Priority,
|
||||
IsActive: binding.IsActive,
|
||||
CreatedAt: binding.CreatedAt,
|
||||
UpdatedAt: binding.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
return list, total, nil
|
||||
}
|
||||
|
||||
// GetBindingsByProvider 根据提供商获取绑定的预设列表
|
||||
func (s *PresetBindingService) GetBindingsByProvider(providerID uint) ([]app.AiPreset, error) {
|
||||
var bindings []app.AiPresetBinding
|
||||
err := global.GVA_DB.Where("provider_id = ? AND is_active = ?", providerID, true).
|
||||
Preload("Preset").
|
||||
Order("priority ASC").
|
||||
Find(&bindings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
presets := make([]app.AiPreset, len(bindings))
|
||||
for i, binding := range bindings {
|
||||
presets[i] = binding.Preset
|
||||
}
|
||||
|
||||
return presets, nil
|
||||
}
|
||||
230
server/service/app/ai_provider.go
Normal file
230
server/service/app/ai_provider.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||||
)
|
||||
|
||||
type AiProviderService struct{}
|
||||
|
||||
// CreateAiProvider 创建AI提供商
|
||||
func (s *AiProviderService) CreateAiProvider(req *request.CreateAiProviderRequest) (provider app.AiProvider, err error) {
|
||||
provider = app.AiProvider{
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
BaseURL: req.BaseURL,
|
||||
Endpoint: req.Endpoint,
|
||||
UpstreamKey: req.UpstreamKey,
|
||||
Model: req.Model,
|
||||
ProxyKey: req.ProxyKey,
|
||||
Config: req.Config,
|
||||
IsActive: req.IsActive,
|
||||
}
|
||||
err = global.GVA_DB.Create(&provider).Error
|
||||
return provider, err
|
||||
}
|
||||
|
||||
// DeleteAiProvider 删除AI提供商
|
||||
func (s *AiProviderService) DeleteAiProvider(id uint) (err error) {
|
||||
return global.GVA_DB.Delete(&app.AiProvider{}, id).Error
|
||||
}
|
||||
|
||||
// UpdateAiProvider 更新AI提供商
|
||||
func (s *AiProviderService) UpdateAiProvider(req *request.UpdateAiProviderRequest) (provider app.AiProvider, err error) {
|
||||
err = global.GVA_DB.First(&provider, req.ID).Error
|
||||
if err != nil {
|
||||
return provider, err
|
||||
}
|
||||
|
||||
if req.Name != "" {
|
||||
provider.Name = req.Name
|
||||
}
|
||||
if req.Type != "" {
|
||||
provider.Type = req.Type
|
||||
}
|
||||
if req.BaseURL != "" {
|
||||
provider.BaseURL = req.BaseURL
|
||||
}
|
||||
if req.Endpoint != "" {
|
||||
provider.Endpoint = req.Endpoint
|
||||
}
|
||||
if req.UpstreamKey != "" {
|
||||
provider.UpstreamKey = req.UpstreamKey
|
||||
}
|
||||
if req.Model != "" {
|
||||
provider.Model = req.Model
|
||||
}
|
||||
if req.ProxyKey != "" {
|
||||
provider.ProxyKey = req.ProxyKey
|
||||
}
|
||||
if req.Config != nil {
|
||||
provider.Config = req.Config
|
||||
}
|
||||
provider.IsActive = req.IsActive
|
||||
|
||||
err = global.GVA_DB.Save(&provider).Error
|
||||
return provider, err
|
||||
}
|
||||
|
||||
// GetAiProvider 获取AI提供商详情
|
||||
func (s *AiProviderService) GetAiProvider(id uint) (provider app.AiProvider, err error) {
|
||||
err = global.GVA_DB.First(&provider, id).Error
|
||||
return provider, err
|
||||
}
|
||||
|
||||
// GetAiProviderList 获取AI提供商列表
|
||||
func (s *AiProviderService) GetAiProviderList() (list []app.AiProvider, err error) {
|
||||
err = global.GVA_DB.Where("is_active = ?", true).Find(&list).Error
|
||||
return list, err
|
||||
}
|
||||
|
||||
// TestConnection 测试连接
|
||||
func (s *AiProviderService) TestConnection(req *request.TestConnectionRequest) (resp response.TestConnectionResponse, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 根据类型构建测试 URL
|
||||
var testURL string
|
||||
switch strings.ToLower(req.Type) {
|
||||
case "openai":
|
||||
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
case "claude":
|
||||
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/messages"
|
||||
default:
|
||||
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
}
|
||||
|
||||
// 创建 HTTP 请求
|
||||
httpReq, err := http.NewRequest("GET", testURL, nil)
|
||||
if err != nil {
|
||||
return response.TestConnectionResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("创建请求失败: %v", err),
|
||||
Latency: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return response.TestConnectionResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("连接失败: %v", err),
|
||||
Latency: time.Since(startTime).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
latency := time.Since(startTime).Milliseconds()
|
||||
|
||||
// 检查响应状态
|
||||
if httpResp.StatusCode == http.StatusOK || httpResp.StatusCode == http.StatusCreated {
|
||||
return response.TestConnectionResponse{
|
||||
Success: true,
|
||||
Message: "连接成功",
|
||||
Latency: latency,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 读取错误响应
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return response.TestConnectionResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("连接失败 (状态码: %d): %s", httpResp.StatusCode, string(body)),
|
||||
Latency: latency,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetModels 获取模型列表
|
||||
func (s *AiProviderService) GetModels(req *request.GetModelsRequest) (models []response.ModelInfo, err error) {
|
||||
// 根据类型构建 URL
|
||||
var modelsURL string
|
||||
switch strings.ToLower(req.Type) {
|
||||
case "openai":
|
||||
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
case "claude":
|
||||
// Claude API 不提供模型列表接口,返回预定义的模型
|
||||
return []response.ModelInfo{
|
||||
{ID: "claude-opus-4-6", Name: "Claude Opus 4.6", OwnedBy: "anthropic"},
|
||||
{ID: "claude-sonnet-4-6", Name: "Claude Sonnet 4.6", OwnedBy: "anthropic"},
|
||||
{ID: "claude-haiku-4-5-20251001", Name: "Claude Haiku 4.5", OwnedBy: "anthropic"},
|
||||
{ID: "claude-3-5-sonnet-20241022", Name: "Claude 3.5 Sonnet", OwnedBy: "anthropic"},
|
||||
{ID: "claude-3-opus-20240229", Name: "Claude 3 Opus", OwnedBy: "anthropic"},
|
||||
}, nil
|
||||
default:
|
||||
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
||||
}
|
||||
|
||||
// 创建 HTTP 请求
|
||||
httpReq, err := http.NewRequest("GET", modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, errors.New("创建请求失败: " + err.Error())
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
httpResp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, errors.New("请求失败: " + err.Error())
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(httpResp.Body)
|
||||
return nil, fmt.Errorf("获取模型列表失败 (状态码: %d): %s", httpResp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.New("读取响应失败: " + err.Error())
|
||||
}
|
||||
|
||||
// OpenAI 格式的响应
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||
return nil, errors.New("解析响应失败: " + err.Error())
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
models = make([]response.ModelInfo, len(modelsResp.Data))
|
||||
for i, model := range modelsResp.Data {
|
||||
models[i] = response.ModelInfo{
|
||||
ID: model.ID,
|
||||
Name: model.ID,
|
||||
OwnedBy: model.OwnedBy,
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
247
server/service/app/ai_proxy.go
Normal file
247
server/service/app/ai_proxy.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
||||
)
|
||||
|
||||
type AiProxyService struct{}
|
||||
|
||||
// ProcessChatCompletion 处理聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint, req *request.ChatCompletionRequest) (resp response.ChatCompletionResponse, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 获取预设配置
|
||||
var preset app.AiPreset
|
||||
if req.PresetID > 0 {
|
||||
err = global.GVA_DB.First(&preset, req.PresetID).Error
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("预设不存在: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取提供商配置
|
||||
var provider app.AiProvider
|
||||
// TODO: 根据 binding_key 或默认配置获取 provider
|
||||
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
|
||||
}
|
||||
|
||||
// 3. 构建注入后的消息
|
||||
messages, err := s.buildInjectedMessages(req, &preset)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("构建消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 转发到上游AI
|
||||
resp, err = s.forwardToAI(ctx, &provider, &preset, messages)
|
||||
if err != nil {
|
||||
// 记录失败日志
|
||||
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, "", err, time.Since(startTime))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// 5. 应用输出正则脚本
|
||||
resp.Choices[0].Message.Content = s.applyOutputRegex(resp.Choices[0].Message.Content, preset.RegexScripts)
|
||||
|
||||
// 6. 记录成功日志
|
||||
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, resp.Choices[0].Message.Content, nil, time.Since(startTime))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// buildInjectedMessages 构建注入预设后的消息数组
|
||||
func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionRequest, preset *app.AiPreset) ([]request.Message, error) {
|
||||
if preset == nil || preset.ID == 0 {
|
||||
return req.Messages, nil
|
||||
}
|
||||
|
||||
// TODO: 实现完整的预设注入逻辑
|
||||
// 1. 按 injection_order 排序 prompts
|
||||
// 2. 根据 injection_depth 插入到对话历史中
|
||||
// 3. 替换变量 {{user}}, {{char}}
|
||||
// 4. 应用正则脚本 (placement=1)
|
||||
|
||||
messages := make([]request.Message, 0)
|
||||
|
||||
// 简化实现:直接添加系统提示词
|
||||
for _, prompt := range preset.Prompts {
|
||||
if prompt.SystemPrompt && !prompt.Marker {
|
||||
messages = append(messages, request.Message{
|
||||
Role: prompt.Role,
|
||||
Content: s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 添加用户消息
|
||||
messages = append(messages, req.Messages...)
|
||||
|
||||
// 应用输入正则脚本
|
||||
for i := range messages {
|
||||
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// replaceVariables 替换变量
|
||||
func (s *AiProxyService) replaceVariables(content string, vars map[string]string, card *request.CharacterCard) string {
|
||||
result := content
|
||||
|
||||
// 替换自定义变量
|
||||
for key, value := range vars {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
result = replaceAll(result, placeholder, value)
|
||||
}
|
||||
|
||||
// 替换角色卡片变量
|
||||
if card != nil {
|
||||
result = replaceAll(result, "{{char}}", card.Name)
|
||||
result = replaceAll(result, "{{char_name}}", card.Name)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// applyInputRegex 应用输入正则脚本
|
||||
func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScript) string {
|
||||
for _, script := range scripts {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
if !containsPlacement(script.Placement, 1) {
|
||||
continue
|
||||
}
|
||||
// TODO: 实现正则替换逻辑
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// applyOutputRegex 应用输出正则脚本
|
||||
func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScript) string {
|
||||
for _, script := range scripts {
|
||||
if script.Disabled {
|
||||
continue
|
||||
}
|
||||
if !containsPlacement(script.Placement, 2) {
|
||||
continue
|
||||
}
|
||||
// TODO: 实现正则替换逻辑
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// forwardToAI 转发请求到上游AI
|
||||
func (s *AiProxyService) forwardToAI(ctx context.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message) (response.ChatCompletionResponse, error) {
|
||||
// 构建请求体
|
||||
reqBody := map[string]interface{}{
|
||||
"model": provider.Model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if preset != nil {
|
||||
reqBody["temperature"] = preset.Temperature
|
||||
reqBody["top_p"] = preset.TopP
|
||||
reqBody["max_tokens"] = preset.MaxTokens
|
||||
reqBody["frequency_penalty"] = preset.FrequencyPenalty
|
||||
reqBody["presence_penalty"] = preset.PresencePenalty
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if provider.UpstreamKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return response.ChatCompletionResponse{}, fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var aiResp response.ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &aiResp); err != nil {
|
||||
return response.ChatCompletionResponse{}, err
|
||||
}
|
||||
|
||||
return aiResp, nil
|
||||
}
|
||||
|
||||
// logRequest 记录请求日志
|
||||
func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider *app.AiProvider, originalMsg, responseText string, err error, latency time.Duration) {
|
||||
log := app.AiRequestLog{
|
||||
UserID: &userId,
|
||||
OriginalMessage: originalMsg,
|
||||
ResponseText: responseText,
|
||||
LatencyMs: int(latency.Milliseconds()),
|
||||
}
|
||||
|
||||
if preset != nil {
|
||||
presetID := preset.ID
|
||||
log.PresetID = &presetID
|
||||
}
|
||||
|
||||
if provider != nil {
|
||||
providerID := provider.ID
|
||||
log.ProviderID = &providerID
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Status = "error"
|
||||
log.ErrorMessage = err.Error()
|
||||
} else {
|
||||
log.Status = "success"
|
||||
}
|
||||
|
||||
global.GVA_DB.Create(&log)
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func replaceAll(s, old, new string) string {
|
||||
return s // TODO: 实现字符串替换
|
||||
}
|
||||
|
||||
func containsPlacement(placements []int, target int) bool {
|
||||
for _, p := range placements {
|
||||
if p == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
8
server/service/app/enter.go
Normal file
8
server/service/app/enter.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package app
|
||||
|
||||
type AppServiceGroup struct {
|
||||
AiPresetService AiPresetService
|
||||
AiProviderService AiProviderService
|
||||
AiProxyService AiProxyService
|
||||
PresetBindingService PresetBindingService
|
||||
}
|
||||
Reference in New Issue
Block a user