🎨 优化扩展模块,完成ai接入和对话功能

This commit is contained in:
2026-02-12 23:12:28 +08:00
parent 4e611d3a5e
commit 572f3aa15b
779 changed files with 194400 additions and 3136 deletions

View File

@@ -0,0 +1,949 @@
package app
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/st/server/global"
"git.echol.cn/loser/st/server/model/app"
"git.echol.cn/loser/st/server/model/app/request"
"git.echol.cn/loser/st/server/model/app/response"
"gorm.io/datatypes"
"gorm.io/gorm"
)
type ProviderService struct{}
// ==================== 提供商 CRUD ====================
// CreateProvider 创建AI提供商
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
// 加密 API Key
encryptedKey := encryptAPIKey(req.APIKey)
// 序列化额外配置
apiConfigJSON, _ := json.Marshal(req.APIConfig)
if req.APIConfig == nil {
apiConfigJSON = []byte("{}")
}
// 根据类型确定能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
// 如果 BaseURL 为空,使用默认地址
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
provider := app.AIProvider{
UserID: &userID,
ProviderName: req.ProviderName,
ProviderType: req.ProviderType,
BaseURL: baseURL,
APIKey: encryptedKey,
APIConfig: datatypes.JSON(apiConfigJSON),
Capabilities: datatypes.JSON(capJSON),
IsEnabled: true,
IsDefault: false,
}
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 创建提供商
if err := tx.Create(&provider).Error; err != nil {
return err
}
// 如果携带了模型列表,同时创建模型
if len(req.Models) > 0 {
for _, m := range req.Models {
modelConfigJSON, _ := json.Marshal(m.Config)
if m.Config == nil {
modelConfigJSON = []byte("{}")
}
isEnabled := true
if m.IsEnabled != nil {
isEnabled = *m.IsEnabled
}
model := app.AIModel{
ProviderID: provider.ID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: datatypes.JSON(modelConfigJSON),
IsEnabled: isEnabled,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
} else {
// 没有指定模型时,自动添加预设模型
presets := getPresetModels(req.ProviderType)
for _, p := range presets {
model := app.AIModel{
ProviderID: provider.ID,
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
Config: datatypes.JSON([]byte("{}")),
IsEnabled: true,
}
if err := tx.Create(&model).Error; err != nil {
return err
}
}
}
// 如果是用户的第一个提供商,自动设为默认
var count int64
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
if count == 0 {
provider.IsDefault = true
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(provider.ID, userID)
}
// GetProviderList 获取用户的提供商列表
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
if req.Keyword != "" {
keyword := "%" + req.Keyword + "%"
db = db.Where("provider_name ILIKE ?", keyword)
}
var total int64
db.Count(&total)
var providers []app.AIProvider
offset := (req.Page - 1) * req.PageSize
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
Offset(offset).Limit(req.PageSize).Find(&providers).Error
if err != nil {
return response.ProviderListResponse{}, err
}
// 获取所有提供商的模型
providerIDs := make([]uint, len(providers))
for i, p := range providers {
providerIDs[i] = p.ID
}
var models []app.AIModel
if len(providerIDs) > 0 {
global.GVA_DB.Where("provider_id IN ?", providerIDs).
Order("model_type ASC, model_name ASC").Find(&models)
}
// 按提供商ID分组模型
modelMap := make(map[uint][]app.AIModel)
for _, m := range models {
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
}
list := make([]response.ProviderResponse, len(providers))
for i, p := range providers {
list[i] = toProviderResponse(&p, modelMap[p.ID])
}
return response.ProviderListResponse{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, nil
}
// GetProviderDetail 获取提供商详情
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
var models []app.AIModel
global.GVA_DB.Where("provider_id = ?", providerID).
Order("model_type ASC, model_name ASC").Find(&models)
return toProviderResponse(&provider, models), nil
}
// UpdateProvider 更新提供商
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return response.ProviderResponse{}, errors.New("提供商不存在")
}
return response.ProviderResponse{}, err
}
// 更新字段
updates := map[string]interface{}{
"provider_name": req.ProviderName,
"provider_type": req.ProviderType,
"base_url": req.BaseURL,
}
// APIKey 不为空时才更新
if req.APIKey != "" {
updates["api_key"] = encryptAPIKey(req.APIKey)
}
if req.APIConfig != nil {
apiConfigJSON, _ := json.Marshal(req.APIConfig)
updates["api_config"] = datatypes.JSON(apiConfigJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if req.SortOrder != nil {
updates["sort_order"] = *req.SortOrder
}
// 更新能力
capabilities := getDefaultCapabilities(req.ProviderType)
capJSON, _ := json.Marshal(capabilities)
updates["capabilities"] = datatypes.JSON(capJSON)
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
return err
}
// 处理设置默认
if req.IsDefault != nil && *req.IsDefault {
// 先取消其他默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ? AND id != ?", userID, req.ID).
Update("is_default", false).Error; err != nil {
return err
}
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
return response.ProviderResponse{}, err
}
return ps.GetProviderDetail(req.ID, userID)
}
// DeleteProvider 删除提供商
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 删除关联的模型
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
return err
}
// 删除提供商
if err := tx.Delete(&provider).Error; err != nil {
return err
}
// 如果删除的是默认提供商,自动将第一个提供商设为默认
if provider.IsDefault {
var firstProvider app.AIProvider
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
tx.Model(&firstProvider).Update("is_default", true)
}
}
return nil
})
}
// SetDefaultProvider 设置默认提供商
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 先取消所有默认
if err := tx.Model(&app.AIProvider{}).
Where("user_id = ?", userID).
Update("is_default", false).Error; err != nil {
return err
}
// 设置新默认
return tx.Model(&provider).Update("is_default", true).Error
})
}
// ==================== 模型 CRUD ====================
// AddModel 为提供商添加模型
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
// 验证提供商归属
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
if err != nil {
return response.ModelResponse{}, errors.New("提供商不存在")
}
configJSON, _ := json.Marshal(req.Config)
if req.Config == nil {
configJSON = []byte("{}")
}
isEnabled := true
if req.IsEnabled != nil {
isEnabled = *req.IsEnabled
}
model := app.AIModel{
ProviderID: req.ProviderID,
ModelName: req.ModelName,
DisplayName: req.DisplayName,
ModelType: req.ModelType,
Config: datatypes.JSON(configJSON),
IsEnabled: isEnabled,
}
if err := global.GVA_DB.Create(&model).Error; err != nil {
return response.ModelResponse{}, err
}
return toModelResponse(&model), nil
}
// UpdateModel 更新模型
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
First(&model).Error
if err != nil {
return response.ModelResponse{}, errors.New("模型不存在")
}
updates := map[string]interface{}{
"model_name": req.ModelName,
"display_name": req.DisplayName,
"model_type": req.ModelType,
}
if req.Config != nil {
configJSON, _ := json.Marshal(req.Config)
updates["config"] = datatypes.JSON(configJSON)
}
if req.IsEnabled != nil {
updates["is_enabled"] = *req.IsEnabled
}
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
return response.ModelResponse{}, err
}
// 重新查询
global.GVA_DB.First(&model, model.ID)
return toModelResponse(&model), nil
}
// DeleteModel 删除模型
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
var model app.AIModel
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
First(&model).Error
if err != nil {
return errors.New("模型不存在")
}
return global.GVA_DB.Delete(&model).Error
}
// ==================== 连通性测试 ====================
// TestProvider 测试提供商连通性
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// TestExistingProvider 测试已保存的提供商连通性
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.TestProviderResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.TestProvider(request.TestProviderRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 辅助查询 ====================
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
return []response.ProviderTypeOption{
{
Value: "openai",
Label: "OpenAI",
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
DefaultURL: "https://api.openai.com/v1",
},
{
Value: "claude",
Label: "Claude",
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
DefaultURL: "https://api.anthropic.com",
},
{
Value: "gemini",
Label: "Google Gemini",
Description: "Google 的 Gemini 系列模型,支持多模态",
DefaultURL: "https://generativelanguage.googleapis.com",
},
{
Value: "custom",
Label: "自定义OpenAI 兼容)",
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
DefaultURL: "",
},
}
}
// GetPresetModels 获取指定提供商类型的预设模型列表
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
presets := getPresetModels(providerType)
result := make([]response.PresetModelOption, len(presets))
for i, p := range presets {
result[i] = response.PresetModelOption{
ModelName: p.ModelName,
DisplayName: p.DisplayName,
ModelType: p.ModelType,
}
}
return result
}
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
First(&provider).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("请先配置 AI 接口")
}
return nil, err
}
return &provider, nil
}
// GetDecryptedAPIKey 获取解密后的API密钥内部方法给AI调用用
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
return decryptAPIKey(provider.APIKey)
}
// ==================== 内部辅助函数 ====================
// toProviderResponse 转换为响应对象
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
apiConfig := json.RawMessage(p.APIConfig)
if len(apiConfig) == 0 {
apiConfig = json.RawMessage("{}")
}
capabilities := json.RawMessage(p.Capabilities)
if len(capabilities) == 0 {
capabilities = json.RawMessage("[]")
}
// 模型列表
modelList := make([]response.ModelResponse, len(models))
for i, m := range models {
modelList[i] = toModelResponse(&m)
}
// API Key 提示
apiKeyHint := ""
apiKeySet := false
if p.APIKey != "" {
apiKeySet = true
decrypted := decryptAPIKey(p.APIKey)
if len(decrypted) > 8 {
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
} else if len(decrypted) > 0 {
apiKeyHint = "****"
}
}
return response.ProviderResponse{
ID: p.ID,
ProviderName: p.ProviderName,
ProviderType: p.ProviderType,
BaseURL: p.BaseURL,
APIKeySet: apiKeySet,
APIKeyHint: apiKeyHint,
APIConfig: apiConfig,
Capabilities: capabilities,
IsEnabled: p.IsEnabled,
IsDefault: p.IsDefault,
SortOrder: p.SortOrder,
Models: modelList,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
// toModelResponse 转换模型为响应对象
func toModelResponse(m *app.AIModel) response.ModelResponse {
config := json.RawMessage(m.Config)
if len(config) == 0 {
config = json.RawMessage("{}")
}
return response.ModelResponse{
ID: m.ID,
ProviderID: m.ProviderID,
ModelName: m.ModelName,
DisplayName: m.DisplayName,
ModelType: m.ModelType,
Config: config,
IsEnabled: m.IsEnabled,
CreatedAt: m.CreatedAt,
}
}
// encryptAPIKey 加密API密钥
// TODO: 后续可以替换为更安全的加密方式(如 AES当前使用简单的 Base64 编码
func encryptAPIKey(key string) string {
if key == "" {
return ""
}
// 简单的混淆处理,生产环境应替换为 AES 加密
import_encoding := []byte(key)
for i := range import_encoding {
import_encoding[i] ^= 0x5A
}
return fmt.Sprintf("enc:%x", import_encoding)
}
// decryptAPIKey 解密API密钥
func decryptAPIKey(encrypted string) string {
if encrypted == "" {
return ""
}
if !strings.HasPrefix(encrypted, "enc:") {
return encrypted // 未加密的旧数据,直接返回
}
hexStr := encrypted[4:]
var data []byte
fmt.Sscanf(hexStr, "%x", &data)
for i := range data {
data[i] ^= 0x5A
}
return string(data)
}
// getDefaultBaseURL 获取默认API基础地址
func getDefaultBaseURL(providerType string) string {
switch providerType {
case "openai":
return "https://api.openai.com/v1"
case "claude":
return "https://api.anthropic.com"
case "gemini":
return "https://generativelanguage.googleapis.com"
default:
return ""
}
}
// getDefaultCapabilities 获取默认能力列表
func getDefaultCapabilities(providerType string) []string {
switch providerType {
case "openai":
return []string{"chat", "image_gen"}
case "claude":
return []string{"chat"}
case "gemini":
return []string{"chat", "image_gen"}
case "custom":
return []string{"chat"}
default:
return []string{"chat"}
}
}
// presetModel 预设模型内部结构
type presetModel struct {
ModelName string
DisplayName string
ModelType string
}
// getPresetModels 获取预设模型列表
func getPresetModels(providerType string) []presetModel {
switch providerType {
case "openai":
return []presetModel{
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
}
case "claude":
return []presetModel{
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
}
case "gemini":
return []presetModel{
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
}
case "custom":
return []presetModel{} // 自定义不提供预设
default:
return []presetModel{}
}
}
// ==================== 获取远程模型列表 ====================
// FetchRemoteModels 从远程API获取可用模型列表
// 所有提供商类型统一使用 baseURL + /models 端点OpenAI 兼容格式)
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
result := fetchModelsUniversal(baseURL, req.APIKey)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
})
}
// ==================== 发送测试消息 ====================
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
// 所有提供商类型统一使用 baseURL + /chat/completions 端点OpenAI 兼容格式)
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
startTime := time.Now()
baseURL := req.BaseURL
if baseURL == "" {
baseURL = getDefaultBaseURL(req.ProviderType)
}
message := req.Message
if message == "" {
message = "你好,请用一句话介绍你自己。"
}
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
result.Latency = time.Since(startTime).Milliseconds()
return result, nil
}
// SendTestMessageExisting 发送测试消息(已保存的提供商)
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
var provider app.AIProvider
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
if err != nil {
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
}
apiKey := decryptAPIKey(provider.APIKey)
return ps.SendTestMessage(request.SendTestMessageRequest{
ProviderType: provider.ProviderType,
BaseURL: provider.BaseURL,
APIKey: apiKey,
ModelName: modelName,
Message: message,
})
}
// ==================== 连通性测试实现 ====================
// testOpenAICompatible 测试 OpenAI 兼容接口
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.TestProviderResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
var modelNames []string
if err := json.Unmarshal(body, &modelsResp); err == nil {
for _, m := range modelsResp.Data {
modelNames = append(modelNames, m.ID)
}
}
return response.TestProviderResponse{
Success: true,
Message: "连接成功",
Models: modelNames,
}
}
// ==================== 获取远程模型列表实现 ====================
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
// 使用 baseURL + /models 端点Authorization: Bearer 鉴权
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
url := strings.TrimRight(baseURL, "/") + "/models"
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
if resp.StatusCode == 401 {
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
return response.FetchRemoteModelsResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
body, _ := io.ReadAll(resp.Body)
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
var modelsData struct {
Data []struct {
ID string `json:"id"`
OwnedBy string `json:"owned_by"`
} `json:"data"`
}
var models []response.RemoteModel
if err := json.Unmarshal(body, &modelsData); err == nil {
for _, m := range modelsData.Data {
models = append(models, response.RemoteModel{
ID: m.ID,
OwnedBy: m.OwnedBy,
})
}
}
return response.FetchRemoteModelsResponse{
Success: true,
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
Models: models,
}
}
// ==================== 发送测试消息实现 ====================
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
// 使用 baseURL + /chat/completions 端点Authorization: Bearer 鉴权
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
payload := map[string]interface{}{
"model": modelName,
"max_tokens": 100,
"messages": []map[string]string{
{"role": "user", "content": message},
},
}
payloadBytes, _ := json.Marshal(payload)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == 401 {
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
}
if resp.StatusCode != 200 {
var errResp struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
}
return response.SendTestMessageResponse{
Success: false,
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
}
}
// 解析 OpenAI 兼容格式的 chat completion 响应
var chatResp struct {
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &chatResp); err != nil {
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
}
reply := ""
if len(chatResp.Choices) > 0 {
reply = chatResp.Choices[0].Message.Content
}
return response.SendTestMessageResponse{
Success: true,
Message: "测试成功",
Reply: reply,
Model: chatResp.Model,
Tokens: chatResp.Usage.TotalTokens,
}
}