Files
ai_proxy/server/service/app/ai_model.go

149 lines
4.4 KiB
Go

package app
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiModelService struct{}
// CreateAiModel 创建模型
func (s *AiModelService) CreateAiModel(model *app.AiModel) error {
return global.GVA_DB.Create(model).Error
}
// DeleteAiModel 删除模型
func (s *AiModelService) DeleteAiModel(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiModel{}).Error
}
// UpdateAiModel 更新模型
func (s *AiModelService) UpdateAiModel(model *app.AiModel, userID uint) error {
return global.GVA_DB.Where("user_id = ?", userID).Updates(model).Error
}
// GetAiModel 查询模型
func (s *AiModelService) GetAiModel(id uint, userID uint) (model app.AiModel, err error) {
err = global.GVA_DB.Preload("Provider").Preload("Preset").Where("id = ? AND user_id = ?", id, userID).First(&model).Error
return
}
// GetAiModelList 获取模型列表
func (s *AiModelService) GetAiModelList(info request.PageInfo, userID uint) (list []app.AiModel, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiModel{}).Preload("Provider").Preload("Preset").Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}
// GetModelByNameAndProvider 根据模型名称和提供商ID查询模型配置
func (s *AiModelService) GetModelByNameAndProvider(modelName string, providerID uint) (*app.AiModel, error) {
var model app.AiModel
err := global.GVA_DB.Preload("Provider").Preload("Preset").
Where("name = ? AND provider_id = ? AND enabled = ?", modelName, providerID, true).
First(&model).Error
if err != nil {
return nil, fmt.Errorf("未找到模型配置: %s", modelName)
}
return &model, nil
}
// FetchProviderModels 从提供商获取可用模型列表
func (s *AiModelService) FetchProviderModels(provider *app.AiProvider) ([]ProviderModel, error) {
// 构建请求 URL
url := fmt.Sprintf("%s/v1/models", provider.BaseURL)
// 创建 HTTP 请求
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
// 设置请求头
if provider.Type == "openai" || provider.Type == "other" {
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
} else if provider.Type == "claude" {
req.Header.Set("x-api-key", provider.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
}
// 发送请求
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("获取模型列表失败: %d - %s", resp.StatusCode, string(body))
}
// 解析响应
var result struct {
Data []ProviderModel `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return result.Data, nil
}
// SyncProviderModels 同步提供商的模型列表
func (s *AiModelService) SyncProviderModels(providerID uint, userID uint) error {
// 获取提供商信息
var provider app.AiProvider
if err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error; err != nil {
return fmt.Errorf("提供商不存在")
}
// 从提供商获取模型列表
models, err := s.FetchProviderModels(&provider)
if err != nil {
return err
}
// 同步到数据库
for _, model := range models {
var existingModel app.AiModel
err := global.GVA_DB.Where("name = ? AND provider_id = ? AND user_id = ?", model.ID, providerID, userID).First(&existingModel).Error
if err != nil {
// 模型不存在,创建新记录
newModel := app.AiModel{
Name: model.ID,
DisplayName: model.ID,
ProviderID: providerID,
Enabled: false, // 默认不启用,需要管理员手动启用
UserID: userID,
}
global.GVA_DB.Create(&newModel)
}
// 如果模型已存在,不做任何操作(保留用户的配置)
}
return nil
}
// ProviderModel 提供商返回的模型信息
type ProviderModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}