🎨 优化模型配置 && 新增apikey功能 && 完善通用接口
This commit is contained in:
@@ -22,53 +22,60 @@ import (
|
||||
type AiProxyService struct{}
|
||||
|
||||
// ProcessChatCompletion 处理聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userID uint, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||||
startTime := time.Now()
|
||||
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
|
||||
// 1. 根据模型获取配置
|
||||
if req.Model == "" {
|
||||
return nil, fmt.Errorf("model 参数不能为空")
|
||||
}
|
||||
|
||||
// 1. 获取绑定配置
|
||||
binding, err := s.getBinding(userID, req)
|
||||
preset, provider, err := s.getConfigByModel(req.Model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取绑定配置失败: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 注入预设
|
||||
injector := NewPresetInjector(&binding.Preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
if preset != nil {
|
||||
injector := NewPresetInjector(preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
}
|
||||
|
||||
// 3. 转发请求到上游
|
||||
resp, err := s.forwardRequest(ctx, &binding.Provider, req)
|
||||
resp, err := s.forwardRequest(ctx, provider, req)
|
||||
if err != nil {
|
||||
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 处理响应
|
||||
if len(resp.Choices) > 0 {
|
||||
if preset != nil && len(resp.Choices) > 0 {
|
||||
injector := NewPresetInjector(preset)
|
||||
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
// 5. 记录日志
|
||||
s.logRequest(userID, binding, req, resp, nil, time.Since(startTime))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ProcessChatCompletionStream 处理流式聊天补全请求
|
||||
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint, req *request.ChatCompletionRequest) {
|
||||
startTime := time.Now()
|
||||
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, req *request.ChatCompletionRequest) {
|
||||
// 1. 根据模型获取配置
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 获取绑定配置
|
||||
binding, err := s.getBinding(userID, req)
|
||||
preset, provider, err := s.getConfigByModel(req.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 注入预设
|
||||
injector := NewPresetInjector(&binding.Preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
var injector *PresetInjector
|
||||
if preset != nil {
|
||||
injector = NewPresetInjector(preset)
|
||||
req.Messages = injector.InjectMessages(req.Messages)
|
||||
injector.ApplyPresetParameters(req)
|
||||
}
|
||||
|
||||
// 3. 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
@@ -77,45 +84,30 @@ func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 4. 转发流式请求
|
||||
err = s.forwardStreamRequest(c, &binding.Provider, req, injector)
|
||||
err = s.forwardStreamRequest(c, provider, req, injector)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
|
||||
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
|
||||
// getBinding 获取绑定配置
|
||||
func (s *AiProxyService) getBinding(userID uint, req *request.ChatCompletionRequest) (*app.AiPresetBinding, error) {
|
||||
var binding app.AiPresetBinding
|
||||
// getConfigByModel 根据模型名称获取配置
|
||||
func (s *AiProxyService) getConfigByModel(modelName string) (*app.AiPreset, *app.AiProvider, error) {
|
||||
// 查找启用的模型配置
|
||||
var model app.AiModel
|
||||
err := global.GVA_DB.Preload("Provider").Preload("Preset").
|
||||
Where("name = ? AND enabled = ?", modelName, true).
|
||||
First(&model).Error
|
||||
|
||||
query := global.GVA_DB.Preload("Preset").Preload("Provider").Where("user_id = ? AND enabled = ?", userID, true)
|
||||
|
||||
// 优先使用 binding_name
|
||||
if req.BindingName != "" {
|
||||
query = query.Where("name = ?", req.BindingName)
|
||||
} else if req.PresetName != "" && req.ProviderName != "" {
|
||||
// 使用 preset_name 和 provider_name
|
||||
query = query.Joins("JOIN ai_presets ON ai_presets.id = ai_preset_bindings.preset_id").
|
||||
Joins("JOIN ai_providers ON ai_providers.id = ai_preset_bindings.provider_id").
|
||||
Where("ai_presets.name = ? AND ai_providers.name = ?", req.PresetName, req.ProviderName)
|
||||
} else {
|
||||
// 使用默认绑定(第一个启用的)
|
||||
query = query.Order("id ASC")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("未找到模型配置: %s", modelName)
|
||||
}
|
||||
|
||||
if err := query.First(&binding).Error; err != nil {
|
||||
return nil, fmt.Errorf("未找到可用的绑定配置")
|
||||
// 检查提供商是否启用
|
||||
if !model.Provider.Enabled {
|
||||
return nil, nil, fmt.Errorf("提供商已禁用")
|
||||
}
|
||||
|
||||
if !binding.Provider.Enabled {
|
||||
return nil, fmt.Errorf("提供商已禁用")
|
||||
}
|
||||
|
||||
if !binding.Preset.Enabled {
|
||||
return nil, fmt.Errorf("预设已禁用")
|
||||
}
|
||||
|
||||
return &binding, nil
|
||||
return model.Preset, &model.Provider, nil
|
||||
}
|
||||
|
||||
// forwardRequest 转发请求到上游 AI 服务
|
||||
@@ -235,7 +227,7 @@ func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiPr
|
||||
}
|
||||
|
||||
// 应用输出正则处理
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
if injector != nil && len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
chunk.Choices[0].Delta.Content = injector.ProcessResponse(chunk.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
@@ -251,29 +243,40 @@ func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiPr
|
||||
return nil
|
||||
}
|
||||
|
||||
// logRequest 记录请求日志
|
||||
func (s *AiProxyService) logRequest(userID uint, binding *app.AiPresetBinding, req *request.ChatCompletionRequest, resp *response.ChatCompletionResponse, err error, duration time.Duration) {
|
||||
log := app.AiRequestLog{
|
||||
UserID: userID,
|
||||
BindingID: binding.ID,
|
||||
ProviderID: binding.ProviderID,
|
||||
PresetID: binding.PresetID,
|
||||
Model: req.Model,
|
||||
Duration: duration.Milliseconds(),
|
||||
RequestTime: time.Now(),
|
||||
// GetAvailableModels 获取用户可用的模型列表
|
||||
func (s *AiProxyService) GetAvailableModels(apiKey *app.AiApiKey) (*response.ModelListResponse, error) {
|
||||
// 查询所有启用的模型
|
||||
var models []app.AiModel
|
||||
query := global.GVA_DB.Where("enabled = ?", true)
|
||||
|
||||
// 如果 API Key 限制了模型,只返回允许的模型
|
||||
if len(apiKey.AllowedModels) > 0 {
|
||||
query = query.Where("name IN ?", apiKey.AllowedModels)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Status = "error"
|
||||
log.ErrorMessage = err.Error()
|
||||
} else {
|
||||
log.Status = "success"
|
||||
if resp != nil {
|
||||
log.PromptTokens = resp.Usage.PromptTokens
|
||||
log.CompletionTokens = resp.Usage.CompletionTokens
|
||||
log.TotalTokens = resp.Usage.TotalTokens
|
||||
if err := query.Find(&models).Error; err != nil {
|
||||
return nil, fmt.Errorf("查询模型列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
modelList := &response.ModelListResponse{
|
||||
Object: "list",
|
||||
Data: make([]response.ModelInfo, 0, len(models)),
|
||||
}
|
||||
|
||||
// 去重(同一模型可能在多个提供商下配置)
|
||||
seen := make(map[string]bool)
|
||||
for _, model := range models {
|
||||
if !seen[model.Name] {
|
||||
seen[model.Name] = true
|
||||
modelList.Data = append(modelList.Data, response.ModelInfo{
|
||||
ID: model.Name,
|
||||
Object: "model",
|
||||
Created: model.CreatedAt.Unix(),
|
||||
OwnedBy: "system",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_DB.Create(&log)
|
||||
return modelList, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user