950 lines
27 KiB
Go
950 lines
27 KiB
Go
package app
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"git.echol.cn/loser/st/server/global"
|
||
"git.echol.cn/loser/st/server/model/app"
|
||
"git.echol.cn/loser/st/server/model/app/request"
|
||
"git.echol.cn/loser/st/server/model/app/response"
|
||
"gorm.io/datatypes"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type ProviderService struct{}
|
||
|
||
// ==================== 提供商 CRUD ====================
|
||
|
||
// CreateProvider 创建AI提供商
|
||
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||
// 加密 API Key
|
||
encryptedKey := encryptAPIKey(req.APIKey)
|
||
|
||
// 序列化额外配置
|
||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||
if req.APIConfig == nil {
|
||
apiConfigJSON = []byte("{}")
|
||
}
|
||
|
||
// 根据类型确定能力
|
||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||
capJSON, _ := json.Marshal(capabilities)
|
||
|
||
// 如果 BaseURL 为空,使用默认地址
|
||
baseURL := req.BaseURL
|
||
if baseURL == "" {
|
||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||
}
|
||
|
||
provider := app.AIProvider{
|
||
UserID: &userID,
|
||
ProviderName: req.ProviderName,
|
||
ProviderType: req.ProviderType,
|
||
BaseURL: baseURL,
|
||
APIKey: encryptedKey,
|
||
APIConfig: datatypes.JSON(apiConfigJSON),
|
||
Capabilities: datatypes.JSON(capJSON),
|
||
IsEnabled: true,
|
||
IsDefault: false,
|
||
}
|
||
|
||
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||
// 创建提供商
|
||
if err := tx.Create(&provider).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 如果携带了模型列表,同时创建模型
|
||
if len(req.Models) > 0 {
|
||
for _, m := range req.Models {
|
||
modelConfigJSON, _ := json.Marshal(m.Config)
|
||
if m.Config == nil {
|
||
modelConfigJSON = []byte("{}")
|
||
}
|
||
isEnabled := true
|
||
if m.IsEnabled != nil {
|
||
isEnabled = *m.IsEnabled
|
||
}
|
||
model := app.AIModel{
|
||
ProviderID: provider.ID,
|
||
ModelName: m.ModelName,
|
||
DisplayName: m.DisplayName,
|
||
ModelType: m.ModelType,
|
||
Config: datatypes.JSON(modelConfigJSON),
|
||
IsEnabled: isEnabled,
|
||
}
|
||
if err := tx.Create(&model).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
} else {
|
||
// 没有指定模型时,自动添加预设模型
|
||
presets := getPresetModels(req.ProviderType)
|
||
for _, p := range presets {
|
||
model := app.AIModel{
|
||
ProviderID: provider.ID,
|
||
ModelName: p.ModelName,
|
||
DisplayName: p.DisplayName,
|
||
ModelType: p.ModelType,
|
||
Config: datatypes.JSON([]byte("{}")),
|
||
IsEnabled: true,
|
||
}
|
||
if err := tx.Create(&model).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果是用户的第一个提供商,自动设为默认
|
||
var count int64
|
||
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
|
||
if count == 0 {
|
||
provider.IsDefault = true
|
||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return response.ProviderResponse{}, err
|
||
}
|
||
|
||
return ps.GetProviderDetail(provider.ID, userID)
|
||
}
|
||
|
||
// GetProviderList 获取用户的提供商列表
|
||
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
|
||
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
|
||
|
||
if req.Keyword != "" {
|
||
keyword := "%" + req.Keyword + "%"
|
||
db = db.Where("provider_name ILIKE ?", keyword)
|
||
}
|
||
|
||
var total int64
|
||
db.Count(&total)
|
||
|
||
var providers []app.AIProvider
|
||
offset := (req.Page - 1) * req.PageSize
|
||
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
|
||
Offset(offset).Limit(req.PageSize).Find(&providers).Error
|
||
if err != nil {
|
||
return response.ProviderListResponse{}, err
|
||
}
|
||
|
||
// 获取所有提供商的模型
|
||
providerIDs := make([]uint, len(providers))
|
||
for i, p := range providers {
|
||
providerIDs[i] = p.ID
|
||
}
|
||
|
||
var models []app.AIModel
|
||
if len(providerIDs) > 0 {
|
||
global.GVA_DB.Where("provider_id IN ?", providerIDs).
|
||
Order("model_type ASC, model_name ASC").Find(&models)
|
||
}
|
||
|
||
// 按提供商ID分组模型
|
||
modelMap := make(map[uint][]app.AIModel)
|
||
for _, m := range models {
|
||
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
|
||
}
|
||
|
||
list := make([]response.ProviderResponse, len(providers))
|
||
for i, p := range providers {
|
||
list[i] = toProviderResponse(&p, modelMap[p.ID])
|
||
}
|
||
|
||
return response.ProviderListResponse{
|
||
List: list,
|
||
Total: total,
|
||
Page: req.Page,
|
||
PageSize: req.PageSize,
|
||
}, nil
|
||
}
|
||
|
||
// GetProviderDetail 获取提供商详情
|
||
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||
}
|
||
return response.ProviderResponse{}, err
|
||
}
|
||
|
||
var models []app.AIModel
|
||
global.GVA_DB.Where("provider_id = ?", providerID).
|
||
Order("model_type ASC, model_name ASC").Find(&models)
|
||
|
||
return toProviderResponse(&provider, models), nil
|
||
}
|
||
|
||
// UpdateProvider 更新提供商
|
||
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||
}
|
||
return response.ProviderResponse{}, err
|
||
}
|
||
|
||
// 更新字段
|
||
updates := map[string]interface{}{
|
||
"provider_name": req.ProviderName,
|
||
"provider_type": req.ProviderType,
|
||
"base_url": req.BaseURL,
|
||
}
|
||
|
||
// APIKey 不为空时才更新
|
||
if req.APIKey != "" {
|
||
updates["api_key"] = encryptAPIKey(req.APIKey)
|
||
}
|
||
|
||
if req.APIConfig != nil {
|
||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||
updates["api_config"] = datatypes.JSON(apiConfigJSON)
|
||
}
|
||
|
||
if req.IsEnabled != nil {
|
||
updates["is_enabled"] = *req.IsEnabled
|
||
}
|
||
|
||
if req.SortOrder != nil {
|
||
updates["sort_order"] = *req.SortOrder
|
||
}
|
||
|
||
// 更新能力
|
||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||
capJSON, _ := json.Marshal(capabilities)
|
||
updates["capabilities"] = datatypes.JSON(capJSON)
|
||
|
||
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 处理设置默认
|
||
if req.IsDefault != nil && *req.IsDefault {
|
||
// 先取消其他默认
|
||
if err := tx.Model(&app.AIProvider{}).
|
||
Where("user_id = ? AND id != ?", userID, req.ID).
|
||
Update("is_default", false).Error; err != nil {
|
||
return err
|
||
}
|
||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return response.ProviderResponse{}, err
|
||
}
|
||
|
||
return ps.GetProviderDetail(req.ID, userID)
|
||
}
|
||
|
||
// DeleteProvider 删除提供商
|
||
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("提供商不存在")
|
||
}
|
||
return err
|
||
}
|
||
|
||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||
// 删除关联的模型
|
||
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
|
||
return err
|
||
}
|
||
// 删除提供商
|
||
if err := tx.Delete(&provider).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 如果删除的是默认提供商,自动将第一个提供商设为默认
|
||
if provider.IsDefault {
|
||
var firstProvider app.AIProvider
|
||
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
|
||
tx.Model(&firstProvider).Update("is_default", true)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// SetDefaultProvider 设置默认提供商
|
||
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("提供商不存在")
|
||
}
|
||
return err
|
||
}
|
||
|
||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||
// 先取消所有默认
|
||
if err := tx.Model(&app.AIProvider{}).
|
||
Where("user_id = ?", userID).
|
||
Update("is_default", false).Error; err != nil {
|
||
return err
|
||
}
|
||
// 设置新默认
|
||
return tx.Model(&provider).Update("is_default", true).Error
|
||
})
|
||
}
|
||
|
||
// ==================== 模型 CRUD ====================
|
||
|
||
// AddModel 为提供商添加模型
|
||
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
|
||
// 验证提供商归属
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
|
||
if err != nil {
|
||
return response.ModelResponse{}, errors.New("提供商不存在")
|
||
}
|
||
|
||
configJSON, _ := json.Marshal(req.Config)
|
||
if req.Config == nil {
|
||
configJSON = []byte("{}")
|
||
}
|
||
|
||
isEnabled := true
|
||
if req.IsEnabled != nil {
|
||
isEnabled = *req.IsEnabled
|
||
}
|
||
|
||
model := app.AIModel{
|
||
ProviderID: req.ProviderID,
|
||
ModelName: req.ModelName,
|
||
DisplayName: req.DisplayName,
|
||
ModelType: req.ModelType,
|
||
Config: datatypes.JSON(configJSON),
|
||
IsEnabled: isEnabled,
|
||
}
|
||
|
||
if err := global.GVA_DB.Create(&model).Error; err != nil {
|
||
return response.ModelResponse{}, err
|
||
}
|
||
|
||
return toModelResponse(&model), nil
|
||
}
|
||
|
||
// UpdateModel 更新模型
|
||
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
|
||
var model app.AIModel
|
||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
|
||
First(&model).Error
|
||
if err != nil {
|
||
return response.ModelResponse{}, errors.New("模型不存在")
|
||
}
|
||
|
||
updates := map[string]interface{}{
|
||
"model_name": req.ModelName,
|
||
"display_name": req.DisplayName,
|
||
"model_type": req.ModelType,
|
||
}
|
||
|
||
if req.Config != nil {
|
||
configJSON, _ := json.Marshal(req.Config)
|
||
updates["config"] = datatypes.JSON(configJSON)
|
||
}
|
||
|
||
if req.IsEnabled != nil {
|
||
updates["is_enabled"] = *req.IsEnabled
|
||
}
|
||
|
||
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
|
||
return response.ModelResponse{}, err
|
||
}
|
||
|
||
// 重新查询
|
||
global.GVA_DB.First(&model, model.ID)
|
||
return toModelResponse(&model), nil
|
||
}
|
||
|
||
// DeleteModel 删除模型
|
||
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
|
||
var model app.AIModel
|
||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
|
||
First(&model).Error
|
||
if err != nil {
|
||
return errors.New("模型不存在")
|
||
}
|
||
|
||
return global.GVA_DB.Delete(&model).Error
|
||
}
|
||
|
||
// ==================== 连通性测试 ====================
|
||
|
||
// TestProvider 测试提供商连通性
|
||
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
|
||
startTime := time.Now()
|
||
|
||
baseURL := req.BaseURL
|
||
if baseURL == "" {
|
||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||
}
|
||
|
||
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
|
||
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
|
||
result.Latency = time.Since(startTime).Milliseconds()
|
||
return result, nil
|
||
}
|
||
|
||
// TestExistingProvider 测试已保存的提供商连通性
|
||
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||
if err != nil {
|
||
return response.TestProviderResponse{}, errors.New("提供商不存在")
|
||
}
|
||
|
||
apiKey := decryptAPIKey(provider.APIKey)
|
||
|
||
return ps.TestProvider(request.TestProviderRequest{
|
||
ProviderType: provider.ProviderType,
|
||
BaseURL: provider.BaseURL,
|
||
APIKey: apiKey,
|
||
})
|
||
}
|
||
|
||
// ==================== 辅助查询 ====================
|
||
|
||
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
|
||
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
|
||
return []response.ProviderTypeOption{
|
||
{
|
||
Value: "openai",
|
||
Label: "OpenAI",
|
||
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
|
||
DefaultURL: "https://api.openai.com/v1",
|
||
},
|
||
{
|
||
Value: "claude",
|
||
Label: "Claude",
|
||
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
|
||
DefaultURL: "https://api.anthropic.com",
|
||
},
|
||
{
|
||
Value: "gemini",
|
||
Label: "Google Gemini",
|
||
Description: "Google 的 Gemini 系列模型,支持多模态",
|
||
DefaultURL: "https://generativelanguage.googleapis.com",
|
||
},
|
||
{
|
||
Value: "custom",
|
||
Label: "自定义(OpenAI 兼容)",
|
||
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
|
||
DefaultURL: "",
|
||
},
|
||
}
|
||
}
|
||
|
||
// GetPresetModels 获取指定提供商类型的预设模型列表
|
||
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
|
||
presets := getPresetModels(providerType)
|
||
result := make([]response.PresetModelOption, len(presets))
|
||
for i, p := range presets {
|
||
result[i] = response.PresetModelOption{
|
||
ModelName: p.ModelName,
|
||
DisplayName: p.DisplayName,
|
||
ModelType: p.ModelType,
|
||
}
|
||
}
|
||
return result
|
||
}
|
||
|
||
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
|
||
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
|
||
First(&provider).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("请先配置 AI 接口")
|
||
}
|
||
return nil, err
|
||
}
|
||
return &provider, nil
|
||
}
|
||
|
||
// GetDecryptedAPIKey 获取解密后的API密钥(内部方法,给AI调用用)
|
||
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
|
||
return decryptAPIKey(provider.APIKey)
|
||
}
|
||
|
||
// ==================== 内部辅助函数 ====================
|
||
|
||
// toProviderResponse 转换为响应对象
|
||
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
|
||
apiConfig := json.RawMessage(p.APIConfig)
|
||
if len(apiConfig) == 0 {
|
||
apiConfig = json.RawMessage("{}")
|
||
}
|
||
|
||
capabilities := json.RawMessage(p.Capabilities)
|
||
if len(capabilities) == 0 {
|
||
capabilities = json.RawMessage("[]")
|
||
}
|
||
|
||
// 模型列表
|
||
modelList := make([]response.ModelResponse, len(models))
|
||
for i, m := range models {
|
||
modelList[i] = toModelResponse(&m)
|
||
}
|
||
|
||
// API Key 提示
|
||
apiKeyHint := ""
|
||
apiKeySet := false
|
||
if p.APIKey != "" {
|
||
apiKeySet = true
|
||
decrypted := decryptAPIKey(p.APIKey)
|
||
if len(decrypted) > 8 {
|
||
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
|
||
} else if len(decrypted) > 0 {
|
||
apiKeyHint = "****"
|
||
}
|
||
}
|
||
|
||
return response.ProviderResponse{
|
||
ID: p.ID,
|
||
ProviderName: p.ProviderName,
|
||
ProviderType: p.ProviderType,
|
||
BaseURL: p.BaseURL,
|
||
APIKeySet: apiKeySet,
|
||
APIKeyHint: apiKeyHint,
|
||
APIConfig: apiConfig,
|
||
Capabilities: capabilities,
|
||
IsEnabled: p.IsEnabled,
|
||
IsDefault: p.IsDefault,
|
||
SortOrder: p.SortOrder,
|
||
Models: modelList,
|
||
CreatedAt: p.CreatedAt,
|
||
UpdatedAt: p.UpdatedAt,
|
||
}
|
||
}
|
||
|
||
// toModelResponse 转换模型为响应对象
|
||
func toModelResponse(m *app.AIModel) response.ModelResponse {
|
||
config := json.RawMessage(m.Config)
|
||
if len(config) == 0 {
|
||
config = json.RawMessage("{}")
|
||
}
|
||
return response.ModelResponse{
|
||
ID: m.ID,
|
||
ProviderID: m.ProviderID,
|
||
ModelName: m.ModelName,
|
||
DisplayName: m.DisplayName,
|
||
ModelType: m.ModelType,
|
||
Config: config,
|
||
IsEnabled: m.IsEnabled,
|
||
CreatedAt: m.CreatedAt,
|
||
}
|
||
}
|
||
|
||
// encryptAPIKey 加密API密钥
|
||
// TODO: 后续可以替换为更安全的加密方式(如 AES),当前使用简单的 Base64 编码
|
||
func encryptAPIKey(key string) string {
|
||
if key == "" {
|
||
return ""
|
||
}
|
||
// 简单的混淆处理,生产环境应替换为 AES 加密
|
||
import_encoding := []byte(key)
|
||
for i := range import_encoding {
|
||
import_encoding[i] ^= 0x5A
|
||
}
|
||
return fmt.Sprintf("enc:%x", import_encoding)
|
||
}
|
||
|
||
// decryptAPIKey 解密API密钥
|
||
func decryptAPIKey(encrypted string) string {
|
||
if encrypted == "" {
|
||
return ""
|
||
}
|
||
if !strings.HasPrefix(encrypted, "enc:") {
|
||
return encrypted // 未加密的旧数据,直接返回
|
||
}
|
||
|
||
hexStr := encrypted[4:]
|
||
var data []byte
|
||
fmt.Sscanf(hexStr, "%x", &data)
|
||
for i := range data {
|
||
data[i] ^= 0x5A
|
||
}
|
||
return string(data)
|
||
}
|
||
|
||
// getDefaultBaseURL 获取默认API基础地址
|
||
func getDefaultBaseURL(providerType string) string {
|
||
switch providerType {
|
||
case "openai":
|
||
return "https://api.openai.com/v1"
|
||
case "claude":
|
||
return "https://api.anthropic.com"
|
||
case "gemini":
|
||
return "https://generativelanguage.googleapis.com"
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// getDefaultCapabilities 获取默认能力列表
|
||
func getDefaultCapabilities(providerType string) []string {
|
||
switch providerType {
|
||
case "openai":
|
||
return []string{"chat", "image_gen"}
|
||
case "claude":
|
||
return []string{"chat"}
|
||
case "gemini":
|
||
return []string{"chat", "image_gen"}
|
||
case "custom":
|
||
return []string{"chat"}
|
||
default:
|
||
return []string{"chat"}
|
||
}
|
||
}
|
||
|
||
// presetModel 预设模型内部结构
|
||
type presetModel struct {
|
||
ModelName string
|
||
DisplayName string
|
||
ModelType string
|
||
}
|
||
|
||
// getPresetModels 获取预设模型列表
|
||
func getPresetModels(providerType string) []presetModel {
|
||
switch providerType {
|
||
case "openai":
|
||
return []presetModel{
|
||
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
|
||
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
|
||
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
|
||
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
|
||
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
|
||
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
|
||
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
|
||
}
|
||
case "claude":
|
||
return []presetModel{
|
||
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
|
||
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
|
||
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
|
||
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
|
||
}
|
||
case "gemini":
|
||
return []presetModel{
|
||
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
|
||
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
|
||
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
|
||
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
|
||
}
|
||
case "custom":
|
||
return []presetModel{} // 自定义不提供预设
|
||
default:
|
||
return []presetModel{}
|
||
}
|
||
}
|
||
|
||
// ==================== 获取远程模型列表 ====================
|
||
|
||
// FetchRemoteModels 从远程API获取可用模型列表
|
||
// 所有提供商类型统一使用 baseURL + /models 端点(OpenAI 兼容格式)
|
||
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
|
||
startTime := time.Now()
|
||
|
||
baseURL := req.BaseURL
|
||
if baseURL == "" {
|
||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||
}
|
||
|
||
result := fetchModelsUniversal(baseURL, req.APIKey)
|
||
result.Latency = time.Since(startTime).Milliseconds()
|
||
return result, nil
|
||
}
|
||
|
||
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
|
||
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||
if err != nil {
|
||
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
|
||
}
|
||
|
||
apiKey := decryptAPIKey(provider.APIKey)
|
||
|
||
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
|
||
ProviderType: provider.ProviderType,
|
||
BaseURL: provider.BaseURL,
|
||
APIKey: apiKey,
|
||
})
|
||
}
|
||
|
||
// ==================== 发送测试消息 ====================
|
||
|
||
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
|
||
// 所有提供商类型统一使用 baseURL + /chat/completions 端点(OpenAI 兼容格式)
|
||
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
|
||
startTime := time.Now()
|
||
|
||
baseURL := req.BaseURL
|
||
if baseURL == "" {
|
||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||
}
|
||
|
||
message := req.Message
|
||
if message == "" {
|
||
message = "你好,请用一句话介绍你自己。"
|
||
}
|
||
|
||
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
|
||
result.Latency = time.Since(startTime).Milliseconds()
|
||
return result, nil
|
||
}
|
||
|
||
// SendTestMessageExisting 发送测试消息(已保存的提供商)
|
||
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
|
||
var provider app.AIProvider
|
||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||
if err != nil {
|
||
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
|
||
}
|
||
|
||
apiKey := decryptAPIKey(provider.APIKey)
|
||
|
||
return ps.SendTestMessage(request.SendTestMessageRequest{
|
||
ProviderType: provider.ProviderType,
|
||
BaseURL: provider.BaseURL,
|
||
APIKey: apiKey,
|
||
ModelName: modelName,
|
||
Message: message,
|
||
})
|
||
}
|
||
|
||
// ==================== 连通性测试实现 ====================
|
||
|
||
// testOpenAICompatible 测试 OpenAI 兼容接口
|
||
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
|
||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||
defer cancel()
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||
if err != nil {
|
||
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
resp, err := http.DefaultClient.Do(req)
|
||
if err != nil {
|
||
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
|
||
if resp.StatusCode == 401 {
|
||
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
|
||
}
|
||
|
||
if resp.StatusCode != 200 {
|
||
return response.TestProviderResponse{
|
||
Success: false,
|
||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||
}
|
||
}
|
||
|
||
// 解析模型列表
|
||
var modelsResp struct {
|
||
Data []struct {
|
||
ID string `json:"id"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
var modelNames []string
|
||
if err := json.Unmarshal(body, &modelsResp); err == nil {
|
||
for _, m := range modelsResp.Data {
|
||
modelNames = append(modelNames, m.ID)
|
||
}
|
||
}
|
||
|
||
return response.TestProviderResponse{
|
||
Success: true,
|
||
Message: "连接成功",
|
||
Models: modelNames,
|
||
}
|
||
}
|
||
|
||
// ==================== 获取远程模型列表实现 ====================
|
||
|
||
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
|
||
// 使用 baseURL + /models 端点,Authorization: Bearer 鉴权
|
||
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
|
||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||
defer cancel()
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||
if err != nil {
|
||
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
resp, err := http.DefaultClient.Do(req)
|
||
if err != nil {
|
||
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode == 401 {
|
||
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
|
||
}
|
||
if resp.StatusCode != 200 {
|
||
return response.FetchRemoteModelsResponse{
|
||
Success: false,
|
||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||
}
|
||
}
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
|
||
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
|
||
var modelsData struct {
|
||
Data []struct {
|
||
ID string `json:"id"`
|
||
OwnedBy string `json:"owned_by"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
var models []response.RemoteModel
|
||
if err := json.Unmarshal(body, &modelsData); err == nil {
|
||
for _, m := range modelsData.Data {
|
||
models = append(models, response.RemoteModel{
|
||
ID: m.ID,
|
||
OwnedBy: m.OwnedBy,
|
||
})
|
||
}
|
||
}
|
||
|
||
return response.FetchRemoteModelsResponse{
|
||
Success: true,
|
||
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
|
||
Models: models,
|
||
}
|
||
}
|
||
|
||
// ==================== 发送测试消息实现 ====================
|
||
|
||
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
|
||
// 使用 baseURL + /chat/completions 端点,Authorization: Bearer 鉴权
|
||
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
|
||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||
|
||
payload := map[string]interface{}{
|
||
"model": modelName,
|
||
"max_tokens": 100,
|
||
"messages": []map[string]string{
|
||
{"role": "user", "content": message},
|
||
},
|
||
}
|
||
payloadBytes, _ := json.Marshal(payload)
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
|
||
if err != nil {
|
||
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := http.DefaultClient.Do(req)
|
||
if err != nil {
|
||
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
|
||
if resp.StatusCode == 401 {
|
||
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
|
||
}
|
||
|
||
if resp.StatusCode != 200 {
|
||
var errResp struct {
|
||
Error struct {
|
||
Message string `json:"message"`
|
||
} `json:"error"`
|
||
}
|
||
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
|
||
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
|
||
}
|
||
return response.SendTestMessageResponse{
|
||
Success: false,
|
||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||
}
|
||
}
|
||
|
||
// 解析 OpenAI 兼容格式的 chat completion 响应
|
||
var chatResp struct {
|
||
Model string `json:"model"`
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Usage struct {
|
||
TotalTokens int `json:"total_tokens"`
|
||
} `json:"usage"`
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
|
||
}
|
||
|
||
reply := ""
|
||
if len(chatResp.Choices) > 0 {
|
||
reply = chatResp.Choices[0].Message.Content
|
||
}
|
||
|
||
return response.SendTestMessageResponse{
|
||
Success: true,
|
||
Message: "测试成功",
|
||
Reply: reply,
|
||
Model: chatResp.Model,
|
||
Tokens: chatResp.Usage.TotalTokens,
|
||
}
|
||
}
|