🎨 优化模型配置 && 新增apikey功能 && 完善通用接口

This commit is contained in:
2026-03-03 17:13:24 +08:00
parent 2714e63d2a
commit 7dae1a6e2b
46 changed files with 3063 additions and 278 deletions

View File

@@ -0,0 +1,107 @@
package app
import (
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiApiKeyService struct{}
// GenerateApiKey 生成API密钥
func (s *AiApiKeyService) GenerateApiKey() string {
bytes := make([]byte, 32)
rand.Read(bytes)
return "sk-" + hex.EncodeToString(bytes)
}
// CreateAiApiKey 创建API密钥
func (s *AiApiKeyService) CreateAiApiKey(apiKey *app.AiApiKey) error {
if apiKey.Key == "" {
apiKey.Key = s.GenerateApiKey()
}
return global.GVA_DB.Create(apiKey).Error
}
// DeleteAiApiKey 删除API密钥
func (s *AiApiKeyService) DeleteAiApiKey(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiApiKey{}).Error
}
// UpdateAiApiKey 更新API密钥
func (s *AiApiKeyService) UpdateAiApiKey(apiKey *app.AiApiKey, userID uint) error {
return global.GVA_DB.Where("user_id = ?", userID).Updates(apiKey).Error
}
// GetAiApiKey 查询API密钥
func (s *AiApiKeyService) GetAiApiKey(id uint, userID uint) (apiKey app.AiApiKey, err error) {
err = global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&apiKey).Error
return
}
// GetAiApiKeyList 获取API密钥列表
func (s *AiApiKeyService) GetAiApiKeyList(info request.PageInfo, userID uint) (list []app.AiApiKey, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiApiKey{}).Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}
// ValidateApiKey 验证API密钥
func (s *AiApiKeyService) ValidateApiKey(key string) (*app.AiApiKey, error) {
var apiKey app.AiApiKey
if err := global.GVA_DB.Where("key = ? AND enabled = ?", key, true).First(&apiKey).Error; err != nil {
return nil, fmt.Errorf("无效的API密钥")
}
// 检查是否过期
if apiKey.ExpiresAt != nil && time.Now().Unix() > *apiKey.ExpiresAt {
return nil, fmt.Errorf("API密钥已过期")
}
return &apiKey, nil
}
// CheckModelPermission 检查模型权限
func (s *AiApiKeyService) CheckModelPermission(apiKey *app.AiApiKey, model string) bool {
// 如果没有限制,允许所有模型
if len(apiKey.AllowedModels) == 0 {
return true
}
// 检查模型是否在允许列表中
for _, allowedModel := range apiKey.AllowedModels {
if allowedModel == model || allowedModel == "*" {
return true
}
}
return false
}
// CheckPresetPermission 检查预设权限
func (s *AiApiKeyService) CheckPresetPermission(apiKey *app.AiApiKey, presetName string) bool {
// 如果没有限制,允许所有预设
if len(apiKey.AllowedPresets) == 0 {
return true
}
// 检查预设是否在允许列表中
for _, allowedPreset := range apiKey.AllowedPresets {
if allowedPreset == presetName || allowedPreset == "*" {
return true
}
}
return false
}

View File

