231 lines
6.3 KiB
Go
231 lines
6.3 KiB
Go
package app
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.echol.cn/loser/ai_proxy/server/global"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
|
"git.echol.cn/loser/ai_proxy/server/model/app/response"
|
|
)
|
|
|
|
type AiProviderService struct{}
|
|
|
|
// CreateAiProvider 创建AI提供商
|
|
func (s *AiProviderService) CreateAiProvider(req *request.CreateAiProviderRequest) (provider app.AiProvider, err error) {
|
|
provider = app.AiProvider{
|
|
Name: req.Name,
|
|
Type: req.Type,
|
|
BaseURL: req.BaseURL,
|
|
Endpoint: req.Endpoint,
|
|
UpstreamKey: req.UpstreamKey,
|
|
Model: req.Model,
|
|
ProxyKey: req.ProxyKey,
|
|
Config: req.Config,
|
|
IsActive: req.IsActive,
|
|
}
|
|
err = global.GVA_DB.Create(&provider).Error
|
|
return provider, err
|
|
}
|
|
|
|
// DeleteAiProvider 删除AI提供商
|
|
func (s *AiProviderService) DeleteAiProvider(id uint) (err error) {
|
|
return global.GVA_DB.Delete(&app.AiProvider{}, id).Error
|
|
}
|
|
|
|
// UpdateAiProvider 更新AI提供商
|
|
func (s *AiProviderService) UpdateAiProvider(req *request.UpdateAiProviderRequest) (provider app.AiProvider, err error) {
|
|
err = global.GVA_DB.First(&provider, req.ID).Error
|
|
if err != nil {
|
|
return provider, err
|
|
}
|
|
|
|
if req.Name != "" {
|
|
provider.Name = req.Name
|
|
}
|
|
if req.Type != "" {
|
|
provider.Type = req.Type
|
|
}
|
|
if req.BaseURL != "" {
|
|
provider.BaseURL = req.BaseURL
|
|
}
|
|
if req.Endpoint != "" {
|
|
provider.Endpoint = req.Endpoint
|
|
}
|
|
if req.UpstreamKey != "" {
|
|
provider.UpstreamKey = req.UpstreamKey
|
|
}
|
|
if req.Model != "" {
|
|
provider.Model = req.Model
|
|
}
|
|
if req.ProxyKey != "" {
|
|
provider.ProxyKey = req.ProxyKey
|
|
}
|
|
if req.Config != nil {
|
|
provider.Config = req.Config
|
|
}
|
|
provider.IsActive = req.IsActive
|
|
|
|
err = global.GVA_DB.Save(&provider).Error
|
|
return provider, err
|
|
}
|
|
|
|
// GetAiProvider 获取AI提供商详情
|
|
func (s *AiProviderService) GetAiProvider(id uint) (provider app.AiProvider, err error) {
|
|
err = global.GVA_DB.First(&provider, id).Error
|
|
return provider, err
|
|
}
|
|
|
|
// GetAiProviderList 获取AI提供商列表
|
|
func (s *AiProviderService) GetAiProviderList() (list []app.AiProvider, err error) {
|
|
err = global.GVA_DB.Where("is_active = ?", true).Find(&list).Error
|
|
return list, err
|
|
}
|
|
|
|
// TestConnection 测试连接
|
|
func (s *AiProviderService) TestConnection(req *request.TestConnectionRequest) (resp response.TestConnectionResponse, err error) {
|
|
startTime := time.Now()
|
|
|
|
// 根据类型构建测试 URL
|
|
var testURL string
|
|
switch strings.ToLower(req.Type) {
|
|
case "openai":
|
|
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
|
case "claude":
|
|
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/messages"
|
|
default:
|
|
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
|
}
|
|
|
|
// 创建 HTTP 请求
|
|
httpReq, err := http.NewRequest("GET", testURL, nil)
|
|
if err != nil {
|
|
return response.TestConnectionResponse{
|
|
Success: false,
|
|
Message: fmt.Sprintf("创建请求失败: %v", err),
|
|
Latency: 0,
|
|
}, nil
|
|
}
|
|
|
|
// 设置请求头
|
|
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
// 发送请求
|
|
client := &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
httpResp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
return response.TestConnectionResponse{
|
|
Success: false,
|
|
Message: fmt.Sprintf("连接失败: %v", err),
|
|
Latency: time.Since(startTime).Milliseconds(),
|
|
}, nil
|
|
}
|
|
defer httpResp.Body.Close()
|
|
|
|
latency := time.Since(startTime).Milliseconds()
|
|
|
|
// 检查响应状态
|
|
if httpResp.StatusCode == http.StatusOK || httpResp.StatusCode == http.StatusCreated {
|
|
return response.TestConnectionResponse{
|
|
Success: true,
|
|
Message: "连接成功",
|
|
Latency: latency,
|
|
}, nil
|
|
}
|
|
|
|
// 读取错误响应
|
|
body, _ := io.ReadAll(httpResp.Body)
|
|
return response.TestConnectionResponse{
|
|
Success: false,
|
|
Message: fmt.Sprintf("连接失败 (状态码: %d): %s", httpResp.StatusCode, string(body)),
|
|
Latency: latency,
|
|
}, nil
|
|
}
|
|
|
|
// GetModels 获取模型列表
|
|
func (s *AiProviderService) GetModels(req *request.GetModelsRequest) (models []response.ModelInfo, err error) {
|
|
// 根据类型构建 URL
|
|
var modelsURL string
|
|
switch strings.ToLower(req.Type) {
|
|
case "openai":
|
|
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
|
case "claude":
|
|
// Claude API 不提供模型列表接口,返回预定义的模型
|
|
return []response.ModelInfo{
|
|
{ID: "claude-opus-4-6", Name: "Claude Opus 4.6", OwnedBy: "anthropic"},
|
|
{ID: "claude-sonnet-4-6", Name: "Claude Sonnet 4.6", OwnedBy: "anthropic"},
|
|
{ID: "claude-haiku-4-5-20251001", Name: "Claude Haiku 4.5", OwnedBy: "anthropic"},
|
|
{ID: "claude-3-5-sonnet-20241022", Name: "Claude 3.5 Sonnet", OwnedBy: "anthropic"},
|
|
{ID: "claude-3-opus-20240229", Name: "Claude 3 Opus", OwnedBy: "anthropic"},
|
|
}, nil
|
|
default:
|
|
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
|
|
}
|
|
|
|
// 创建 HTTP 请求
|
|
httpReq, err := http.NewRequest("GET", modelsURL, nil)
|
|
if err != nil {
|
|
return nil, errors.New("创建请求失败: " + err.Error())
|
|
}
|
|
|
|
// 设置请求头
|
|
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
// 发送请求
|
|
client := &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
httpResp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, errors.New("请求失败: " + err.Error())
|
|
}
|
|
defer httpResp.Body.Close()
|
|
|
|
// 检查响应状态
|
|
if httpResp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(httpResp.Body)
|
|
return nil, fmt.Errorf("获取模型列表失败 (状态码: %d): %s", httpResp.StatusCode, string(body))
|
|
}
|
|
|
|
// 解析响应
|
|
body, err := io.ReadAll(httpResp.Body)
|
|
if err != nil {
|
|
return nil, errors.New("读取响应失败: " + err.Error())
|
|
}
|
|
|
|
// OpenAI 格式的响应
|
|
var modelsResp struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
OwnedBy string `json:"owned_by"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
|
return nil, errors.New("解析响应失败: " + err.Error())
|
|
}
|
|
|
|
// 转换为响应格式
|
|
models = make([]response.ModelInfo, len(modelsResp.Data))
|
|
for i, model := range modelsResp.Data {
|
|
models[i] = response.ModelInfo{
|
|
ID: model.ID,
|
|
Name: model.ID,
|
|
OwnedBy: model.OwnedBy,
|
|
}
|
|
}
|
|
|
|
return models, nil
|
|
}
|