Files
ai_proxy/server/service/app/ai_provider.go
2026-03-03 06:05:51 +08:00

231 lines
6.3 KiB
Go

package app
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
)
type AiProviderService struct{}
// CreateAiProvider 创建AI提供商
func (s *AiProviderService) CreateAiProvider(req *request.CreateAiProviderRequest) (provider app.AiProvider, err error) {
provider = app.AiProvider{
Name: req.Name,
Type: req.Type,
BaseURL: req.BaseURL,
Endpoint: req.Endpoint,
UpstreamKey: req.UpstreamKey,
Model: req.Model,
ProxyKey: req.ProxyKey,
Config: req.Config,
IsActive: req.IsActive,
}
err = global.GVA_DB.Create(&provider).Error
return provider, err
}
// DeleteAiProvider 删除AI提供商
func (s *AiProviderService) DeleteAiProvider(id uint) (err error) {
return global.GVA_DB.Delete(&app.AiProvider{}, id).Error
}
// UpdateAiProvider 更新AI提供商
func (s *AiProviderService) UpdateAiProvider(req *request.UpdateAiProviderRequest) (provider app.AiProvider, err error) {
err = global.GVA_DB.First(&provider, req.ID).Error
if err != nil {
return provider, err
}
if req.Name != "" {
provider.Name = req.Name
}
if req.Type != "" {
provider.Type = req.Type
}
if req.BaseURL != "" {
provider.BaseURL = req.BaseURL
}
if req.Endpoint != "" {
provider.Endpoint = req.Endpoint
}
if req.UpstreamKey != "" {
provider.UpstreamKey = req.UpstreamKey
}
if req.Model != "" {
provider.Model = req.Model
}
if req.ProxyKey != "" {
provider.ProxyKey = req.ProxyKey
}
if req.Config != nil {
provider.Config = req.Config
}
provider.IsActive = req.IsActive
err = global.GVA_DB.Save(&provider).Error
return provider, err
}
// GetAiProvider 获取AI提供商详情
func (s *AiProviderService) GetAiProvider(id uint) (provider app.AiProvider, err error) {
err = global.GVA_DB.First(&provider, id).Error
return provider, err
}
// GetAiProviderList 获取AI提供商列表
func (s *AiProviderService) GetAiProviderList() (list []app.AiProvider, err error) {
err = global.GVA_DB.Where("is_active = ?", true).Find(&list).Error
return list, err
}
// TestConnection 测试连接
func (s *AiProviderService) TestConnection(req *request.TestConnectionRequest) (resp response.TestConnectionResponse, err error) {
startTime := time.Now()
// 根据类型构建测试 URL
var testURL string
switch strings.ToLower(req.Type) {
case "openai":
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
case "claude":
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/messages"
default:
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
}
// 创建 HTTP 请求
httpReq, err := http.NewRequest("GET", testURL, nil)
if err != nil {
return response.TestConnectionResponse{
Success: false,
Message: fmt.Sprintf("创建请求失败: %v", err),
Latency: 0,
}, nil
}
// 设置请求头
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
httpReq.Header.Set("Content-Type", "application/json")
// 发送请求
client := &http.Client{
Timeout: 10 * time.Second,
}
httpResp, err := client.Do(httpReq)
if err != nil {
return response.TestConnectionResponse{
Success: false,
Message: fmt.Sprintf("连接失败: %v", err),
Latency: time.Since(startTime).Milliseconds(),
}, nil
}
defer httpResp.Body.Close()
latency := time.Since(startTime).Milliseconds()
// 检查响应状态
if httpResp.StatusCode == http.StatusOK || httpResp.StatusCode == http.StatusCreated {
return response.TestConnectionResponse{
Success: true,
Message: "连接成功",
Latency: latency,
}, nil
}
// 读取错误响应
body, _ := io.ReadAll(httpResp.Body)
return response.TestConnectionResponse{
Success: false,
Message: fmt.Sprintf("连接失败 (状态码: %d): %s", httpResp.StatusCode, string(body)),
Latency: latency,
}, nil
}
// GetModels 获取模型列表
func (s *AiProviderService) GetModels(req *request.GetModelsRequest) (models []response.ModelInfo, err error) {
// 根据类型构建 URL
var modelsURL string
switch strings.ToLower(req.Type) {
case "openai":
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
case "claude":
// Claude API 不提供模型列表接口,返回预定义的模型
return []response.ModelInfo{
{ID: "claude-opus-4-6", Name: "Claude Opus 4.6", OwnedBy: "anthropic"},
{ID: "claude-sonnet-4-6", Name: "Claude Sonnet 4.6", OwnedBy: "anthropic"},
{ID: "claude-haiku-4-5-20251001", Name: "Claude Haiku 4.5", OwnedBy: "anthropic"},
{ID: "claude-3-5-sonnet-20241022", Name: "Claude 3.5 Sonnet", OwnedBy: "anthropic"},
{ID: "claude-3-opus-20240229", Name: "Claude 3 Opus", OwnedBy: "anthropic"},
}, nil
default:
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
}
// 创建 HTTP 请求
httpReq, err := http.NewRequest("GET", modelsURL, nil)
if err != nil {
return nil, errors.New("创建请求失败: " + err.Error())
}
// 设置请求头
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
httpReq.Header.Set("Content-Type", "application/json")
// 发送请求
client := &http.Client{
Timeout: 10 * time.Second,
}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, errors.New("请求失败: " + err.Error())
}
defer httpResp.Body.Close()
// 检查响应状态
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("获取模型列表失败 (状态码: %d): %s", httpResp.StatusCode, string(body))
}
// 解析响应
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, errors.New("读取响应失败: " + err.Error())
}
// OpenAI 格式的响应
var modelsResp struct {
Data []struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
} `json:"data"`
}
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, errors.New("解析响应失败: " + err.Error())
}
// 转换为响应格式
models = make([]response.ModelInfo, len(modelsResp.Data))
for i, model := range modelsResp.Data {
models[i] = response.ModelInfo{
ID: model.ID,
Name: model.ID,
OwnedBy: model.OwnedBy,
}
}
return models, nil
}