🎨 添加中间件 && 完善预设注入功能 && 新增流式传输
This commit is contained in:
44
server/middleware/app_jwt.go
Normal file
44
server/middleware/app_jwt.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AppJWTAuth 前台用户 JWT 认证中间件
|
||||
func AppJWTAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
token = c.GetHeader("x-token")
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
response.FailWithDetailed(gin.H{"reload": true}, "未登录或非法访问", c)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 移除 Bearer 前缀
|
||||
if len(token) > 7 && token[:7] == "Bearer " {
|
||||
token = token[7:]
|
||||
}
|
||||
|
||||
// 解析 token
|
||||
claims, err := utils.ParseAppToken(token)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("解析 App Token 失败: " + err.Error())
|
||||
response.FailWithDetailed(gin.H{"reload": true}, "授权已过期或无效", c)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存入上下文
|
||||
c.Set("appClaims", claims)
|
||||
c.Set("userId", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
25
server/middleware/cors.go
Normal file
25
server/middleware/cors.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Cors 直接放行全部跨域请求
|
||||
func Cors() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token,X-Token,X-User-Id")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS,DELETE,PUT")
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type, New-Token, New-Expires-At")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// 放行所有OPTIONS方法
|
||||
if method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
21
server/middleware/email.go
Normal file
21
server/middleware/email.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ErrorToEmail 错误发送邮件中间件
|
||||
func ErrorToEmail() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
global.GVA_LOG.Error("panic error", zap.Any("error", err))
|
||||
response.FailWithMessage("服务器内部错误", c)
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
18
server/middleware/error.go
Normal file
18
server/middleware/error.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ErrorLogger 错误日志中间件
|
||||
func ErrorLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
// 记录错误
|
||||
for _, err := range c.Errors {
|
||||
global.GVA_LOG.Error("request error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
75
server/middleware/jwt.go
Normal file
75
server/middleware/jwt.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/system"
|
||||
"git.echol.cn/loser/ai_proxy/server/service"
|
||||
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var jwtService = service.ServiceGroupApp.SystemServiceGroup.JwtService
|
||||
|
||||
func JWTAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从请求头获取 token
|
||||
token := utils.GetToken(c)
|
||||
if token == "" {
|
||||
response.FailWithDetailed(gin.H{"reload": true}, "未登录或非法访问", c)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 JWT 黑名单
|
||||
if jwtService.IsBlacklist(token) {
|
||||
response.FailWithDetailed(gin.H{"reload": true}, "您的帐户异地登陆或令牌失效", c)
|
||||
utils.ClearToken(c)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 token
|
||||
j := utils.NewJWT()
|
||||
claims, err := j.ParseToken(token)
|
||||
if err != nil {
|
||||
if err == utils.TokenExpired {
|
||||
response.FailWithDetailed(gin.H{"reload": true}, "授权已过期", c)
|
||||
utils.ClearToken(c)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
response.FailWithDetailed(gin.H{"reload": true}, err.Error(), c)
|
||||
utils.ClearToken(c)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Token 续期检查
|
||||
if claims.ExpiresAt.Unix()-time.Now().Unix() < claims.BufferTime {
|
||||
dr, _ := utils.ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
|
||||
claims.ExpiresAt = utils.NewNumericDate(time.Now().Add(dr))
|
||||
newToken, _ := j.CreateTokenByOldToken(token, *claims)
|
||||
newClaims, _ := j.ParseToken(newToken)
|
||||
c.Header("new-token", newToken)
|
||||
c.Header("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
|
||||
utils.SetToken(c, newToken, int(dr.Seconds()))
|
||||
if global.GVA_CONFIG.System.UseMultipoint {
|
||||
RedisJwtToken, err := jwtService.GetRedisJWT(newClaims.Username)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("get redis jwt failed", zap.Error(err))
|
||||
} else {
|
||||
_ = jwtService.JsonInBlacklist(system.JwtBlacklist{Jwt: RedisJwtToken})
|
||||
}
|
||||
_ = jwtService.SetRedisJWT(newToken, newClaims.Username)
|
||||
}
|
||||
}
|
||||
|
||||
c.Set("claims", claims)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
66
server/middleware/limit_ip.go
Normal file
66
server/middleware/limit_ip.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type LimitConfig struct {
|
||||
GenerationDuration time.Duration
|
||||
Limit int
|
||||
}
|
||||
|
||||
func (l LimitConfig) LimitKey(c *gin.Context) string {
|
||||
return "GVA_Limit" + c.ClientIP()
|
||||
}
|
||||
|
||||
func (l LimitConfig) GetLimit(c *gin.Context) int {
|
||||
return l.Limit
|
||||
}
|
||||
|
||||
func (l LimitConfig) Reached(c *gin.Context) response.Response {
|
||||
return response.Response{Code: response.ERROR, Data: nil, Msg: "操作过于频繁,请稍后再试"}
|
||||
}
|
||||
|
||||
// IPLimit IP 限流中间件
|
||||
func IPLimit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if global.GVA_CONFIG.System.UseRedis {
|
||||
key := "GVA_Limit" + c.ClientIP()
|
||||
limit := global.GVA_CONFIG.System.IpLimitCount
|
||||
limitTime := global.GVA_CONFIG.System.IpLimitTime
|
||||
|
||||
ctx := context.Background()
|
||||
count, err := global.GVA_REDIS.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
|
||||
global.GVA_LOG.Error("get redis key error", zap.Error(err))
|
||||
}
|
||||
|
||||
if count >= limit {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": response.ERROR,
|
||||
"msg": fmt.Sprintf("操作过于频繁,请在 %d 秒后再试", limitTime),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
pipe := global.GVA_REDIS.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, time.Second*time.Duration(limitTime))
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("redis pipeline error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
68
server/middleware/loadtls.go
Normal file
68
server/middleware/loadtls.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LoadTls 加载 TLS 证书
|
||||
func LoadTls() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if global.GVA_CONFIG.System.UseHttps {
|
||||
certFile := global.GVA_CONFIG.System.TlsCert
|
||||
keyFile := global.GVA_CONFIG.System.TlsKey
|
||||
|
||||
if certFile == "" || keyFile == "" {
|
||||
global.GVA_LOG.Error("TLS cert or key file not configured")
|
||||
c.AbortWithStatus(500)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查证书文件是否存在
|
||||
if _, err := os.Stat(certFile); os.IsNotExist(err) {
|
||||
global.GVA_LOG.Error("TLS cert file not found", zap.String("file", certFile))
|
||||
c.AbortWithStatus(500)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := os.Stat(keyFile); os.IsNotExist(err) {
|
||||
global.GVA_LOG.Error("TLS key file not found", zap.String("file", keyFile))
|
||||
c.AbortWithStatus(500)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// LoadTlsFromFile 从文件加载 TLS 证书内容
|
||||
func LoadTlsFromFile(certFile, keyFile string) (certPEM, keyPEM []byte, err error) {
|
||||
certF, err := os.Open(certFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("open cert file error: %w", err)
|
||||
}
|
||||
defer certF.Close()
|
||||
|
||||
keyF, err := os.Open(keyFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("open key file error: %w", err)
|
||||
}
|
||||
defer keyF.Close()
|
||||
|
||||
certPEM, err = io.ReadAll(certF)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read cert file error: %w", err)
|
||||
}
|
||||
|
||||
keyPEM, err = io.ReadAll(keyF)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read key file error: %w", err)
|
||||
}
|
||||
|
||||
return certPEM, keyPEM, nil
|
||||
}
|
||||
113
server/middleware/logger.go
Normal file
113
server/middleware/logger.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/system"
|
||||
"git.echol.cn/loser/ai_proxy/server/service"
|
||||
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var operationRecordService = service.ServiceGroupApp.SystemServiceGroup.OperationRecordService
|
||||
|
||||
// OperationRecord 操作记录中间件
|
||||
func OperationRecord() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var body []byte
|
||||
var userId int
|
||||
if c.Request.Method != http.MethodGet {
|
||||
var err error
|
||||
body, err = io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("read body from request error:", zap.Error(err))
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
}
|
||||
}
|
||||
|
||||
userId = int(utils.GetUserID(c))
|
||||
|
||||
writer := responseBodyWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
body: &bytes.Buffer{},
|
||||
}
|
||||
c.Writer = writer
|
||||
now := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(now)
|
||||
|
||||
if c.Request.Method != http.MethodGet {
|
||||
record := system.SysOperationRecord{
|
||||
Ip: c.ClientIP(),
|
||||
Method: c.Request.Method,
|
||||
Path: c.Request.URL.Path,
|
||||
Agent: c.Request.UserAgent(),
|
||||
Body: string(body),
|
||||
UserID: userId,
|
||||
Status: c.Writer.Status(),
|
||||
Latency: latency,
|
||||
Resp: writer.body.String(),
|
||||
}
|
||||
|
||||
values, _ := url.ParseQuery(c.Request.URL.RawQuery)
|
||||
record.Query = values.Encode()
|
||||
|
||||
if err := operationRecordService.CreateSysOperationRecord(record); err != nil {
|
||||
global.GVA_LOG.Error("create operation record error:", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type responseBodyWriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
}
|
||||
|
||||
func (r responseBodyWriter) Write(b []byte) (int, error) {
|
||||
r.body.Write(b)
|
||||
return r.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (r responseBodyWriter) WriteString(s string) (int, error) {
|
||||
r.body.WriteString(s)
|
||||
return r.ResponseWriter.WriteString(s)
|
||||
}
|
||||
|
||||
func (r responseBodyWriter) WriteJSON(obj interface{}) error {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.body.Write(data)
|
||||
return r.ResponseWriter.WriteJSON(obj)
|
||||
}
|
||||
|
||||
// NeedRecordPath 判断是否需要记录操作日志
|
||||
func NeedRecordPath(path string) bool {
|
||||
// 排除不需要记录的路径
|
||||
excludePaths := []string{
|
||||
"/health",
|
||||
"/swagger",
|
||||
"/api/captcha",
|
||||
}
|
||||
|
||||
for _, excludePath := range excludePaths {
|
||||
if strings.HasPrefix(path, excludePath) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
119
server/middleware/operation.go
Normal file
119
server/middleware/operation.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/system"
|
||||
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var respPool sync.Pool
|
||||
var bufferSize = 1024
|
||||
|
||||
func init() {
|
||||
respPool.New = func() interface{} {
|
||||
return make([]byte, bufferSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Operation 操作日志记录中间件
|
||||
func Operation() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var body []byte
|
||||
var userId int
|
||||
|
||||
if c.Request.Method != http.MethodGet {
|
||||
var err error
|
||||
body, err = io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("read body from request error:", zap.Error(err))
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
}
|
||||
}
|
||||
|
||||
userId = int(utils.GetUserID(c))
|
||||
|
||||
writer := responseBodyWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
body: &bytes.Buffer{},
|
||||
}
|
||||
c.Writer = writer
|
||||
now := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(now)
|
||||
|
||||
// 只记录需要的路径
|
||||
if NeedRecordPath(c.Request.URL.Path) && c.Request.Method != http.MethodGet {
|
||||
record := system.SysOperationRecord{
|
||||
Ip: c.ClientIP(),
|
||||
Method: c.Request.Method,
|
||||
Path: c.Request.URL.Path,
|
||||
Agent: c.Request.UserAgent(),
|
||||
Body: string(body),
|
||||
UserID: userId,
|
||||
Status: c.Writer.Status(),
|
||||
Latency: latency,
|
||||
Resp: writer.body.String(),
|
||||
}
|
||||
|
||||
values, _ := url.ParseQuery(c.Request.URL.RawQuery)
|
||||
record.Query = values.Encode()
|
||||
|
||||
// 异步记录操作日志
|
||||
go func() {
|
||||
if err := operationRecordService.CreateSysOperationRecord(record); err != nil {
|
||||
global.GVA_LOG.Error("create operation record error:", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FilteredPaths 需要过滤的路径
|
||||
var FilteredPaths = []string{
|
||||
"/health",
|
||||
"/swagger",
|
||||
"/api/captcha",
|
||||
"/api/base/login",
|
||||
}
|
||||
|
||||
// ShouldRecord 判断是否应该记录该路径
|
||||
func ShouldRecord(path string) bool {
|
||||
for _, filtered := range FilteredPaths {
|
||||
if strings.HasPrefix(path, filtered) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// MaskSensitiveData 脱敏敏感数据
|
||||
func MaskSensitiveData(data string) string {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &result); err != nil {
|
||||
return data
|
||||
}
|
||||
|
||||
sensitiveFields := []string{"password", "token", "secret", "key"}
|
||||
for _, field := range sensitiveFields {
|
||||
if _, ok := result[field]; ok {
|
||||
result[field] = "******"
|
||||
}
|
||||
}
|
||||
|
||||
masked, _ := json.Marshal(result)
|
||||
return string(masked)
|
||||
}
|
||||
32
server/middleware/timeout.go
Normal file
32
server/middleware/timeout.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Timeout 超时中间件
|
||||
func Timeout(timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
c.Next()
|
||||
finished <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
response.FailWithMessage("请求超时", c)
|
||||
c.Abort()
|
||||
case <-finished:
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user