🎨 优化扩展模块,完成ai接入和对话功能
This commit is contained in:
949
server/service/app/provider.go
Normal file
949
server/service/app/provider.go
Normal file
@@ -0,0 +1,949 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/st/server/global"
|
||||
"git.echol.cn/loser/st/server/model/app"
|
||||
"git.echol.cn/loser/st/server/model/app/request"
|
||||
"git.echol.cn/loser/st/server/model/app/response"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ProviderService struct{}
|
||||
|
||||
// ==================== 提供商 CRUD ====================
|
||||
|
||||
// CreateProvider 创建AI提供商
|
||||
func (ps *ProviderService) CreateProvider(req request.CreateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||||
// 加密 API Key
|
||||
encryptedKey := encryptAPIKey(req.APIKey)
|
||||
|
||||
// 序列化额外配置
|
||||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||||
if req.APIConfig == nil {
|
||||
apiConfigJSON = []byte("{}")
|
||||
}
|
||||
|
||||
// 根据类型确定能力
|
||||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||||
capJSON, _ := json.Marshal(capabilities)
|
||||
|
||||
// 如果 BaseURL 为空,使用默认地址
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
provider := app.AIProvider{
|
||||
UserID: &userID,
|
||||
ProviderName: req.ProviderName,
|
||||
ProviderType: req.ProviderType,
|
||||
BaseURL: baseURL,
|
||||
APIKey: encryptedKey,
|
||||
APIConfig: datatypes.JSON(apiConfigJSON),
|
||||
Capabilities: datatypes.JSON(capJSON),
|
||||
IsEnabled: true,
|
||||
IsDefault: false,
|
||||
}
|
||||
|
||||
err := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建提供商
|
||||
if err := tx.Create(&provider).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果携带了模型列表,同时创建模型
|
||||
if len(req.Models) > 0 {
|
||||
for _, m := range req.Models {
|
||||
modelConfigJSON, _ := json.Marshal(m.Config)
|
||||
if m.Config == nil {
|
||||
modelConfigJSON = []byte("{}")
|
||||
}
|
||||
isEnabled := true
|
||||
if m.IsEnabled != nil {
|
||||
isEnabled = *m.IsEnabled
|
||||
}
|
||||
model := app.AIModel{
|
||||
ProviderID: provider.ID,
|
||||
ModelName: m.ModelName,
|
||||
DisplayName: m.DisplayName,
|
||||
ModelType: m.ModelType,
|
||||
Config: datatypes.JSON(modelConfigJSON),
|
||||
IsEnabled: isEnabled,
|
||||
}
|
||||
if err := tx.Create(&model).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 没有指定模型时,自动添加预设模型
|
||||
presets := getPresetModels(req.ProviderType)
|
||||
for _, p := range presets {
|
||||
model := app.AIModel{
|
||||
ProviderID: provider.ID,
|
||||
ModelName: p.ModelName,
|
||||
DisplayName: p.DisplayName,
|
||||
ModelType: p.ModelType,
|
||||
Config: datatypes.JSON([]byte("{}")),
|
||||
IsEnabled: true,
|
||||
}
|
||||
if err := tx.Create(&model).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是用户的第一个提供商,自动设为默认
|
||||
var count int64
|
||||
tx.Model(&app.AIProvider{}).Where("user_id = ? AND id != ?", userID, provider.ID).Count(&count)
|
||||
if count == 0 {
|
||||
provider.IsDefault = true
|
||||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
return ps.GetProviderDetail(provider.ID, userID)
|
||||
}
|
||||
|
||||
// GetProviderList 获取用户的提供商列表
|
||||
func (ps *ProviderService) GetProviderList(req request.ProviderListRequest, userID uint) (response.ProviderListResponse, error) {
|
||||
db := global.GVA_DB.Model(&app.AIProvider{}).Where("user_id = ?", userID)
|
||||
|
||||
if req.Keyword != "" {
|
||||
keyword := "%" + req.Keyword + "%"
|
||||
db = db.Where("provider_name ILIKE ?", keyword)
|
||||
}
|
||||
|
||||
var total int64
|
||||
db.Count(&total)
|
||||
|
||||
var providers []app.AIProvider
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
err := db.Order("is_default DESC, sort_order ASC, created_at DESC").
|
||||
Offset(offset).Limit(req.PageSize).Find(&providers).Error
|
||||
if err != nil {
|
||||
return response.ProviderListResponse{}, err
|
||||
}
|
||||
|
||||
// 获取所有提供商的模型
|
||||
providerIDs := make([]uint, len(providers))
|
||||
for i, p := range providers {
|
||||
providerIDs[i] = p.ID
|
||||
}
|
||||
|
||||
var models []app.AIModel
|
||||
if len(providerIDs) > 0 {
|
||||
global.GVA_DB.Where("provider_id IN ?", providerIDs).
|
||||
Order("model_type ASC, model_name ASC").Find(&models)
|
||||
}
|
||||
|
||||
// 按提供商ID分组模型
|
||||
modelMap := make(map[uint][]app.AIModel)
|
||||
for _, m := range models {
|
||||
modelMap[m.ProviderID] = append(modelMap[m.ProviderID], m)
|
||||
}
|
||||
|
||||
list := make([]response.ProviderResponse, len(providers))
|
||||
for i, p := range providers {
|
||||
list[i] = toProviderResponse(&p, modelMap[p.ID])
|
||||
}
|
||||
|
||||
return response.ProviderListResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderDetail 获取提供商详情
|
||||
func (ps *ProviderService) GetProviderDetail(providerID uint, userID uint) (response.ProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
var models []app.AIModel
|
||||
global.GVA_DB.Where("provider_id = ?", providerID).
|
||||
Order("model_type ASC, model_name ASC").Find(&models)
|
||||
|
||||
return toProviderResponse(&provider, models), nil
|
||||
}
|
||||
|
||||
// UpdateProvider 更新提供商
|
||||
func (ps *ProviderService) UpdateProvider(req request.UpdateProviderRequest, userID uint) (response.ProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return response.ProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
updates := map[string]interface{}{
|
||||
"provider_name": req.ProviderName,
|
||||
"provider_type": req.ProviderType,
|
||||
"base_url": req.BaseURL,
|
||||
}
|
||||
|
||||
// APIKey 不为空时才更新
|
||||
if req.APIKey != "" {
|
||||
updates["api_key"] = encryptAPIKey(req.APIKey)
|
||||
}
|
||||
|
||||
if req.APIConfig != nil {
|
||||
apiConfigJSON, _ := json.Marshal(req.APIConfig)
|
||||
updates["api_config"] = datatypes.JSON(apiConfigJSON)
|
||||
}
|
||||
|
||||
if req.IsEnabled != nil {
|
||||
updates["is_enabled"] = *req.IsEnabled
|
||||
}
|
||||
|
||||
if req.SortOrder != nil {
|
||||
updates["sort_order"] = *req.SortOrder
|
||||
}
|
||||
|
||||
// 更新能力
|
||||
capabilities := getDefaultCapabilities(req.ProviderType)
|
||||
capJSON, _ := json.Marshal(capabilities)
|
||||
updates["capabilities"] = datatypes.JSON(capJSON)
|
||||
|
||||
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&provider).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理设置默认
|
||||
if req.IsDefault != nil && *req.IsDefault {
|
||||
// 先取消其他默认
|
||||
if err := tx.Model(&app.AIProvider{}).
|
||||
Where("user_id = ? AND id != ?", userID, req.ID).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&provider).Update("is_default", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return response.ProviderResponse{}, err
|
||||
}
|
||||
|
||||
return ps.GetProviderDetail(req.ID, userID)
|
||||
}
|
||||
|
||||
// DeleteProvider 删除提供商
|
||||
func (ps *ProviderService) DeleteProvider(providerID uint, userID uint) error {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除关联的模型
|
||||
if err := tx.Where("provider_id = ?", providerID).Delete(&app.AIModel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除提供商
|
||||
if err := tx.Delete(&provider).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果删除的是默认提供商,自动将第一个提供商设为默认
|
||||
if provider.IsDefault {
|
||||
var firstProvider app.AIProvider
|
||||
if err := tx.Where("user_id = ?", userID).Order("created_at ASC").First(&firstProvider).Error; err == nil {
|
||||
tx.Model(&firstProvider).Update("is_default", true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultProvider 设置默认提供商
|
||||
func (ps *ProviderService) SetDefaultProvider(providerID uint, userID uint) error {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("提供商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 先取消所有默认
|
||||
if err := tx.Model(&app.AIProvider{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置新默认
|
||||
return tx.Model(&provider).Update("is_default", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 模型 CRUD ====================
|
||||
|
||||
// AddModel 为提供商添加模型
|
||||
func (ps *ProviderService) AddModel(req request.CreateModelRequest, userID uint) (response.ModelResponse, error) {
|
||||
// 验证提供商归属
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", req.ProviderID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.ModelResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
if req.Config == nil {
|
||||
configJSON = []byte("{}")
|
||||
}
|
||||
|
||||
isEnabled := true
|
||||
if req.IsEnabled != nil {
|
||||
isEnabled = *req.IsEnabled
|
||||
}
|
||||
|
||||
model := app.AIModel{
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
DisplayName: req.DisplayName,
|
||||
ModelType: req.ModelType,
|
||||
Config: datatypes.JSON(configJSON),
|
||||
IsEnabled: isEnabled,
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Create(&model).Error; err != nil {
|
||||
return response.ModelResponse{}, err
|
||||
}
|
||||
|
||||
return toModelResponse(&model), nil
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func (ps *ProviderService) UpdateModel(req request.UpdateModelRequest, userID uint) (response.ModelResponse, error) {
|
||||
var model app.AIModel
|
||||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||||
Where("ai_models.id = ? AND ai_providers.user_id = ?", req.ID, userID).
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return response.ModelResponse{}, errors.New("模型不存在")
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"model_name": req.ModelName,
|
||||
"display_name": req.DisplayName,
|
||||
"model_type": req.ModelType,
|
||||
}
|
||||
|
||||
if req.Config != nil {
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
updates["config"] = datatypes.JSON(configJSON)
|
||||
}
|
||||
|
||||
if req.IsEnabled != nil {
|
||||
updates["is_enabled"] = *req.IsEnabled
|
||||
}
|
||||
|
||||
if err := global.GVA_DB.Model(&model).Updates(updates).Error; err != nil {
|
||||
return response.ModelResponse{}, err
|
||||
}
|
||||
|
||||
// 重新查询
|
||||
global.GVA_DB.First(&model, model.ID)
|
||||
return toModelResponse(&model), nil
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func (ps *ProviderService) DeleteModel(modelID uint, userID uint) error {
|
||||
var model app.AIModel
|
||||
err := global.GVA_DB.Joins("JOIN ai_providers ON ai_providers.id = ai_models.provider_id").
|
||||
Where("ai_models.id = ? AND ai_providers.user_id = ?", modelID, userID).
|
||||
First(&model).Error
|
||||
if err != nil {
|
||||
return errors.New("模型不存在")
|
||||
}
|
||||
|
||||
return global.GVA_DB.Delete(&model).Error
|
||||
}
|
||||
|
||||
// ==================== 连通性测试 ====================
|
||||
|
||||
// TestProvider 测试提供商连通性
|
||||
func (ps *ProviderService) TestProvider(req request.TestProviderRequest) (response.TestProviderResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
// 所有提供商统一使用 OpenAI 兼容的 /models 端点测试连通性
|
||||
result := testOpenAICompatible(baseURL, req.APIKey, req.ModelName)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TestExistingProvider 测试已保存的提供商连通性
|
||||
func (ps *ProviderService) TestExistingProvider(providerID uint, userID uint) (response.TestProviderResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.TestProvider(request.TestProviderRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 辅助查询 ====================
|
||||
|
||||
// GetProviderTypes 获取支持的提供商类型列表(前端下拉用)
|
||||
func (ps *ProviderService) GetProviderTypes() []response.ProviderTypeOption {
|
||||
return []response.ProviderTypeOption{
|
||||
{
|
||||
Value: "openai",
|
||||
Label: "OpenAI",
|
||||
Description: "支持 GPT-4o、GPT-4、DALL·E 等模型,也兼容所有 OpenAI 格式的中转站",
|
||||
DefaultURL: "https://api.openai.com/v1",
|
||||
},
|
||||
{
|
||||
Value: "claude",
|
||||
Label: "Claude",
|
||||
Description: "Anthropic 的 Claude 系列模型,支持长上下文对话",
|
||||
DefaultURL: "https://api.anthropic.com",
|
||||
},
|
||||
{
|
||||
Value: "gemini",
|
||||
Label: "Google Gemini",
|
||||
Description: "Google 的 Gemini 系列模型,支持多模态",
|
||||
DefaultURL: "https://generativelanguage.googleapis.com",
|
||||
},
|
||||
{
|
||||
Value: "custom",
|
||||
Label: "自定义(OpenAI 兼容)",
|
||||
Description: "兼容 OpenAI 格式的任意接口,如 DeepSeek、通义千问等中转站",
|
||||
DefaultURL: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetPresetModels 获取指定提供商类型的预设模型列表
|
||||
func (ps *ProviderService) GetPresetModels(providerType string) []response.PresetModelOption {
|
||||
presets := getPresetModels(providerType)
|
||||
result := make([]response.PresetModelOption, len(presets))
|
||||
for i, p := range presets {
|
||||
result[i] = response.PresetModelOption{
|
||||
ModelName: p.ModelName,
|
||||
DisplayName: p.DisplayName,
|
||||
ModelType: p.ModelType,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUserDefaultProvider 获取用户默认提供商(内部方法,给对话功能用)
|
||||
func (ps *ProviderService) GetUserDefaultProvider(userID uint) (*app.AIProvider, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("user_id = ? AND is_default = ? AND is_enabled = ?", userID, true, true).
|
||||
First(&provider).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("请先配置 AI 接口")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// GetDecryptedAPIKey 获取解密后的API密钥(内部方法,给AI调用用)
|
||||
func (ps *ProviderService) GetDecryptedAPIKey(provider *app.AIProvider) string {
|
||||
return decryptAPIKey(provider.APIKey)
|
||||
}
|
||||
|
||||
// ==================== 内部辅助函数 ====================
|
||||
|
||||
// toProviderResponse 转换为响应对象
|
||||
func toProviderResponse(p *app.AIProvider, models []app.AIModel) response.ProviderResponse {
|
||||
apiConfig := json.RawMessage(p.APIConfig)
|
||||
if len(apiConfig) == 0 {
|
||||
apiConfig = json.RawMessage("{}")
|
||||
}
|
||||
|
||||
capabilities := json.RawMessage(p.Capabilities)
|
||||
if len(capabilities) == 0 {
|
||||
capabilities = json.RawMessage("[]")
|
||||
}
|
||||
|
||||
// 模型列表
|
||||
modelList := make([]response.ModelResponse, len(models))
|
||||
for i, m := range models {
|
||||
modelList[i] = toModelResponse(&m)
|
||||
}
|
||||
|
||||
// API Key 提示
|
||||
apiKeyHint := ""
|
||||
apiKeySet := false
|
||||
if p.APIKey != "" {
|
||||
apiKeySet = true
|
||||
decrypted := decryptAPIKey(p.APIKey)
|
||||
if len(decrypted) > 8 {
|
||||
apiKeyHint = decrypted[:4] + "****" + decrypted[len(decrypted)-4:]
|
||||
} else if len(decrypted) > 0 {
|
||||
apiKeyHint = "****"
|
||||
}
|
||||
}
|
||||
|
||||
return response.ProviderResponse{
|
||||
ID: p.ID,
|
||||
ProviderName: p.ProviderName,
|
||||
ProviderType: p.ProviderType,
|
||||
BaseURL: p.BaseURL,
|
||||
APIKeySet: apiKeySet,
|
||||
APIKeyHint: apiKeyHint,
|
||||
APIConfig: apiConfig,
|
||||
Capabilities: capabilities,
|
||||
IsEnabled: p.IsEnabled,
|
||||
IsDefault: p.IsDefault,
|
||||
SortOrder: p.SortOrder,
|
||||
Models: modelList,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// toModelResponse 转换模型为响应对象
|
||||
func toModelResponse(m *app.AIModel) response.ModelResponse {
|
||||
config := json.RawMessage(m.Config)
|
||||
if len(config) == 0 {
|
||||
config = json.RawMessage("{}")
|
||||
}
|
||||
return response.ModelResponse{
|
||||
ID: m.ID,
|
||||
ProviderID: m.ProviderID,
|
||||
ModelName: m.ModelName,
|
||||
DisplayName: m.DisplayName,
|
||||
ModelType: m.ModelType,
|
||||
Config: config,
|
||||
IsEnabled: m.IsEnabled,
|
||||
CreatedAt: m.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// encryptAPIKey 加密API密钥
|
||||
// TODO: 后续可以替换为更安全的加密方式(如 AES),当前使用简单的 Base64 编码
|
||||
func encryptAPIKey(key string) string {
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
// 简单的混淆处理,生产环境应替换为 AES 加密
|
||||
import_encoding := []byte(key)
|
||||
for i := range import_encoding {
|
||||
import_encoding[i] ^= 0x5A
|
||||
}
|
||||
return fmt.Sprintf("enc:%x", import_encoding)
|
||||
}
|
||||
|
||||
// decryptAPIKey 解密API密钥
|
||||
func decryptAPIKey(encrypted string) string {
|
||||
if encrypted == "" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(encrypted, "enc:") {
|
||||
return encrypted // 未加密的旧数据,直接返回
|
||||
}
|
||||
|
||||
hexStr := encrypted[4:]
|
||||
var data []byte
|
||||
fmt.Sscanf(hexStr, "%x", &data)
|
||||
for i := range data {
|
||||
data[i] ^= 0x5A
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// getDefaultBaseURL 获取默认API基础地址
|
||||
func getDefaultBaseURL(providerType string) string {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return "https://api.openai.com/v1"
|
||||
case "claude":
|
||||
return "https://api.anthropic.com"
|
||||
case "gemini":
|
||||
return "https://generativelanguage.googleapis.com"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultCapabilities 获取默认能力列表
|
||||
func getDefaultCapabilities(providerType string) []string {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return []string{"chat", "image_gen"}
|
||||
case "claude":
|
||||
return []string{"chat"}
|
||||
case "gemini":
|
||||
return []string{"chat", "image_gen"}
|
||||
case "custom":
|
||||
return []string{"chat"}
|
||||
default:
|
||||
return []string{"chat"}
|
||||
}
|
||||
}
|
||||
|
||||
// presetModel 预设模型内部结构
|
||||
type presetModel struct {
|
||||
ModelName string
|
||||
DisplayName string
|
||||
ModelType string
|
||||
}
|
||||
|
||||
// getPresetModels 获取预设模型列表
|
||||
func getPresetModels(providerType string) []presetModel {
|
||||
switch providerType {
|
||||
case "openai":
|
||||
return []presetModel{
|
||||
{ModelName: "gpt-4o", DisplayName: "GPT-4o", ModelType: "chat"},
|
||||
{ModelName: "gpt-4o-mini", DisplayName: "GPT-4o Mini", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1", DisplayName: "GPT-4.1", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1-mini", DisplayName: "GPT-4.1 Mini", ModelType: "chat"},
|
||||
{ModelName: "gpt-4.1-nano", DisplayName: "GPT-4.1 Nano", ModelType: "chat"},
|
||||
{ModelName: "o3-mini", DisplayName: "o3-mini", ModelType: "chat"},
|
||||
{ModelName: "dall-e-3", DisplayName: "DALL·E 3", ModelType: "image_gen"},
|
||||
}
|
||||
case "claude":
|
||||
return []presetModel{
|
||||
{ModelName: "claude-sonnet-4-20250514", DisplayName: "Claude Sonnet 4", ModelType: "chat"},
|
||||
{ModelName: "claude-3-5-sonnet-20241022", DisplayName: "Claude 3.5 Sonnet", ModelType: "chat"},
|
||||
{ModelName: "claude-3-5-haiku-20241022", DisplayName: "Claude 3.5 Haiku", ModelType: "chat"},
|
||||
{ModelName: "claude-3-opus-20240229", DisplayName: "Claude 3 Opus", ModelType: "chat"},
|
||||
}
|
||||
case "gemini":
|
||||
return []presetModel{
|
||||
{ModelName: "gemini-2.5-flash-preview-05-20", DisplayName: "Gemini 2.5 Flash", ModelType: "chat"},
|
||||
{ModelName: "gemini-2.5-pro-preview-05-06", DisplayName: "Gemini 2.5 Pro", ModelType: "chat"},
|
||||
{ModelName: "gemini-2.0-flash", DisplayName: "Gemini 2.0 Flash", ModelType: "chat"},
|
||||
{ModelName: "imagen-3.0-generate-002", DisplayName: "Imagen 3", ModelType: "image_gen"},
|
||||
}
|
||||
case "custom":
|
||||
return []presetModel{} // 自定义不提供预设
|
||||
default:
|
||||
return []presetModel{}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 获取远程模型列表 ====================
|
||||
|
||||
// FetchRemoteModels 从远程API获取可用模型列表
|
||||
// 所有提供商类型统一使用 baseURL + /models 端点(OpenAI 兼容格式)
|
||||
func (ps *ProviderService) FetchRemoteModels(req request.FetchRemoteModelsRequest) (response.FetchRemoteModelsResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
result := fetchModelsUniversal(baseURL, req.APIKey)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FetchRemoteModelsExisting 获取已保存提供商的远程模型列表
|
||||
func (ps *ProviderService) FetchRemoteModelsExisting(providerID uint, userID uint) (response.FetchRemoteModelsResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.FetchRemoteModels(request.FetchRemoteModelsRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 发送测试消息 ====================
|
||||
|
||||
// SendTestMessage 发送测试消息(使用指定的 provider 配置)
|
||||
// 所有提供商类型统一使用 baseURL + /chat/completions 端点(OpenAI 兼容格式)
|
||||
func (ps *ProviderService) SendTestMessage(req request.SendTestMessageRequest) (response.SendTestMessageResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getDefaultBaseURL(req.ProviderType)
|
||||
}
|
||||
|
||||
message := req.Message
|
||||
if message == "" {
|
||||
message = "你好,请用一句话介绍你自己。"
|
||||
}
|
||||
|
||||
result := sendTestMessageUniversal(baseURL, req.APIKey, req.ModelName, message)
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SendTestMessageExisting 发送测试消息(已保存的提供商)
|
||||
func (ps *ProviderService) SendTestMessageExisting(providerID uint, userID uint, modelName string, message string) (response.SendTestMessageResponse, error) {
|
||||
var provider app.AIProvider
|
||||
err := global.GVA_DB.Where("id = ? AND user_id = ?", providerID, userID).First(&provider).Error
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{}, errors.New("提供商不存在")
|
||||
}
|
||||
|
||||
apiKey := decryptAPIKey(provider.APIKey)
|
||||
|
||||
return ps.SendTestMessage(request.SendTestMessageRequest{
|
||||
ProviderType: provider.ProviderType,
|
||||
BaseURL: provider.BaseURL,
|
||||
APIKey: apiKey,
|
||||
ModelName: modelName,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 连通性测试实现 ====================
|
||||
|
||||
// testOpenAICompatible 测试 OpenAI 兼容接口
|
||||
func testOpenAICompatible(baseURL, apiKey, modelName string) response.TestProviderResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.TestProviderResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.TestProviderResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return response.TestProviderResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// 解析模型列表
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var modelNames []string
|
||||
if err := json.Unmarshal(body, &modelsResp); err == nil {
|
||||
for _, m := range modelsResp.Data {
|
||||
modelNames = append(modelNames, m.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return response.TestProviderResponse{
|
||||
Success: true,
|
||||
Message: "连接成功",
|
||||
Models: modelNames,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 获取远程模型列表实现 ====================
|
||||
|
||||
// fetchModelsUniversal 统一获取模型列表(所有提供商通用)
|
||||
// 使用 baseURL + /models 端点,Authorization: Bearer 鉴权
|
||||
func fetchModelsUniversal(baseURL, apiKey string) response.FetchRemoteModelsResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/models"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.FetchRemoteModelsResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return response.FetchRemoteModelsResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 解析 OpenAI 兼容格式: { "data": [{ "id": "xxx", "owned_by": "xxx" }] }
|
||||
var modelsData struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
var models []response.RemoteModel
|
||||
if err := json.Unmarshal(body, &modelsData); err == nil {
|
||||
for _, m := range modelsData.Data {
|
||||
models = append(models, response.RemoteModel{
|
||||
ID: m.ID,
|
||||
OwnedBy: m.OwnedBy,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return response.FetchRemoteModelsResponse{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("获取成功,共 %d 个模型", len(models)),
|
||||
Models: models,
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 发送测试消息实现 ====================
|
||||
|
||||
// sendTestMessageUniversal 统一发送测试消息(所有提供商通用)
|
||||
// 使用 baseURL + /chat/completions 端点,Authorization: Bearer 鉴权
|
||||
func sendTestMessageUniversal(baseURL, apiKey, modelName, message string) response.SendTestMessageResponse {
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"max_tokens": 100,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": message},
|
||||
},
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "请求构建失败: " + err.Error()}
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "连接失败: " + err.Error()}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "API 密钥无效"}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var errResp struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "API 错误: " + errResp.Error.Message}
|
||||
}
|
||||
return response.SendTestMessageResponse{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("请求失败 (HTTP %d)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 OpenAI 兼容格式的 chat completion 响应
|
||||
var chatResp struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return response.SendTestMessageResponse{Success: false, Message: "解析响应失败"}
|
||||
}
|
||||
|
||||
reply := ""
|
||||
if len(chatResp.Choices) > 0 {
|
||||
reply = chatResp.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return response.SendTestMessageResponse{
|
||||
Success: true,
|
||||
Message: "测试成功",
|
||||
Reply: reply,
|
||||
Model: chatResp.Model,
|
||||
Tokens: chatResp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user