Files
ai_proxy/server/service/app/ai_model.go
2026-03-03 20:33:46 +08:00

150 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiModelService struct{}
// CreateAiModel 创建模型
func (s *AiModelService) CreateAiModel(model *app.AiModel) error {
return global.GVA_DB.Create(model).Error
}
// DeleteAiModel 删除模型
func (s *AiModelService) DeleteAiModel(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiModel{}).Error
}
// UpdateAiModel 更新模型
func (s *AiModelService) UpdateAiModel(model *app.AiModel, userID uint) error {
// 使用 Select("*") 来更新所有字段,包括零值字段(如 enabled=false
return global.GVA_DB.Model(&app.AiModel{}).Where("id = ? AND user_id = ?", model.ID, userID).Select("*").Updates(model).Error
}
// GetAiModel 查询模型
func (s *AiModelService) GetAiModel(id uint, userID uint) (model app.AiModel, err error) {
err = global.GVA_DB.Preload("Provider").Preload("Preset").Where("id = ? AND user_id = ?", id, userID).First(&model).Error
return
}
// GetAiModelList 获取模型列表
func (s *AiModelService) GetAiModelList(info request.PageInfo, userID uint) (list []app.AiModel, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiModel{}).Preload("Provider").Preload("Preset").Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}
// GetModelByNameAndProvider 根据模型名称和提供商ID查询模型配置
func (s *AiModelService) GetModelByNameAndProvider(modelName string, providerID uint) (*app.AiModel, error) {
var model app.AiModel
err := global.GVA_DB.Preload("Provider").Preload("Preset").
Where("name = ? AND provider_id = ? AND enabled = ?", modelName, providerID, true).
First(&model).Error
if err != nil {
return nil, fmt.Errorf("未找到模型配置: %s", modelName)
}
return &model, nil
}
// FetchProviderModels 从提供商获取可用模型列表
func (s *AiModelService) FetchProviderModels(provider *app.AiProvider) ([]ProviderModel, error) {
// 构建请求 URL
url := fmt.Sprintf("%s/v1/models", provider.BaseURL)
// 创建 HTTP 请求
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
// 设置请求头
if provider.Type == "openai" || provider.Type == "other" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
} else if provider.Type == "claude" {
req.Header.Set("x-api-key", provider.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
}
// 发送请求
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("获取模型列表失败: %d - %s", resp.StatusCode, string(body))
}
// 解析响应
var result struct {
Data []ProviderModel `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return result.Data, nil
}
// SyncProviderModels 同步提供商的模型列表
func (s *AiModelService) SyncProviderModels(providerID uint, userID uint) error {
// 获取提供商信息
var provider app.AiProvider
if err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error; err != nil {
return fmt.Errorf("提供商不存在")
}
// 从提供商获取模型列表
models, err := s.FetchProviderModels(&provider)
if err != nil {
return err
}
// 同步到数据库
for _, model := range models {
var existingModel app.AiModel
err := global.GVA_DB.Where("name = ? AND provider_id = ? AND user_id = ?", model.ID, providerID, userID).First(&existingModel).Error
if err != nil {
// 模型不存在,创建新记录
newModel := app.AiModel{
Name: model.ID,
DisplayName: model.ID,
ProviderID: providerID,
Enabled: false, // 默认不启用,需要管理员手动启用
UserID: userID,
}
global.GVA_DB.Create(&newModel)
}
// 如果模型已存在,不做任何操作(保留用户的配置)
}
return nil
}
// ProviderModel 提供商返回的模型信息
type ProviderModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}