Files
st/server/service/app/provider.go

950 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package app
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/app/response"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type ProviderService struct{}
// ==================== 提供商 CRUD ====================
// CreateProvider 创建AI提供商
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
// 加密 API Key
encryptedKey := encryptAPIKey(req.APIKey)
// 序列化额外配置
apiConfigJSON, _ := json.Marshal(req.APIConfig)
if req.APIConfig == nil {
apiConfigJSON = []byte("{}")
}
// 根据类型确定能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
// 如果 BaseURL 为空,使用默认地址
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
provider := app.AIProvider{
UserID: &userID,
ProviderName: req.ProviderName,
ProviderType: req.ProviderType,
BaseURL: baseURL,
APIKey: encryptedKey,
APIConfig: datatypes.JSON(apiConfigJSON),
Capabilities: datatypes.JSON(capJSON),
IsEnabled: true,
IsDefault: false,
}
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 创建提供商
if err := tx.Create(&provider).Error; err != nil {
return err
}
// 如果携带了模型列表,同时创建模型
if len(req.Models) > 0 {
for _, m := range req.Models {
modelConfigJSON, _ := json.Marshal(m.Config)
if m.Config == nil {
modelConfigJSON = []byte("{}")
}
isEnabled := true
if m.IsEnabled != nil {
isEnabled = *m.IsEnabled
}
model := app.AIModel{
ProviderID: provider.ID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: datatypes.JSON(modelConfigJSON),
IsEnabled: isEnabled,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
} else {
// 没有指定模型时,自动添加预设模型
presets := getPresetModels(req.ProviderType)
for _, p := range presets {
model := app.AIModel{
ProviderID: provider.ID,
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
Config: datatypes.JSON([]byte("{}")),
IsEnabled: true,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
}
// 如果是用户的第一个提供商,自动设为默认
var count int64
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
if count == 0 {
provider.IsDefault = true
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(provider.ID, userID)
}
// GetProviderList 获取用户的提供商列表
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
if req.Keyword != "" {
keyword := "%" + req.Keyword + "%"
db = db.Where("provider_name ILIKE ?", keyword)
}
var total int64
db.Count(&total)
var providers []app.AIProvider
offset := (req.Page - 1) * req.PageSize
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
Offset(offset).Limit(req.PageSize).Find(&providers).Error
if err != nil {
return response.ProviderListResponse{}, err
}
// 获取所有提供商的模型
providerIDs := make([]uint, len(providers))
for i, p := range providers {
providerIDs[i] = p.ID
}
var models []app.AIModel
if len(providerIDs) > 0 {
global.GVA_DB.Where("provider_id IN ?", providerIDs).
Order("model_type ASC, model_name ASC").Find(&models)
}
// 按提供商ID分组模型
modelMap := make(map[uint][]app.AIModel)
for _, m := range models {
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
}
list := make([]response.ProviderResponse, len(providers))
for i, p := range providers {
list[i] = toProviderResponse(&p, modelMap[p.ID])
}
return response.ProviderListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// GetProviderDetail 获取提供商详情
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
var models []app.AIModel
global.GVA_DB.Where("provider_id = ?", providerID).
Order("model_type ASC, model_name ASC").Find(&models)
return toProviderResponse(&provider, models), nil
}
// UpdateProvider 更新提供商
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
// 更新字段
updates := map[string]interface{}{
"provider_name": req.ProviderName,
"provider_type": req.ProviderType,
"base_url": req.BaseURL,
}
// APIKey 不为空时才更新
if req.APIKey != "" {
updates["api_key"] = encryptAPIKey(req.APIKey)
}
if req.APIConfig != nil {
apiConfigJSON, _ := json.Marshal(req.APIConfig)
updates["api_config"] = datatypes.JSON(apiConfigJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if req.SortOrder != nil {
updates["sort_order"] = *req.SortOrder
}
// 更新能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
updates["capabilities"] = datatypes.JSON(capJSON)
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
return err
}
// 处理设置默认
if req.IsDefault != nil && *req.IsDefault {
// 先取消其他默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ? AND id != ?", userID, req.ID).
Update("is_default", false).Error; err != nil {
return err
}
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(req.ID, userID)
}
// DeleteProvider 删除提供商
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 删除关联的模型
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
return err
}
// 删除提供商
if err := tx.Delete(&provider).Error; err != nil {
return err
}
// 如果删除的是默认提供商,自动将第一个提供商设为默认
if provider.IsDefault {
var firstProvider app.AIProvider
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
tx.Model(&firstProvider).Update("is_default", true)
}
}
return nil
})
}
// SetDefaultProvider 设置默认提供商
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 先取消所有默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ?", userID).
Update("is_default", false).Error; err != nil {
return err
}
// 设置新默认
return tx.Model(&provider).Update("is_default", true).Error
})
}
// ==================== 模型 CRUD ====================
// AddModel 为提供商添加模型
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
// 验证提供商归属
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
if err != nil {
return response.ModelResponse{}, errors.New("提供商不存在")
}
configJSON, _ := json.Marshal(req.Config)
if req.Config == nil {
configJSON = []byte("{}")
}
isEnabled := true
if req.IsEnabled != nil {
isEnabled = *req.IsEnabled
}
model := app.AIModel{
ProviderID: req.ProviderID,
ModelName: req.ModelName,
DisplayName: req.DisplayName,
ModelType: req.ModelType,
Config: datatypes.JSON(configJSON),
IsEnabled: isEnabled,
}
if err := global.GVA_DB.Create(&model).Error; err != nil {
return response.ModelResponse{}, err
}
return toModelResponse(&model), nil
}
// UpdateModel 更新模型
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
First(&model).Error
if err != nil {
return response.ModelResponse{}, errors.New("模型不存在")
}
updates := map[string]interface{}{
"model_name": req.ModelName,
"display_name": req.DisplayName,
"model_type": req.ModelType,
}
if req.Config != nil {
configJSON, _ := json.Marshal(req.Config)
updates["config"] = datatypes.JSON(configJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
return response.ModelResponse{}, err
}
// 重新查询
global.GVA_DB.First(&model, model.ID)
return toModelResponse(&model), nil
}
// DeleteModel 删除模型
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
First(&model).Error
if err != nil {
return errors.New("模型不存在")
}
return global.GVA_DB.Delete(&model).Error
}
// ==================== 连通性测试 ====================
// TestProvider 测试提供商连通性
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// TestExistingProvider 测试已保存的提供商连通性
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.TestProviderResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.TestProvider(request.TestProviderRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 辅助查询 ====================
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
return []response.ProviderTypeOption{
{
Value: "openai",
Label: "OpenAI",
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
DefaultURL: "https://api.openai.com/v1",
},
{
Value: "claude",
Label: "Claude",
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
DefaultURL: "https://api.anthropic.com",
},
{
Value: "gemini",
Label: "Google Gemini",
Description: "Google 的 Gemini 系列模型,支持多模态",
DefaultURL: "https://generativelanguage.googleapis.com",
},
{
Value: "custom",
Label: "自定义OpenAI 兼容)",
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
DefaultURL: "",
},
}
}
// GetPresetModels 获取指定提供商类型的预设模型列表
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
presets := getPresetModels(providerType)
result := make([]response.PresetModelOption, len(presets))
for i, p := range presets {
result[i] = response.PresetModelOption{
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
}
}
return result
}
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("请先配置 AI 接口")
}
return nil, err
}
return &provider, nil
}
// GetDecryptedAPIKey 获取解密后的API密钥内部方法给AI调用用
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
return decryptAPIKey(provider.APIKey)
}
// ==================== 内部辅助函数 ====================
// toProviderResponse 转换为响应对象
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
apiConfig := json.RawMessage(p.APIConfig)
if len(apiConfig) == 0 {
apiConfig = json.RawMessage("{}")
}
capabilities := json.RawMessage(p.Capabilities)
if len(capabilities) == 0 {
capabilities = json.RawMessage("[]")
}
// 模型列表
modelList := make([]response.ModelResponse, len(models))
for i, m := range models {
modelList[i] = toModelResponse(&m)
}
// API Key 提示
apiKeyHint := ""
apiKeySet := false
if p.APIKey != "" {
apiKeySet = true
decrypted := decryptAPIKey(p.APIKey)
if len(decrypted) > 8 {
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
} else if len(decrypted) > 0 {
apiKeyHint = "****"
}
}
return response.ProviderResponse{
ID: p.ID,
ProviderName: p.ProviderName,
ProviderType: p.ProviderType,
BaseURL: p.BaseURL,
APIKeySet: apiKeySet,
APIKeyHint: apiKeyHint,
APIConfig: apiConfig,
Capabilities: capabilities,
IsEnabled: p.IsEnabled,
IsDefault: p.IsDefault,
SortOrder: p.SortOrder,
Models: modelList,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
// toModelResponse 转换模型为响应对象
func toModelResponse(m *app.AIModel) response.ModelResponse {
config := json.RawMessage(m.Config)
if len(config) == 0 {
config = json.RawMessage("{}")
}
return response.ModelResponse{
ID: m.ID,
ProviderID: m.ProviderID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: config,
IsEnabled: m.IsEnabled,
CreatedAt: m.CreatedAt,
}
}
// encryptAPIKey 加密API密钥
// TODO: 后续可以替换为更安全的加密方式(如 AES当前使用简单的 Base64 编码
func encryptAPIKey(key string) string {
if key == "" {
return ""
}
// 简单的混淆处理,生产环境应替换为 AES 加密
import_encoding := []byte(key)
for i := range import_encoding {
import_encoding[i] ^= 0x5A
}
return fmt.Sprintf("enc:%x", import_encoding)
}
// decryptAPIKey 解密API密钥
func decryptAPIKey(encrypted string) string {
if encrypted == "" {
return ""
}
if !strings.HasPrefix(encrypted, "enc:") {
return encrypted // 未加密的旧数据,直接返回
}
hexStr := encrypted[4:]
var data []byte
fmt.Sscanf(hexStr, "%x", &data)
for i := range data {
data[i] ^= 0x5A
}
return string(data)
}
// getDefaultBaseURL 获取默认API基础地址
func getDefaultBaseURL(providerType string) string {
switch providerType {
case "openai":
return "https://api.openai.com/v1"
case "claude":
return "https://api.anthropic.com"
case "gemini":
return "https://generativelanguage.googleapis.com"
default:
return ""
}
}
// getDefaultCapabilities 获取默认能力列表
func getDefaultCapabilities(providerType string) []string {
switch providerType {
case "openai":
return []string{"chat", "image_gen"}
case "claude":
return []string{"chat"}
case "gemini":
return []string{"chat", "image_gen"}
case "custom":
return []string{"chat"}
default:
return []string{"chat"}
}
}
// presetModel 预设模型内部结构
type presetModel struct {
ModelName string
DisplayName string
ModelType string
}
// getPresetModels 获取预设模型列表
func getPresetModels(providerType string) []presetModel {
switch providerType {
case "openai":
return []presetModel{
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
}
case "claude":
return []presetModel{
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
}
case "gemini":
return []presetModel{
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
}
case "custom":
return []presetModel{} // 自定义不提供预设
default:
return []presetModel{}
}
}
// ==================== 获取远程模型列表 ====================
// FetchRemoteModels 从远程API获取可用模型列表
// 所有提供商类型统一使用 baseURL + /models 端点OpenAI 兼容格式)
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
result := fetchModelsUniversal(baseURL, req.APIKey)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 发送测试消息 ====================
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
// 所有提供商类型统一使用 baseURL + /chat/completions 端点OpenAI 兼容格式)
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
message := req.Message
if message == "" {
message = "你好,请用一句话介绍你自己。"
}
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// SendTestMessageExisting 发送测试消息(已保存的提供商)
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.SendTestMessage(request.SendTestMessageRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
ModelName: modelName,
Message: message,
})
}
// ==================== 连通性测试实现 ====================
// testOpenAICompatible 测试 OpenAI 兼容接口
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.TestProviderResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
var modelNames []string
if err := json.Unmarshal(body, &modelsResp); err == nil {
for _, m := range modelsResp.Data {
modelNames = append(modelNames, m.ID)
}
}
return response.TestProviderResponse{
Success: true,
Message: "连接成功",
Models: modelNames,
}
}
// ==================== 获取远程模型列表实现 ====================
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
// 使用 baseURL + /models 端点Authorization: Bearer 鉴权
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
if resp.StatusCode == 401 {
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.FetchRemoteModelsResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
body, _ := io.ReadAll(resp.Body)
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
var modelsData struct {
Data []struct {
ID string `json:"id"`
OwnedBy string `json:"owned_by"`
} `json:"data"`
}
var models []response.RemoteModel
if err := json.Unmarshal(body, &modelsData); err == nil {
for _, m := range modelsData.Data {
models = append(models, response.RemoteModel{
ID: m.ID,
OwnedBy: m.OwnedBy,
})
}
}
return response.FetchRemoteModelsResponse{
Success: true,
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
Models: models,
}
}
// ==================== 发送测试消息实现 ====================
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
// 使用 baseURL + /chat/completions 端点Authorization: Bearer 鉴权
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
payload := map[string]interface{}{
"model": modelName,
"max_tokens": 100,
"messages": []map[string]string{
{"role": "user", "content": message},
},
}
payloadBytes, _ := json.Marshal(payload)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
var errResp struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
}
return response.SendTestMessageResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析 OpenAI 兼容格式的 chat completion 响应
var chatResp struct {
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &chatResp); err != nil {
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
}
reply := ""
if len(chatResp.Choices) > 0 {
reply = chatResp.Choices[0].Message.Content
}
return response.SendTestMessageResponse{
Success: true,
Message: "测试成功",
Reply: reply,
Model: chatResp.Model,
Tokens: chatResp.Usage.TotalTokens,
}
}