@@ -0,0 +1,200 @@
package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProcessClaudeMessage 处理 Claude 消息请求
func (s *AiProxyService) ProcessClaudeMessage(ctx context.Context, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
// 1. 根据模型获取配置
if req.Model == "" {
return nil, fmt.Errorf("model 参数不能为空")
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
return nil, err
}
// 2. 注入预设
if preset != nil {
injector := NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 转发请求到上游
resp, err := s.forwardClaudeRequest(ctx, provider, req)
if err != nil {
return nil, err
}
// 4. 处理响应
if preset != nil && len(resp.Content) > 0 {
injector := NewPresetInjector(preset)
resp.Content[0].Text = injector.ProcessResponse(resp.Content[0].Text)
}
return resp, nil
}
// ProcessClaudeMessageStream 处理 Claude 流式消息请求
func (s *AiProxyService) ProcessClaudeMessageStream(c *gin.Context, req *request.ClaudeMessageRequest) {
// 1. 根据模型获取配置
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
return
}
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 注入预设
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = s.convertClaudeMessages(injector.InjectMessages(s.convertToOpenAIMessages(req.Messages)))
}
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 4. 转发流式请求
err = s.forwardClaudeStreamRequest(c, provider, req, injector)
if err != nil {
global.GVA_LOG.Error("Claude流式请求失败", zap.Error(err))
}
}
// forwardClaudeRequest 转发 Claude 请求
func (s *AiProxyService) forwardClaudeRequest(ctx context.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest) (*response.ClaudeMessageResponse, error) {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
var resp response.ClaudeMessageResponse
json.NewDecoder(httpResp.Body).Decode(&resp)
return &resp, nil
}
// forwardClaudeStreamRequest 转发 Claude 流式请求
func (s *AiProxyService) forwardClaudeStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ClaudeMessageRequest, injector *PresetInjector) error {
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, _ := json.Marshal(req)
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/messages"
httpReq, _ := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("x-api-key", provider.APIKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
reader := bufio.NewReader(httpResp.Body)
flusher, _ := c.Writer.(http.Flusher)
for {
line, err := reader.ReadBytes('\n')
if err == io.EOF {
break
}
if len(bytes.TrimSpace(line)) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
var chunk response.ClaudeStreamResponse
if json.Unmarshal(data, &chunk) == nil && chunk.Delta != nil {
if injector != nil {
chunk.Delta.Text = injector.ProcessResponse(chunk.Delta.Text)
}
processedData, _ := json.Marshal(chunk)
c.Writer.Write([]byte("data: "))
c.Writer.Write(processedData)
c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
}
}
return nil
}
// convertToOpenAIMessages 转换 Claude 消息为 OpenAI 格式
func (s *AiProxyService) convertToOpenAIMessages(messages []request.ClaudeMessage) []request.ChatMessage {
result := make([]request.ChatMessage, len(messages))
for i, msg := range messages {
content := ""
// 处理字符串类型的 content
if str, ok := msg.Content.(string); ok {
content = str
} else if blocks, ok := msg.Content.([]interface{}); ok {
// 处理对象数组类型的 content (Claude API 标准格式)
for _, block := range blocks {
if blockMap, ok := block.(map[string]interface{}); ok {
if blockMap["type"] == "text" {
if text, ok := blockMap["text"].(string); ok {
content += text
}
}
}
}
}
result[i] = request.ChatMessage{Role: msg.Role, Content: content}
}
return result
}
// convertClaudeMessages 转换 OpenAI 消息为 Claude 格式
func (s *AiProxyService) convertClaudeMessages(messages []request.ChatMessage) []request.ClaudeMessage {
result := make([]request.ClaudeMessage, len(messages))
for i, msg := range messages {
result[i] = request.ClaudeMessage{Role: msg.Role, Content: msg.Content}
}
return result
}

View File

@@ -0,0 +1,148 @@
package app
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiModelService struct{}
// CreateAiModel 创建模型
func (s *AiModelService) CreateAiModel(model *app.AiModel) error {
return global.GVA_DB.Create(model).Error
}
// DeleteAiModel 删除模型
func (s *AiModelService) DeleteAiModel(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiModel{}).Error
}
// UpdateAiModel 更新模型
func (s *AiModelService) UpdateAiModel(model *app.AiModel, userID uint) error {
return global.GVA_DB.Where("user_id = ?", userID).Updates(model).Error
}
// GetAiModel 查询模型
func (s *AiModelService) GetAiModel(id uint, userID uint) (model app.AiModel, err error) {
err = global.GVA_DB.Preload("Provider").Preload("Preset").Where("id = ? AND user_id = ?", id, userID).First(&model).Error
return
}
// GetAiModelList 获取模型列表
func (s *AiModelService) GetAiModelList(info request.PageInfo, userID uint) (list []app.AiModel, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiModel{}).Preload("Provider").Preload("Preset").Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}
// GetModelByNameAndProvider 根据模型名称和提供商ID查询模型配置
func (s *AiModelService) GetModelByNameAndProvider(modelName string, providerID uint) (*app.AiModel, error) {
var model app.AiModel
err := global.GVA_DB.Preload("Provider").Preload("Preset").
Where("name = ? AND provider_id = ? AND enabled = ?", modelName, providerID, true).
First(&model).Error
if err != nil {
return nil, fmt.Errorf("未找到模型配置: %s", modelName)
}
return &model, nil
}
// FetchProviderModels 从提供商获取可用模型列表
func (s *AiModelService) FetchProviderModels(provider *app.AiProvider) ([]ProviderModel, error) {
// 构建请求 URL
url := fmt.Sprintf("%s/v1/models", provider.BaseURL)
// 创建 HTTP 请求
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
// 设置请求头
if provider.Type == "openai" || provider.Type == "other" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
} else if provider.Type == "claude" {
req.Header.Set("x-api-key", provider.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
}
// 发送请求
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("获取模型列表失败: %d - %s", resp.StatusCode, string(body))
}
// 解析响应
var result struct {
Data []ProviderModel `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return result.Data, nil
}
// SyncProviderModels 同步提供商的模型列表
func (s *AiModelService) SyncProviderModels(providerID uint, userID uint) error {
// 获取提供商信息
var provider app.AiProvider
if err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error; err != nil {
return fmt.Errorf("提供商不存在")
}
// 从提供商获取模型列表
models, err := s.FetchProviderModels(&provider)
if err != nil {
return err
}
// 同步到数据库
for _, model := range models {
var existingModel app.AiModel
err := global.GVA_DB.Where("name = ? AND provider_id = ? AND user_id = ?", model.ID, providerID, userID).First(&existingModel).Error
if err != nil {
// 模型不存在,创建新记录
newModel := app.AiModel{
Name: model.ID,
DisplayName: model.ID,
ProviderID: providerID,
Enabled: false, // 默认不启用,需要管理员手动启用
UserID: userID,
}
global.GVA_DB.Create(&newModel)
}
// 如果模型已存在,不做任何操作(保留用户的配置)
}
return nil
}
// ProviderModel 提供商返回的模型信息
type ProviderModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}

View File

@@ -1,6 +1,9 @@
package app
import (
"encoding/json"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
@@ -41,3 +44,88 @@ func (s *AiPresetService) GetAiPresetList(info request.PageInfo, userID uint) (l
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}
// ParseImportedPreset 解析导入的预设,支持 SillyTavern 格式
func (s *AiPresetService) ParseImportedPreset(rawData map[string]interface{}) (*app.AiPreset, error) {
preset := &app.AiPreset{
Enabled: true,
}
// 处理名称字段 - 支持多种格式
if name, ok := rawData["name"].(string); ok && name != "" {
preset.Name = name
} else if name, ok := rawData["preset_name"].(string); ok && name != "" {
preset.Name = name
} else if name, ok := rawData["presetName"].(string); ok && name != "" {
preset.Name = name
} else {
return nil, fmt.Errorf("预设名称不能为空")
}
// 处理描述
if desc, ok := rawData["description"].(string); ok {
preset.Description = desc
}
// 处理参数
if temp, ok := rawData["temperature"].(float64); ok {
preset.Temperature = temp
} else {
preset.Temperature = 1.0
}
if topP, ok := rawData["top_p"].(float64); ok {
preset.TopP = topP
} else {
preset.TopP = 0.9
}
if topK, ok := rawData["top_k"].(float64); ok {
preset.TopK = int(topK)
}
if maxTokens, ok := rawData["max_tokens"].(float64); ok {
preset.MaxTokens = int(maxTokens)
} else {
preset.MaxTokens = 4096
}
if freqPenalty, ok := rawData["frequency_penalty"].(float64); ok {
preset.FrequencyPenalty = freqPenalty
}
if presPenalty, ok := rawData["presence_penalty"].(float64); ok {
preset.PresencePenalty = presPenalty
}
// 处理提示词
if prompts, ok := rawData["prompts"].([]interface{}); ok {
promptsData, _ := json.Marshal(prompts)
json.Unmarshal(promptsData, &preset.Prompts)
}
// 处理提示词顺序
if promptOrder, ok := rawData["prompt_order"].([]interface{}); ok {
orderData, _ := json.Marshal(promptOrder)
json.Unmarshal(orderData, &preset.PromptOrder)
}
// 处理正则脚本
if regexScripts, ok := rawData["regex_scripts"].([]interface{}); ok {
scriptsData, _ := json.Marshal(regexScripts)
json.Unmarshal(scriptsData, &preset.RegexScripts)
}
// 处理扩展配置
if extensions, ok := rawData["extensions"].(map[string]interface{}); ok {
extData, _ := json.Marshal(extensions)
json.Unmarshal(extData, &preset.Extensions)
}
// 处理启用状态
if enabled, ok := rawData["enabled"].(bool); ok {
preset.Enabled = enabled
}
return preset, nil
}

View File

@@ -1,43 +0,0 @@
package app
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiPresetBindingService struct{}
// CreateAiPresetBinding 创建绑定
func (s *AiPresetBindingService) CreateAiPresetBinding(binding *app.AiPresetBinding) error {
return global.GVA_DB.Create(binding).Error
}
// DeleteAiPresetBinding 删除绑定
func (s *AiPresetBindingService) DeleteAiPresetBinding(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiPresetBinding{}).Error
}
// UpdateAiPresetBinding 更新绑定
func (s *AiPresetBindingService) UpdateAiPresetBinding(binding *app.AiPresetBinding, userID uint) error {
return global.GVA_DB.Where("user_id = ?", userID).Updates(binding).Error
}
// GetAiPresetBinding 查询绑定
func (s *AiPresetBindingService) GetAiPresetBinding(id uint, userID uint) (binding app.AiPresetBinding, err error) {
err = global.GVA_DB.Preload("Preset").Preload("Provider").Where("id = ? AND user_id = ?", id, userID).First(&binding).Error
return
}
// GetAiPresetBindingList 获取绑定列表
func (s *AiPresetBindingService) GetAiPresetBindingList(info request.PageInfo, userID uint) (list []app.AiPresetBinding, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiPresetBinding{}).Preload("Preset").Preload("Provider").Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}

View File

@@ -59,45 +59,94 @@ func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
}
// buildInjectedMessages 构建注入后的消息列表
// 参考 SillyTavern 的实现逻辑
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
result := make([]request.ChatMessage, 0)
// 分离系统提示词和对话消息
var systemPrompts []app.PresetPrompt
var otherPrompts []app.PresetPrompt
// 按照 injection_position 分组
// 0 = 相对位置(从对话历史的特定深度注入)
// 1 = 绝对位置(在消息列表的固定位置注入)
var relativePrompts []app.PresetPrompt
var absolutePrompts []app.PresetPrompt
for _, prompt := range prompts {
if prompt.Role == "system" {
systemPrompts = append(systemPrompts, prompt)
if prompt.InjectionPosition == 0 {
relativePrompts = append(relativePrompts, prompt)
} else {
otherPrompts = append(otherPrompts, prompt)
absolutePrompts = append(absolutePrompts, prompt)
}
}
// 1. 先添加系统提示
for _, prompt := range systemPrompts {
result = append(result, request.ChatMessage{
Role: "system",
Content: p.processPromptContent(prompt.Content),
})
// 处理绝对位置的提示词(通常是系统提示
for _, prompt := range absolutePrompts {
if prompt.InjectionDepth == 0 {
// depth=0 表示在最开始
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
// 2. 处理对话历史注入
chatHistoryIndex := p.findMarkerIndex("chatHistory")
if chatHistoryIndex >= 0 {
// 在 chatHistory 标记位置注入原始消息
result = append(result, messages...)
} else {
// 如果没有 chatHistory 标记,直接添加到末尾
result = append(result, messages...)
// 处理相对位置的提示词和对话历史
// 按 injection_depth 从大到小排序(深度越大越靠前)
sort.Slice(relativePrompts, func(i, j int) bool {
if relativePrompts[i].InjectionDepth != relativePrompts[j].InjectionDepth {
return relativePrompts[i].InjectionDepth > relativePrompts[j].InjectionDepth
}
return relativePrompts[i].InjectionOrder < relativePrompts[j].InjectionOrder
})
// 注入相对位置的提示词到对话历史中
injectedMessages := p.injectRelativePrompts(messages, relativePrompts)
result = append(result, injectedMessages...)
// 处理绝对位置在末尾的提示词
for _, prompt := range absolutePrompts {
if prompt.InjectionDepth > 0 {
// depth>0 表示在末尾
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
// 3. 添加其他角色的提示词(assistant等)
for _, prompt := range otherPrompts {
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
return result
}
// injectRelativePrompts 将相对位置的提示词注入到对话历史中
func (p *PresetInjector) injectRelativePrompts(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
if len(prompts) == 0 {
return messages
}
result := make([]request.ChatMessage, 0, len(messages)+len(prompts))
messageCount := len(messages)
// 按深度分组提示词
depthMap := make(map[int][]app.PresetPrompt)
for _, prompt := range prompts {
depthMap[prompt.InjectionDepth] = append(depthMap[prompt.InjectionDepth], prompt)
}
// 遍历消息,在指定深度注入提示词
for i, msg := range messages {
// 计算当前位置的深度(从末尾开始计数)
depth := messageCount - i
// 在当前消息之前注入对应深度的提示词
if promptsAtDepth, exists := depthMap[depth]; exists {
for _, prompt := range promptsAtDepth {
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
}
// 添加当前消息
result = append(result, msg)
}
return result

View File

@@ -22,53 +22,60 @@ import (
type AiProxyService struct{}
// ProcessChatCompletion 处理聊天补全请求
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userID uint, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
startTime := time.Now()
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
// 1. 根据模型获取配置
if req.Model == "" {
return nil, fmt.Errorf("model 参数不能为空")
}
// 1. 获取绑定配置
binding, err := s.getBinding(userID, req)
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
return nil, fmt.Errorf("获取绑定配置失败: %w", err)
return nil, err
}
// 2. 注入预设
injector := NewPresetInjector(&binding.Preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
if preset != nil {
injector := NewPresetInjector(preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
}
// 3. 转发请求到上游
resp, err := s.forwardRequest(ctx, &binding.Provider, req)
resp, err := s.forwardRequest(ctx, provider, req)
if err != nil {
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
return nil, err
}
// 4. 处理响应
if len(resp.Choices) > 0 {
if preset != nil && len(resp.Choices) > 0 {
injector := NewPresetInjector(preset)
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
}
// 5. 记录日志
s.logRequest(userID, binding, req, resp, nil, time.Since(startTime))
return resp, nil
}
// ProcessChatCompletionStream 处理流式聊天补全请求
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint, req *request.ChatCompletionRequest) {
startTime := time.Now()
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, req *request.ChatCompletionRequest) {
// 1. 根据模型获取配置
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "model 参数不能为空"})
return
}
// 1. 获取绑定配置
binding, err := s.getBinding(userID, req)
preset, provider, err := s.getConfigByModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 注入预设
injector := NewPresetInjector(&binding.Preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
var injector *PresetInjector
if preset != nil {
injector = NewPresetInjector(preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
}
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
@@ -77,45 +84,30 @@ func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint
c.Header("X-Accel-Buffering", "no")
// 4. 转发流式请求
err = s.forwardStreamRequest(c, &binding.Provider, req, injector)
err = s.forwardStreamRequest(c, provider, req, injector)
if err != nil {
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
}
}
// getBinding 获取绑定配置
func (s *AiProxyService) getBinding(userID uint, req *request.ChatCompletionRequest) (*app.AiPresetBinding, error) {
var binding app.AiPresetBinding
// getConfigByModel 根据模型名称获取配置
func (s *AiProxyService) getConfigByModel(modelName string) (*app.AiPreset, *app.AiProvider, error) {
// 查找启用的模型配置
var model app.AiModel
err := global.GVA_DB.Preload("Provider").Preload("Preset").
Where("name = ? AND enabled = ?", modelName, true).
First(&model).Error
query := global.GVA_DB.Preload("Preset").Preload("Provider").Where("user_id = ? AND enabled = ?", userID, true)
// 优先使用 binding_name
if req.BindingName != "" {
query = query.Where("name = ?", req.BindingName)
} else if req.PresetName != "" && req.ProviderName != "" {
// 使用 preset_name 和 provider_name
query = query.Joins("JOIN ai_presets ON ai_presets.id = ai_preset_bindings.preset_id").
Joins("JOIN ai_providers ON ai_providers.id = ai_preset_bindings.provider_id").
Where("ai_presets.name = ? AND ai_providers.name = ?", req.PresetName, req.ProviderName)
} else {
// 使用默认绑定(第一个启用的)
query = query.Order("id ASC")
if err != nil {
return nil, nil, fmt.Errorf("未找到模型配置: %s", modelName)
}
if err := query.First(&binding).Error; err != nil {
return nil, fmt.Errorf("未找到可用的绑定配置")
// 检查提供商是否启用
if !model.Provider.Enabled {
return nil, nil, fmt.Errorf("提供商已禁用")
}
if !binding.Provider.Enabled {
return nil, fmt.Errorf("提供商已禁用")
}
if !binding.Preset.Enabled {
return nil, fmt.Errorf("预设已禁用")
}
return &binding, nil
return model.Preset, &model.Provider, nil
}
// forwardRequest 转发请求到上游 AI 服务
@@ -235,7 +227,7 @@ func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiPr
}
// 应用输出正则处理
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
if injector != nil && len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
chunk.Choices[0].Delta.Content = injector.ProcessResponse(chunk.Choices[0].Delta.Content)
}
@@ -251,29 +243,40 @@ func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiPr
return nil
}
// logRequest 记录请求日志
func (s *AiProxyService) logRequest(userID uint, binding *app.AiPresetBinding, req *request.ChatCompletionRequest, resp *response.ChatCompletionResponse, err error, duration time.Duration) {
log := app.AiRequestLog{
UserID: userID,
BindingID: binding.ID,
ProviderID: binding.ProviderID,
PresetID: binding.PresetID,
Model: req.Model,
Duration: duration.Milliseconds(),
RequestTime: time.Now(),
// GetAvailableModels 获取用户可用的模型列表
func (s *AiProxyService) GetAvailableModels(apiKey *app.AiApiKey) (*response.ModelListResponse, error) {
// 查询所有启用的模型
var models []app.AiModel
query := global.GVA_DB.Where("enabled = ?", true)
// 如果 API Key 限制了模型,只返回允许的模型
if len(apiKey.AllowedModels) > 0 {
query = query.Where("name IN ?", apiKey.AllowedModels)
}
if err != nil {
log.Status = "error"
log.ErrorMessage = err.Error()
} else {
log.Status = "success"
if resp != nil {
log.PromptTokens = resp.Usage.PromptTokens
log.CompletionTokens = resp.Usage.CompletionTokens
log.TotalTokens = resp.Usage.TotalTokens
if err := query.Find(&models).Error; err != nil {
return nil, fmt.Errorf("查询模型列表失败: %w", err)
}
// 构建响应
modelList := &response.ModelListResponse{
Object: "list",
Data: make([]response.ModelInfo, 0, len(models)),
}
// 去重(同一模型可能在多个提供商下配置)
seen := make(map[string]bool)
for _, model := range models {
if !seen[model.Name] {
seen[model.Name] = true
modelList.Data = append(modelList.Data, response.ModelInfo{
ID: model.Name,
Object: "model",
Created: model.CreatedAt.Unix(),
OwnedBy: "system",
})
}
}
global.GVA_DB.Create(&log)
return modelList, nil
}

View File

@@ -4,5 +4,6 @@ type AppServiceGroup struct {
AiProxyService
AiPresetService
AiProviderService
AiPresetBindingService
AiApiKeyService
AiModelService
}