150 lines
4.6 KiB
Go
150 lines
4.6 KiB
Go
package app
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"time"
|
||
|
||
"git.echol.cn/loser/ai_proxy/server/global"
|
||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||
"git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||
)
|
||
|
||
type AiModelService struct{}
|
||
|
||
// CreateAiModel 创建模型
|
||
func (s *AiModelService) CreateAiModel(model *app.AiModel) error {
|
||
return global.GVA_DB.Create(model).Error
|
||
}
|
||
|
||
// DeleteAiModel 删除模型
|
||
func (s *AiModelService) DeleteAiModel(id uint, userID uint) error {
|
||
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiModel{}).Error
|
||
}
|
||
|
||
// UpdateAiModel 更新模型
|
||
func (s *AiModelService) UpdateAiModel(model *app.AiModel, userID uint) error {
|
||
// 使用 Select("*") 来更新所有字段,包括零值字段(如 enabled=false)
|
||
return global.GVA_DB.Model(&app.AiModel{}).Where("id = ? AND user_id = ?", model.ID, userID).Select("*").Updates(model).Error
|
||
}
|
||
|
||
// GetAiModel 查询模型
|
||
func (s *AiModelService) GetAiModel(id uint, userID uint) (model app.AiModel, err error) {
|
||
err = global.GVA_DB.Preload("Provider").Preload("Preset").Where("id = ? AND user_id = ?", id, userID).First(&model).Error
|
||
return
|
||
}
|
||
|
||
// GetAiModelList 获取模型列表
|
||
func (s *AiModelService) GetAiModelList(info request.PageInfo, userID uint) (list []app.AiModel, total int64, err error) {
|
||
limit := info.PageSize
|
||
offset := info.PageSize * (info.Page - 1)
|
||
db := global.GVA_DB.Model(&app.AiModel{}).Preload("Provider").Preload("Preset").Where("user_id = ?", userID)
|
||
err = db.Count(&total).Error
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
|
||
return
|
||
}
|
||
|
||
// GetModelByNameAndProvider 根据模型名称和提供商ID查询模型配置
|
||
func (s *AiModelService) GetModelByNameAndProvider(modelName string, providerID uint) (*app.AiModel, error) {
|
||
var model app.AiModel
|
||
err := global.GVA_DB.Preload("Provider").Preload("Preset").
|
||
Where("name = ? AND provider_id = ? AND enabled = ?", modelName, providerID, true).
|
||
First(&model).Error
|
||
if err != nil {
|
||
return nil, fmt.Errorf("未找到模型配置: %s", modelName)
|
||
}
|
||
return &model, nil
|
||
}
|
||
|
||
// FetchProviderModels 从提供商获取可用模型列表
|
||
func (s *AiModelService) FetchProviderModels(provider *app.AiProvider) ([]ProviderModel, error) {
|
||
// 构建请求 URL
|
||
url := fmt.Sprintf("%s/v1/models", provider.BaseURL)
|
||
|
||
// 创建 HTTP 请求
|
||
req, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 设置请求头
|
||
if provider.Type == "openai" || provider.Type == "other" {
|
||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||
} else if provider.Type == "claude" {
|
||
req.Header.Set("x-api-key", provider.APIKey)
|
||
req.Header.Set("anthropic-version", "2023-06-01")
|
||
}
|
||
|
||
// 发送请求
|
||
client := &http.Client{Timeout: 10 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, fmt.Errorf("获取模型列表失败: %d - %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 解析响应
|
||
var result struct {
|
||
Data []ProviderModel `json:"data"`
|
||
}
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
return result.Data, nil
|
||
}
|
||
|
||
// SyncProviderModels 同步提供商的模型列表
|
||
func (s *AiModelService) SyncProviderModels(providerID uint, userID uint) error {
|
||
// 获取提供商信息
|
||
var provider app.AiProvider
|
||
if err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error; err != nil {
|
||
return fmt.Errorf("提供商不存在")
|
||
}
|
||
|
||
// 从提供商获取模型列表
|
||
models, err := s.FetchProviderModels(&provider)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 同步到数据库
|
||
for _, model := range models {
|
||
var existingModel app.AiModel
|
||
err := global.GVA_DB.Where("name = ? AND provider_id = ? AND user_id = ?", model.ID, providerID, userID).First(&existingModel).Error
|
||
|
||
if err != nil {
|
||
// 模型不存在,创建新记录
|
||
newModel := app.AiModel{
|
||
Name: model.ID,
|
||
DisplayName: model.ID,
|
||
ProviderID: providerID,
|
||
Enabled: false, // 默认不启用,需要管理员手动启用
|
||
UserID: userID,
|
||
}
|
||
global.GVA_DB.Create(&newModel)
|
||
}
|
||
// 如果模型已存在,不做任何操作(保留用户的配置)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ProviderModel 提供商返回的模型信息
|
||
type ProviderModel struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
OwnedBy string `json:"owned_by"`
|
||
}
|