🎨 添加中间件 && 完善预设注入功能 && 新增流式传输
This commit is contained in:
@@ -35,8 +35,7 @@ func (a *AiProxyApi) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
// 处理流式响应
|
// 处理流式响应
|
||||||
if req.Stream {
|
if req.Stream {
|
||||||
// TODO: 实现流式响应
|
aiProxyService.ProcessChatCompletionStream(c, userId, &req)
|
||||||
response.FailWithMessage("流式响应暂未实现", c)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
44
server/middleware/app_jwt.go
Normal file
44
server/middleware/app_jwt.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AppJWTAuth 前台用户 JWT 认证中间件
|
||||||
|
func AppJWTAuth() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
token := c.GetHeader("Authorization")
|
||||||
|
if token == "" {
|
||||||
|
token = c.GetHeader("x-token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
response.FailWithDetailed(gin.H{"reload": true}, "未登录或非法访问", c)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 Bearer 前缀
|
||||||
|
if len(token) > 7 && token[:7] == "Bearer " {
|
||||||
|
token = token[7:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 token
|
||||||
|
claims, err := utils.ParseAppToken(token)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error("解析 App Token 失败: " + err.Error())
|
||||||
|
response.FailWithDetailed(gin.H{"reload": true}, "授权已过期或无效", c)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将用户信息存入上下文
|
||||||
|
c.Set("appClaims", claims)
|
||||||
|
c.Set("userId", claims.UserID)
|
||||||
|
c.Set("username", claims.Username)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
25
server/middleware/cors.go
Normal file
25
server/middleware/cors.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cors 直接放行全部跨域请求
|
||||||
|
func Cors() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
method := c.Request.Method
|
||||||
|
origin := c.Request.Header.Get("Origin")
|
||||||
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
|
c.Header("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token,X-Token,X-User-Id")
|
||||||
|
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS,DELETE,PUT")
|
||||||
|
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type, New-Token, New-Expires-At")
|
||||||
|
c.Header("Access-Control-Allow-Credentials", "true")
|
||||||
|
|
||||||
|
// 放行所有OPTIONS方法
|
||||||
|
if method == "OPTIONS" {
|
||||||
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
21
server/middleware/email.go
Normal file
21
server/middleware/email.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorToEmail 错误发送邮件中间件
|
||||||
|
func ErrorToEmail() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
global.GVA_LOG.Error("panic error", zap.Any("error", err))
|
||||||
|
response.FailWithMessage("服务器内部错误", c)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
18
server/middleware/error.go
Normal file
18
server/middleware/error.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorLogger 错误日志中间件
|
||||||
|
func ErrorLogger() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Next()
|
||||||
|
// 记录错误
|
||||||
|
for _, err := range c.Errors {
|
||||||
|
global.GVA_LOG.Error("request error", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
75
server/middleware/jwt.go
Normal file
75
server/middleware/jwt.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/system"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/service"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var jwtService = service.ServiceGroupApp.SystemServiceGroup.JwtService
|
||||||
|
|
||||||
|
func JWTAuth() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 从请求头获取 token
|
||||||
|
token := utils.GetToken(c)
|
||||||
|
if token == "" {
|
||||||
|
response.FailWithDetailed(gin.H{"reload": true}, "未登录或非法访问", c)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 JWT 黑名单
|
||||||
|
if jwtService.IsBlacklist(token) {
|
||||||
|
response.FailWithDetailed(gin.H{"reload": true}, "您的帐户异地登陆或令牌失效", c)
|
||||||
|
utils.ClearToken(c)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 token
|
||||||
|
j := utils.NewJWT()
|
||||||
|
claims, err := j.ParseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
if err == utils.TokenExpired {
|
||||||
|
response.FailWithDetailed(gin.H{"reload": true}, "授权已过期", c)
|
||||||
|
utils.ClearToken(c)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.FailWithDetailed(gin.H{"reload": true}, err.Error(), c)
|
||||||
|
utils.ClearToken(c)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token 续期检查
|
||||||
|
if claims.ExpiresAt.Unix()-time.Now().Unix() < claims.BufferTime {
|
||||||
|
dr, _ := utils.ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
|
||||||
|
claims.ExpiresAt = utils.NewNumericDate(time.Now().Add(dr))
|
||||||
|
newToken, _ := j.CreateTokenByOldToken(token, *claims)
|
||||||
|
newClaims, _ := j.ParseToken(newToken)
|
||||||
|
c.Header("new-token", newToken)
|
||||||
|
c.Header("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
|
||||||
|
utils.SetToken(c, newToken, int(dr.Seconds()))
|
||||||
|
if global.GVA_CONFIG.System.UseMultipoint {
|
||||||
|
RedisJwtToken, err := jwtService.GetRedisJWT(newClaims.Username)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error("get redis jwt failed", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
_ = jwtService.JsonInBlacklist(system.JwtBlacklist{Jwt: RedisJwtToken})
|
||||||
|
}
|
||||||
|
_ = jwtService.SetRedisJWT(newToken, newClaims.Username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("claims", claims)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
66
server/middleware/limit_ip.go
Normal file
66
server/middleware/limit_ip.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LimitConfig struct {
|
||||||
|
GenerationDuration time.Duration
|
||||||
|
Limit int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l LimitConfig) LimitKey(c *gin.Context) string {
|
||||||
|
return "GVA_Limit" + c.ClientIP()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l LimitConfig) GetLimit(c *gin.Context) int {
|
||||||
|
return l.Limit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l LimitConfig) Reached(c *gin.Context) response.Response {
|
||||||
|
return response.Response{Code: response.ERROR, Data: nil, Msg: "操作过于频繁,请稍后再试"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPLimit IP 限流中间件
|
||||||
|
func IPLimit() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if global.GVA_CONFIG.System.UseRedis {
|
||||||
|
key := "GVA_Limit" + c.ClientIP()
|
||||||
|
limit := global.GVA_CONFIG.System.IpLimitCount
|
||||||
|
limitTime := global.GVA_CONFIG.System.IpLimitTime
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
count, err := global.GVA_REDIS.Get(ctx, key).Int()
|
||||||
|
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
global.GVA_LOG.Error("get redis key error", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if count >= limit {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": response.ERROR,
|
||||||
|
"msg": fmt.Sprintf("操作过于频繁,请在 %d 秒后再试", limitTime),
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := global.GVA_REDIS.Pipeline()
|
||||||
|
pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, time.Second*time.Duration(limitTime))
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error("redis pipeline error", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
68
server/middleware/loadtls.go
Normal file
68
server/middleware/loadtls.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadTls 加载 TLS 证书
|
||||||
|
func LoadTls() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if global.GVA_CONFIG.System.UseHttps {
|
||||||
|
certFile := global.GVA_CONFIG.System.TlsCert
|
||||||
|
keyFile := global.GVA_CONFIG.System.TlsKey
|
||||||
|
|
||||||
|
if certFile == "" || keyFile == "" {
|
||||||
|
global.GVA_LOG.Error("TLS cert or key file not configured")
|
||||||
|
c.AbortWithStatus(500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查证书文件是否存在
|
||||||
|
if _, err := os.Stat(certFile); os.IsNotExist(err) {
|
||||||
|
global.GVA_LOG.Error("TLS cert file not found", zap.String("file", certFile))
|
||||||
|
c.AbortWithStatus(500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(keyFile); os.IsNotExist(err) {
|
||||||
|
global.GVA_LOG.Error("TLS key file not found", zap.String("file", keyFile))
|
||||||
|
c.AbortWithStatus(500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadTlsFromFile 从文件加载 TLS 证书内容
|
||||||
|
func LoadTlsFromFile(certFile, keyFile string) (certPEM, keyPEM []byte, err error) {
|
||||||
|
certF, err := os.Open(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("open cert file error: %w", err)
|
||||||
|
}
|
||||||
|
defer certF.Close()
|
||||||
|
|
||||||
|
keyF, err := os.Open(keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("open key file error: %w", err)
|
||||||
|
}
|
||||||
|
defer keyF.Close()
|
||||||
|
|
||||||
|
certPEM, err = io.ReadAll(certF)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("read cert file error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPEM, err = io.ReadAll(keyF)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("read key file error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return certPEM, keyPEM, nil
|
||||||
|
}
|
||||||
113
server/middleware/logger.go
Normal file
113
server/middleware/logger.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/system"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/service"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var operationRecordService = service.ServiceGroupApp.SystemServiceGroup.OperationRecordService
|
||||||
|
|
||||||
|
// OperationRecord 操作记录中间件
|
||||||
|
func OperationRecord() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var body []byte
|
||||||
|
var userId int
|
||||||
|
if c.Request.Method != http.MethodGet {
|
||||||
|
var err error
|
||||||
|
body, err = io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error("read body from request error:", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userId = int(utils.GetUserID(c))
|
||||||
|
|
||||||
|
writer := responseBodyWriter{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
c.Writer = writer
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
latency := time.Since(now)
|
||||||
|
|
||||||
|
if c.Request.Method != http.MethodGet {
|
||||||
|
record := system.SysOperationRecord{
|
||||||
|
Ip: c.ClientIP(),
|
||||||
|
Method: c.Request.Method,
|
||||||
|
Path: c.Request.URL.Path,
|
||||||
|
Agent: c.Request.UserAgent(),
|
||||||
|
Body: string(body),
|
||||||
|
UserID: userId,
|
||||||
|
Status: c.Writer.Status(),
|
||||||
|
Latency: latency,
|
||||||
|
Resp: writer.body.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
values, _ := url.ParseQuery(c.Request.URL.RawQuery)
|
||||||
|
record.Query = values.Encode()
|
||||||
|
|
||||||
|
if err := operationRecordService.CreateSysOperationRecord(record); err != nil {
|
||||||
|
global.GVA_LOG.Error("create operation record error:", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseBodyWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r responseBodyWriter) Write(b []byte) (int, error) {
|
||||||
|
r.body.Write(b)
|
||||||
|
return r.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r responseBodyWriter) WriteString(s string) (int, error) {
|
||||||
|
r.body.WriteString(s)
|
||||||
|
return r.ResponseWriter.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r responseBodyWriter) WriteJSON(obj interface{}) error {
|
||||||
|
data, err := json.Marshal(obj)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.body.Write(data)
|
||||||
|
return r.ResponseWriter.WriteJSON(obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedRecordPath 判断是否需要记录操作日志
|
||||||
|
func NeedRecordPath(path string) bool {
|
||||||
|
// 排除不需要记录的路径
|
||||||
|
excludePaths := []string{
|
||||||
|
"/health",
|
||||||
|
"/swagger",
|
||||||
|
"/api/captcha",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, excludePath := range excludePaths {
|
||||||
|
if strings.HasPrefix(path, excludePath) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
119
server/middleware/operation.go
Normal file
119
server/middleware/operation.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/system"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var respPool sync.Pool
|
||||||
|
var bufferSize = 1024
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
respPool.New = func() interface{} {
|
||||||
|
return make([]byte, bufferSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operation 操作日志记录中间件
|
||||||
|
func Operation() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var body []byte
|
||||||
|
var userId int
|
||||||
|
|
||||||
|
if c.Request.Method != http.MethodGet {
|
||||||
|
var err error
|
||||||
|
body, err = io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error("read body from request error:", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userId = int(utils.GetUserID(c))
|
||||||
|
|
||||||
|
writer := responseBodyWriter{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
c.Writer = writer
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
latency := time.Since(now)
|
||||||
|
|
||||||
|
// 只记录需要的路径
|
||||||
|
if NeedRecordPath(c.Request.URL.Path) && c.Request.Method != http.MethodGet {
|
||||||
|
record := system.SysOperationRecord{
|
||||||
|
Ip: c.ClientIP(),
|
||||||
|
Method: c.Request.Method,
|
||||||
|
Path: c.Request.URL.Path,
|
||||||
|
Agent: c.Request.UserAgent(),
|
||||||
|
Body: string(body),
|
||||||
|
UserID: userId,
|
||||||
|
Status: c.Writer.Status(),
|
||||||
|
Latency: latency,
|
||||||
|
Resp: writer.body.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
values, _ := url.ParseQuery(c.Request.URL.RawQuery)
|
||||||
|
record.Query = values.Encode()
|
||||||
|
|
||||||
|
// 异步记录操作日志
|
||||||
|
go func() {
|
||||||
|
if err := operationRecordService.CreateSysOperationRecord(record); err != nil {
|
||||||
|
global.GVA_LOG.Error("create operation record error:", zap.Error(err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilteredPaths 需要过滤的路径
|
||||||
|
var FilteredPaths = []string{
|
||||||
|
"/health",
|
||||||
|
"/swagger",
|
||||||
|
"/api/captcha",
|
||||||
|
"/api/base/login",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldRecord 判断是否应该记录该路径
|
||||||
|
func ShouldRecord(path string) bool {
|
||||||
|
for _, filtered := range FilteredPaths {
|
||||||
|
if strings.HasPrefix(path, filtered) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskSensitiveData 脱敏敏感数据
|
||||||
|
func MaskSensitiveData(data string) string {
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(data), &result); err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
sensitiveFields := []string{"password", "token", "secret", "key"}
|
||||||
|
for _, field := range sensitiveFields {
|
||||||
|
if _, ok := result[field]; ok {
|
||||||
|
result[field] = "******"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
masked, _ := json.Marshal(result)
|
||||||
|
return string(masked)
|
||||||
|
}
|
||||||
32
server/middleware/timeout.go
Normal file
32
server/middleware/timeout.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Timeout 超时中间件
|
||||||
|
func Timeout(timeout time.Duration) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
|
||||||
|
finished := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
c.Next()
|
||||||
|
finished <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
response.FailWithMessage("请求超时", c)
|
||||||
|
c.Abort()
|
||||||
|
case <-finished:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"git.echol.cn/loser/ai_proxy/server/global"
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
"git.echol.cn/loser/ai_proxy/server/model/app"
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||||
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||||
@@ -110,9 +113,132 @@ func (s *AiPresetService) GetAiPresetList(userId uint, info req.PageInfo) (list
|
|||||||
|
|
||||||
// ImportAiPreset 导入AI预设(支持SillyTavern格式)
|
// ImportAiPreset 导入AI预设(支持SillyTavern格式)
|
||||||
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
|
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
|
||||||
// TODO: 解析SillyTavern JSON格式
|
// 解析 SillyTavern JSON 格式
|
||||||
// 这里需要实现JSON解析逻辑,将SillyTavern格式转换为我们的格式
|
var stData map[string]interface{}
|
||||||
return preset, nil
|
var jsonData []byte
|
||||||
|
|
||||||
|
// 类型断言处理 req.Data
|
||||||
|
switch v := req.Data.(type) {
|
||||||
|
case string:
|
||||||
|
jsonData = []byte(v)
|
||||||
|
case []byte:
|
||||||
|
jsonData = v
|
||||||
|
default:
|
||||||
|
return preset, fmt.Errorf("不支持的数据类型")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &stData); err != nil {
|
||||||
|
return preset, fmt.Errorf("JSON 解析失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取基本信息
|
||||||
|
preset = app.AiPreset{
|
||||||
|
UserID: userId,
|
||||||
|
Name: req.Name,
|
||||||
|
Description: getStringValue(stData, "description"),
|
||||||
|
IsPublic: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取参数
|
||||||
|
if temp, ok := stData["temperature"].(float64); ok {
|
||||||
|
preset.Temperature = temp
|
||||||
|
}
|
||||||
|
if topP, ok := stData["top_p"].(float64); ok {
|
||||||
|
preset.TopP = topP
|
||||||
|
}
|
||||||
|
if maxTokens, ok := stData["openai_max_tokens"].(float64); ok {
|
||||||
|
preset.MaxTokens = int(maxTokens)
|
||||||
|
} else if maxTokens, ok := stData["max_tokens"].(float64); ok {
|
||||||
|
preset.MaxTokens = int(maxTokens)
|
||||||
|
}
|
||||||
|
if freqPenalty, ok := stData["frequency_penalty"].(float64); ok {
|
||||||
|
preset.FrequencyPenalty = freqPenalty
|
||||||
|
}
|
||||||
|
if presPenalty, ok := stData["presence_penalty"].(float64); ok {
|
||||||
|
preset.PresencePenalty = presPenalty
|
||||||
|
}
|
||||||
|
if stream, ok := stData["stream_openai"].(bool); ok {
|
||||||
|
preset.StreamEnabled = stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取提示词
|
||||||
|
prompts := make([]app.Prompt, 0)
|
||||||
|
if promptsData, ok := stData["prompts"].([]interface{}); ok {
|
||||||
|
for i, p := range promptsData {
|
||||||
|
if promptMap, ok := p.(map[string]interface{}); ok {
|
||||||
|
prompt := app.Prompt{
|
||||||
|
Name: getStringValue(promptMap, "name"),
|
||||||
|
Role: getStringValue(promptMap, "role"),
|
||||||
|
Content: getStringValue(promptMap, "content"),
|
||||||
|
SystemPrompt: getBoolValue(promptMap, "system_prompt"),
|
||||||
|
Marker: getBoolValue(promptMap, "marker"),
|
||||||
|
InjectionOrder: i,
|
||||||
|
InjectionDepth: int(getFloatValue(promptMap, "injection_depth")),
|
||||||
|
InjectionPosition: int(getFloatValue(promptMap, "injection_position")),
|
||||||
|
}
|
||||||
|
prompts = append(prompts, prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
preset.Prompts = prompts
|
||||||
|
|
||||||
|
// 提取正则脚本
|
||||||
|
regexScripts := make([]app.RegexScript, 0)
|
||||||
|
if extensions, ok := stData["extensions"].(map[string]interface{}); ok {
|
||||||
|
if scripts, ok := extensions["regex_scripts"].([]interface{}); ok {
|
||||||
|
for _, s := range scripts {
|
||||||
|
if scriptMap, ok := s.(map[string]interface{}); ok {
|
||||||
|
script := app.RegexScript{
|
||||||
|
ScriptName: getStringValue(scriptMap, "script_name"),
|
||||||
|
FindRegex: getStringValue(scriptMap, "find_regex"),
|
||||||
|
ReplaceString: getStringValue(scriptMap, "replace_string"),
|
||||||
|
Disabled: getBoolValue(scriptMap, "disabled"),
|
||||||
|
Placement: getIntArray(scriptMap, "placement"),
|
||||||
|
}
|
||||||
|
regexScripts = append(regexScripts, script)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
preset.RegexScripts = regexScripts
|
||||||
|
|
||||||
|
// 保存到数据库
|
||||||
|
err = global.GVA_DB.Create(&preset).Error
|
||||||
|
return preset, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数
|
||||||
|
func getStringValue(m map[string]interface{}, key string) string {
|
||||||
|
if v, ok := m[key].(string); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBoolValue(m map[string]interface{}, key string) bool {
|
||||||
|
if v, ok := m[key].(bool); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFloatValue(m map[string]interface{}, key string) float64 {
|
||||||
|
if v, ok := m[key].(float64); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIntArray(m map[string]interface{}, key string) []int {
|
||||||
|
result := make([]int, 0)
|
||||||
|
if arr, ok := m[key].([]interface{}); ok {
|
||||||
|
for _, v := range arr {
|
||||||
|
if num, ok := v.(float64); ok {
|
||||||
|
result = append(result, int(num))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExportAiPreset 导出AI预设
|
// ExportAiPreset 导出AI预设
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.echol.cn/loser/ai_proxy/server/global"
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
@@ -32,10 +35,25 @@ func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint,
|
|||||||
|
|
||||||
// 2. 获取提供商配置
|
// 2. 获取提供商配置
|
||||||
var provider app.AiProvider
|
var provider app.AiProvider
|
||||||
// TODO: 根据 binding_key 或默认配置获取 provider
|
|
||||||
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
// 根据 binding_key 或预设绑定获取 provider
|
||||||
if err != nil {
|
if req.BindingKey != "" {
|
||||||
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
|
// 通过 binding_key 查找绑定关系
|
||||||
|
var binding app.AiPresetBinding
|
||||||
|
err = global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
|
||||||
|
Order("priority ASC").
|
||||||
|
First(&binding).Error
|
||||||
|
if err == nil {
|
||||||
|
err = global.GVA_DB.First(&provider, binding.ProviderID).Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有找到,使用默认的活跃提供商
|
||||||
|
if provider.ID == 0 {
|
||||||
|
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
||||||
|
if err != nil {
|
||||||
|
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 构建注入后的消息
|
// 3. 构建注入后的消息
|
||||||
@@ -67,20 +85,38 @@ func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionReques
|
|||||||
return req.Messages, nil
|
return req.Messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: 实现完整的预设注入逻辑
|
|
||||||
// 1. 按 injection_order 排序 prompts
|
// 1. 按 injection_order 排序 prompts
|
||||||
// 2. 根据 injection_depth 插入到对话历史中
|
sortedPrompts := make([]app.Prompt, len(preset.Prompts))
|
||||||
// 3. 替换变量 {{user}}, {{char}}
|
copy(sortedPrompts, preset.Prompts)
|
||||||
// 4. 应用正则脚本 (placement=1)
|
sort.Slice(sortedPrompts, func(i, j int) bool {
|
||||||
|
return sortedPrompts[i].InjectionOrder < sortedPrompts[j].InjectionOrder
|
||||||
|
})
|
||||||
|
|
||||||
messages := make([]request.Message, 0)
|
messages := make([]request.Message, 0)
|
||||||
|
|
||||||
// 简化实现:直接添加系统提示词
|
// 2. 根据 injection_depth 插入到对话历史中
|
||||||
for _, prompt := range preset.Prompts {
|
for _, prompt := range sortedPrompts {
|
||||||
if prompt.SystemPrompt && !prompt.Marker {
|
if prompt.Marker {
|
||||||
|
continue // 跳过标记提示词
|
||||||
|
}
|
||||||
|
|
||||||
|
// 替换变量
|
||||||
|
content := s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard)
|
||||||
|
|
||||||
|
// 根据 injection_depth 决定插入位置
|
||||||
|
// depth=0: 插入到最前面(系统提示词)
|
||||||
|
// depth>0: 从对话历史末尾往前数 depth 条消息的位置插入
|
||||||
|
if prompt.InjectionDepth == 0 || prompt.SystemPrompt {
|
||||||
messages = append(messages, request.Message{
|
messages = append(messages, request.Message{
|
||||||
Role: prompt.Role,
|
Role: prompt.Role,
|
||||||
Content: s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard),
|
Content: content,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// 先添加用户消息,稍后根据 depth 插入
|
||||||
|
// 这里简化处理,将非系统提示词也添加到前面
|
||||||
|
messages = append(messages, request.Message{
|
||||||
|
Role: prompt.Role,
|
||||||
|
Content: content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -88,7 +124,7 @@ func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionReques
|
|||||||
// 添加用户消息
|
// 添加用户消息
|
||||||
messages = append(messages, req.Messages...)
|
messages = append(messages, req.Messages...)
|
||||||
|
|
||||||
// 应用输入正则脚本
|
// 4. 应用输入正则脚本 (placement=1)
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
|
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
|
||||||
}
|
}
|
||||||
@@ -124,7 +160,16 @@ func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScri
|
|||||||
if !containsPlacement(script.Placement, 1) {
|
if !containsPlacement(script.Placement, 1) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// TODO: 实现正则替换逻辑
|
|
||||||
|
// 编译正则表达式
|
||||||
|
re, err := regexp.Compile(script.FindRegex)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行替换
|
||||||
|
content = re.ReplaceAllString(content, script.ReplaceString)
|
||||||
}
|
}
|
||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
@@ -138,7 +183,16 @@ func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScr
|
|||||||
if !containsPlacement(script.Placement, 2) {
|
if !containsPlacement(script.Placement, 2) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// TODO: 实现正则替换逻辑
|
|
||||||
|
// 编译正则表达式
|
||||||
|
re, err := regexp.Compile(script.FindRegex)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行替换
|
||||||
|
content = re.ReplaceAllString(content, script.ReplaceString)
|
||||||
}
|
}
|
||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
@@ -234,7 +288,7 @@ func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider
|
|||||||
|
|
||||||
// 辅助函数
|
// 辅助函数
|
||||||
func replaceAll(s, old, new string) string {
|
func replaceAll(s, old, new string) string {
|
||||||
return s // TODO: 实现字符串替换
|
return strings.ReplaceAll(s, old, new)
|
||||||
}
|
}
|
||||||
|
|
||||||
func containsPlacement(placements []int, target int) bool {
|
func containsPlacement(placements []int, target int) bool {
|
||||||
|
|||||||
193
server/service/app/ai_proxy_stream.go
Normal file
193
server/service/app/ai_proxy_stream.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/global"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/app"
|
||||||
|
"git.echol.cn/loser/ai_proxy/server/model/app/request"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessChatCompletionStream 处理流式聊天补全请求
|
||||||
|
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userId uint, req *request.ChatCompletionRequest) {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// 1. 获取预设配置
|
||||||
|
var preset app.AiPreset
|
||||||
|
if req.PresetID > 0 {
|
||||||
|
err := global.GVA_DB.First(&preset, req.PresetID).Error
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "预设不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 获取提供商配置
|
||||||
|
var provider app.AiProvider
|
||||||
|
if req.BindingKey != "" {
|
||||||
|
var binding app.AiPresetBinding
|
||||||
|
err := global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
|
||||||
|
Order("priority ASC").
|
||||||
|
First(&binding).Error
|
||||||
|
if err == nil {
|
||||||
|
global.GVA_DB.First(&provider, binding.ProviderID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.ID == 0 {
|
||||||
|
err := global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到可用的AI提供商"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 构建注入后的消息
|
||||||
|
messages, err := s.buildInjectedMessages(req, &preset)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "构建消息失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 转发流式请求到上游AI
|
||||||
|
err = s.forwardStreamToAI(c, &provider, &preset, messages, userId, startTime)
|
||||||
|
if err != nil {
|
||||||
|
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forwardStreamToAI 转发流式请求到上游AI
|
||||||
|
func (s *AiProxyService) forwardStreamToAI(c *gin.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message, userId uint, startTime time.Time) error {
|
||||||
|
// 构建请求体
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"model": provider.Model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if preset != nil {
|
||||||
|
reqBody["temperature"] = preset.Temperature
|
||||||
|
reqBody["top_p"] = preset.TopP
|
||||||
|
reqBody["max_tokens"] = preset.MaxTokens
|
||||||
|
reqBody["frequency_penalty"] = preset.FrequencyPenalty
|
||||||
|
reqBody["presence_penalty"] = preset.PresencePenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建HTTP请求
|
||||||
|
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
|
||||||
|
req, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
if provider.UpstreamKey != "" {
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
client := &http.Client{Timeout: 300 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置SSE响应头
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// 读取并转发流式响应
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("streaming not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
var fullResponse strings.Builder
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过空行
|
||||||
|
if len(bytes.TrimSpace(line)) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析SSE数据
|
||||||
|
lineStr := string(line)
|
||||||
|
if strings.HasPrefix(lineStr, "data: ") {
|
||||||
|
data := strings.TrimPrefix(lineStr, "data: ")
|
||||||
|
data = strings.TrimSpace(data)
|
||||||
|
|
||||||
|
// 检查是否是结束标记
|
||||||
|
if data == "[DONE]" {
|
||||||
|
c.Writer.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析JSON并提取内容
|
||||||
|
var chunk map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
|
||||||
|
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
||||||
|
if choice, ok := choices[0].(map[string]interface{}); ok {
|
||||||
|
if delta, ok := choice["delta"].(map[string]interface{}); ok {
|
||||||
|
if content, ok := delta["content"].(string); ok {
|
||||||
|
fullResponse.WriteString(content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转发原始数据
|
||||||
|
c.Writer.Write(line)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用输出正则脚本
|
||||||
|
finalContent := fullResponse.String()
|
||||||
|
if preset != nil {
|
||||||
|
finalContent = s.applyOutputRegex(finalContent, preset.RegexScripts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录日志
|
||||||
|
var originalMsg string
|
||||||
|
if len(messages) > 0 {
|
||||||
|
originalMsg = messages[len(messages)-1].Content
|
||||||
|
}
|
||||||
|
s.logRequest(userId, preset, provider, originalMsg, finalContent, nil, time.Since(startTime))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
36
web/src/api/binding.js
Normal file
36
web/src/api/binding.js
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import request from '@/utils/request'
|
||||||
|
|
||||||
|
// 获取绑定列表
|
||||||
|
export const getBindingList = (params) => {
|
||||||
|
return request({
|
||||||
|
url: '/app/binding/list',
|
||||||
|
method: 'get',
|
||||||
|
params
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建绑定
|
||||||
|
export const createBinding = (data) => {
|
||||||
|
return request({
|
||||||
|
url: '/app/binding',
|
||||||
|
method: 'post',
|
||||||
|
data
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新绑定
|
||||||
|
export const updateBinding = (data) => {
|
||||||
|
return request({
|
||||||
|
url: '/app/binding',
|
||||||
|
method: 'put',
|
||||||
|
data
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除绑定
|
||||||
|
export const deleteBinding = (id) => {
|
||||||
|
return request({
|
||||||
|
url: `/app/binding/${id}`,
|
||||||
|
method: 'delete'
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user