🎨 优化项目结构 && 完善ai配置

This commit is contained in:
2026-03-03 15:39:23 +08:00
parent 557c865948
commit 2714e63d2a
585 changed files with 62223 additions and 100018 deletions

View File

@@ -1,44 +0,0 @@
package middleware
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/response"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/gin-gonic/gin"
)
// AppJWTAuth 前台用户 JWT 认证中间件
func AppJWTAuth() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" {
token = c.GetHeader("x-token")
}
if token == "" {
response.FailWithDetailed(gin.H{"reload": true}, "未登录或非法访问", c)
c.Abort()
return
}
// 移除 Bearer 前缀
if len(token) > 7 && token[:7] == "Bearer " {
token = token[7:]
}
// 解析 token
claims, err := utils.ParseAppToken(token)
if err != nil {
global.GVA_LOG.Error("解析 App Token 失败: " + err.Error())
response.FailWithDetailed(gin.H{"reload": true}, "授权已过期或无效", c)
c.Abort()
return
}
// 将用户信息存入上下文
c.Set("appClaims", claims)
c.Set("userId", claims.UserID)
c.Set("username", claims.Username)
c.Next()
}
}

View File

@@ -0,0 +1,32 @@
package middleware
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/response"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/gin-gonic/gin"
"strconv"
"strings"
)
// CasbinHandler 拦截器
func CasbinHandler() gin.HandlerFunc {
return func(c *gin.Context) {
waitUse, _ := utils.GetClaims(c)
//获取请求的PATH
path := c.Request.URL.Path
obj := strings.TrimPrefix(path, global.GVA_CONFIG.System.RouterPrefix)
// 获取请求方法
act := c.Request.Method
// 获取用户的角色
sub := strconv.Itoa(int(waitUse.AuthorityId))
e := utils.GetCasbin() // 判断策略中是否存在
success, _ := e.Enforce(sub, obj, act)
if !success {
response.FailWithDetailed(gin.H{}, "权限不足", c)
c.Abort()
return
}
c.Next()
}
}

View File

@@ -1,11 +1,13 @@
package middleware
import (
"git.echol.cn/loser/ai_proxy/server/config"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/gin-gonic/gin"
"net/http"
)
// Cors 直接放行全部跨域请求
// Cors 直接放行所有跨域请求并放行所有 OPTIONS 方法
func Cors() gin.HandlerFunc {
return func(c *gin.Context) {
method := c.Request.Method
@@ -20,6 +22,52 @@ func Cors() gin.HandlerFunc {
if method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
}
// 处理请求
c.Next()
}
}
// CorsByRules 按照配置处理跨域请求
func CorsByRules() gin.HandlerFunc {
// 放行全部
if global.GVA_CONFIG.Cors.Mode == "allow-all" {
return Cors()
}
return func(c *gin.Context) {
whitelist := checkCors(c.GetHeader("origin"))
// 通过检查, 添加请求头
if whitelist != nil {
c.Header("Access-Control-Allow-Origin", whitelist.AllowOrigin)
c.Header("Access-Control-Allow-Headers", whitelist.AllowHeaders)
c.Header("Access-Control-Allow-Methods", whitelist.AllowMethods)
c.Header("Access-Control-Expose-Headers", whitelist.ExposeHeaders)
if whitelist.AllowCredentials {
c.Header("Access-Control-Allow-Credentials", "true")
}
}
// 严格白名单模式且未通过检查,直接拒绝处理请求
if whitelist == nil && global.GVA_CONFIG.Cors.Mode == "strict-whitelist" && !(c.Request.Method == "GET" && c.Request.URL.Path == "/health") {
c.AbortWithStatus(http.StatusForbidden)
} else {
// 非严格白名单模式,无论是否通过检查均放行所有 OPTIONS 方法
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
}
}
// 处理请求
c.Next()
}
}
func checkCors(currentOrigin string) *config.CORSWhitelist {
for _, whitelist := range global.GVA_CONFIG.Cors.Whitelist {
// 遍历配置中的跨域头,寻找匹配项
if currentOrigin == whitelist.AllowOrigin {
return &whitelist
}
}
return nil
}

View File

@@ -1,21 +1,58 @@
package middleware
import (
"bytes"
"io"
"strconv"
"time"
"git.echol.cn/loser/ai_proxy/server/plugin/email/utils"
utils2 "git.echol.cn/loser/ai_proxy/server/utils"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/response"
"git.echol.cn/loser/ai_proxy/server/model/system"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ErrorToEmail 错误发送邮件中间件
func ErrorToEmail() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
global.GVA_LOG.Error("panic error", zap.Any("error", err))
response.FailWithMessage("服务器内部错误", c)
var username string
claims, _ := utils2.GetClaims(c)
if claims.Username != "" {
username = claims.Username
} else {
id, _ := strconv.Atoi(c.Request.Header.Get("x-user-id"))
var u system.SysUser
err := global.GVA_DB.Where("id = ?", id).First(&u).Error
if err != nil {
username = "Unknown"
}
}()
username = u.Username
}
body, _ := io.ReadAll(c.Request.Body)
// 再重新写回请求体body中ioutil.ReadAll会清空c.Request.Body中的数据
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
record := system.SysOperationRecord{
Ip: c.ClientIP(),
Method: c.Request.Method,
Path: c.Request.URL.Path,
Agent: c.Request.UserAgent(),
Body: string(body),
}
now := time.Now()
c.Next()
latency := time.Since(now)
status := c.Writer.Status()
record.ErrorMessage = c.Errors.ByType(gin.ErrorTypePrivate).String()
str := "接收到的请求为" + record.Body + "\n" + "请求方式为" + record.Method + "\n" + "报错信息如下" + record.ErrorMessage + "\n" + "耗时" + latency.String() + "\n"
if status != 200 {
subject := username + "" + record.Ip + "调用了" + record.Path + "报错了"
if err := utils.ErrorToEmail(subject, str); err != nil {
global.GVA_LOG.Error("ErrorToEmail Failed, err:", zap.Error(err))
}
}
}
}

View File

@@ -1,18 +1,80 @@
package middleware
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"os"
"runtime/debug"
"strings"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ErrorLogger 错误日志中间件
func ErrorLogger() gin.HandlerFunc {
// GinRecovery recover掉项目可能出现的panic并使用zap记录相关日志
func GinRecovery(stack bool) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
httpRequest, _ := httputil.DumpRequest(c.Request, false)
if brokenPipe {
global.GVA_LOG.Error(c.Request.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
// If the connection is dead, we can't write a status to it.
_ = c.Error(err.(error)) // nolint: errcheck
c.Abort()
return
}
if stack {
form := "后端"
info := fmt.Sprintf("Panic: %v\nRequest: %s\nStack: %s", err, string(httpRequest), string(debug.Stack()))
level := "error"
_ = service.ServiceGroupApp.SystemServiceGroup.SysErrorService.CreateSysError(context.Background(), &system.SysError{
Form: &form,
Info: &info,
Level: level,
})
global.GVA_LOG.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
} else {
form := "后端"
info := fmt.Sprintf("Panic: %v\nRequest: %s", err, string(httpRequest))
level := "error"
_ = service.ServiceGroupApp.SystemServiceGroup.SysErrorService.CreateSysError(context.Background(), &system.SysError{
Form: &form,
Info: &info,
Level: level,
})
global.GVA_LOG.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
}
c.AbortWithStatus(http.StatusInternalServerError)
}
}()
c.Next()
// 记录错误
for _, err := range c.Errors {
global.GVA_LOG.Error("request error", zap.Error(err))
}
}
}

View File

@@ -1,75 +1,89 @@
package middleware
import (
"errors"
"strconv"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/response"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/service"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
"github.com/golang-jwt/jwt/v5"
var jwtService = service.ServiceGroupApp.SystemServiceGroup.JwtService
"git.echol.cn/loser/ai_proxy/server/model/common/response"
"github.com/gin-gonic/gin"
)
func JWTAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 从请求头获取 token
// 我们这里jwt鉴权取头部信息 x-token 登录时回返回token信息 这里前端需要把token存储到cookie或者本地localStorage中 不过需要跟后端协商过期时间 可以约定刷新令牌或者重新登录
token := utils.GetToken(c)
if token == "" {
response.FailWithDetailed(gin.H{"reload": true}, "未登录或非法访问", c)
response.NoAuth("未登录或非法访问,请登录", c)
c.Abort()
return
}
// 检查 JWT 黑名单
if jwtService.IsBlacklist(token) {
response.FailWithDetailed(gin.H{"reload": true}, "您的帐户异地登陆或令牌失效", c)
if isBlacklist(token) {
response.NoAuth("您的帐户异地登陆或令牌失效", c)
utils.ClearToken(c)
c.Abort()
return
}
// 解析 token
j := utils.NewJWT()
// parseToken 解析token包含的信息
claims, err := j.ParseToken(token)
if err != nil {
if err == utils.TokenExpired {
response.FailWithDetailed(gin.H{"reload": true}, "授权已过期", c)
if errors.Is(err, utils.TokenExpired) {
response.NoAuth("登录已过期,请重新登录", c)
utils.ClearToken(c)
c.Abort()
return
}
response.FailWithDetailed(gin.H{"reload": true}, err.Error(), c)
response.NoAuth(err.Error(), c)
utils.ClearToken(c)
c.Abort()
return
}
// Token 续期检查
// 已登录用户被管理员禁用 需要使该用户的jwt失效 此处比较消耗性能 如果需要 请自行打开
// 用户被删除的逻辑 需要优化 此处比较消耗性能 如果需要 请自行打开
//if user, err := userService.FindUserByUuid(claims.UUID.String()); err != nil || user.Enable == 2 {
// _ = jwtService.JsonInBlacklist(system.JwtBlacklist{Jwt: token})
// response.FailWithDetailed(gin.H{"reload": true}, err.Error(), c)
// c.Abort()
//}
c.Set("claims", claims)
if claims.ExpiresAt.Unix()-time.Now().Unix() < claims.BufferTime {
dr, _ := utils.ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
claims.ExpiresAt = utils.NewNumericDate(time.Now().Add(dr))
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(dr))
newToken, _ := j.CreateTokenByOldToken(token, *claims)
newClaims, _ := j.ParseToken(newToken)
c.Header("new-token", newToken)
c.Header("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
utils.SetToken(c, newToken, int(dr.Seconds()))
utils.SetToken(c, newToken, int(dr.Seconds()/60))
if global.GVA_CONFIG.System.UseMultipoint {
RedisJwtToken, err := jwtService.GetRedisJWT(newClaims.Username)
if err != nil {
global.GVA_LOG.Error("get redis jwt failed", zap.Error(err))
} else {
_ = jwtService.JsonInBlacklist(system.JwtBlacklist{Jwt: RedisJwtToken})
}
_ = jwtService.SetRedisJWT(newToken, newClaims.Username)
// 记录新的活跃jwt
_ = utils.SetRedisJWT(newToken, newClaims.Username)
}
}
c.Set("claims", claims)
c.Next()
if newToken, exists := c.Get("new-token"); exists {
c.Header("new-token", newToken.(string))
}
if newExpiresAt, exists := c.Get("new-expires-at"); exists {
c.Header("new-expires-at", newExpiresAt.(string))
}
}
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: IsBlacklist
//@description: 判断JWT是否在黑名单内部
//@param: jwt string
//@return: bool
func isBlacklist(jwt string) bool {
_, ok := global.BlackCache.Get(jwt)
return ok
}

View File

@@ -3,64 +3,90 @@ package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type LimitConfig struct {
GenerationDuration time.Duration
Limit int
// GenerationKey 根据业务生成key 下面CheckOrMark查询生成
GenerationKey func(c *gin.Context) string
// 检查函数,用户可修改具体逻辑,更加灵活
CheckOrMark func(key string, expire int, limit int) error
// Expire key 过期时间
Expire int
// Limit 周期时间
Limit int
}
func (l LimitConfig) LimitKey(c *gin.Context) string {
func (l LimitConfig) LimitWithTime() gin.HandlerFunc {
return func(c *gin.Context) {
if err := l.CheckOrMark(l.GenerationKey(c), l.Expire, l.Limit); err != nil {
c.JSON(http.StatusOK, gin.H{"code": response.ERROR, "msg": err.Error()})
c.Abort()
return
} else {
c.Next()
}
}
}
// DefaultGenerationKey 默认生成key
func DefaultGenerationKey(c *gin.Context) string {
return "GVA_Limit" + c.ClientIP()
}
func (l LimitConfig) GetLimit(c *gin.Context) int {
return l.Limit
func DefaultCheckOrMark(key string, expire int, limit int) (err error) {
// 判断是否开启redis
if global.GVA_REDIS == nil {
return err
}
if err = SetLimitWithTime(key, limit, time.Duration(expire)*time.Second); err != nil {
global.GVA_LOG.Error("limit", zap.Error(err))
}
return err
}
func (l LimitConfig) Reached(c *gin.Context) response.Response {
return response.Response{Code: response.ERROR, Data: nil, Msg: "操作过于频繁,请稍后再试"}
func DefaultLimit() gin.HandlerFunc {
return LimitConfig{
GenerationKey: DefaultGenerationKey,
CheckOrMark: DefaultCheckOrMark,
Expire: global.GVA_CONFIG.System.LimitTimeIP,
Limit: global.GVA_CONFIG.System.LimitCountIP,
}.LimitWithTime()
}
// IPLimit IP 限流中间件
func IPLimit() gin.HandlerFunc {
return func(c *gin.Context) {
if global.GVA_CONFIG.System.UseRedis {
key := "GVA_Limit" + c.ClientIP()
limit := global.GVA_CONFIG.System.IpLimitCount
limitTime := global.GVA_CONFIG.System.IpLimitTime
ctx := context.Background()
count, err := global.GVA_REDIS.Get(ctx, key).Int()
if err != nil && !errors.Is(err, context.DeadlineExceeded) {
global.GVA_LOG.Error("get redis key error", zap.Error(err))
}
if count >= limit {
c.JSON(http.StatusOK, gin.H{
"code": response.ERROR,
"msg": fmt.Sprintf("操作过于频繁,请在 %d 秒后再试", limitTime),
})
c.Abort()
return
}
pipe := global.GVA_REDIS.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Second*time.Duration(limitTime))
_, err = pipe.Exec(ctx)
if err != nil {
global.GVA_LOG.Error("redis pipeline error", zap.Error(err))
// SetLimitWithTime 设置访问次数
func SetLimitWithTime(key string, limit int, expiration time.Duration) error {
count, err := global.GVA_REDIS.Exists(context.Background(), key).Result()
if err != nil {
return err
}
if count == 0 {
pipe := global.GVA_REDIS.TxPipeline()
pipe.Incr(context.Background(), key)
pipe.Expire(context.Background(), key, expiration)
_, err = pipe.Exec(context.Background())
return err
} else {
// 次数
if times, err := global.GVA_REDIS.Get(context.Background(), key).Int(); err != nil {
return err
} else {
if times >= limit {
if t, err := global.GVA_REDIS.PTTL(context.Background(), key).Result(); err != nil {
return errors.New("请求太过频繁,请稍后再试")
} else {
return errors.New("请求太过频繁, 请 " + t.String() + " 秒后尝试")
}
} else {
return global.GVA_REDIS.Incr(context.Background(), key).Err()
}
}
c.Next()
}
}

View File

@@ -2,67 +2,26 @@ package middleware
import (
"fmt"
"io"
"os"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"github.com/unrolled/secure"
)
// LoadTls 加载 TLS 证书
// 用https把这个中间件在router里面use一下就好
func LoadTls() gin.HandlerFunc {
return func(c *gin.Context) {
if global.GVA_CONFIG.System.UseHttps {
certFile := global.GVA_CONFIG.System.TlsCert
keyFile := global.GVA_CONFIG.System.TlsKey
if certFile == "" || keyFile == "" {
global.GVA_LOG.Error("TLS cert or key file not configured")
c.AbortWithStatus(500)
return
}
// 检查证书文件是否存在
if _, err := os.Stat(certFile); os.IsNotExist(err) {
global.GVA_LOG.Error("TLS cert file not found", zap.String("file", certFile))
c.AbortWithStatus(500)
return
}
if _, err := os.Stat(keyFile); os.IsNotExist(err) {
global.GVA_LOG.Error("TLS key file not found", zap.String("file", keyFile))
c.AbortWithStatus(500)
return
}
middleware := secure.New(secure.Options{
SSLRedirect: true,
SSLHost: "localhost:443",
})
err := middleware.Process(c.Writer, c.Request)
if err != nil {
// 如果出现错误,请不要继续
fmt.Println(err)
return
}
// 继续往下处理
c.Next()
}
}
// LoadTlsFromFile 从文件加载 TLS 证书内容
func LoadTlsFromFile(certFile, keyFile string) (certPEM, keyPEM []byte, err error) {
certF, err := os.Open(certFile)
if err != nil {
return nil, nil, fmt.Errorf("open cert file error: %w", err)
}
defer certF.Close()
keyF, err := os.Open(keyFile)
if err != nil {
return nil, nil, fmt.Errorf("open key file error: %w", err)
}
defer keyF.Close()
certPEM, err = io.ReadAll(certF)
if err != nil {
return nil, nil, fmt.Errorf("read cert file error: %w", err)
}
keyPEM, err = io.ReadAll(keyF)
if err != nil {
return nil, nil, fmt.Errorf("read key file error: %w", err)
}
return certPEM, keyPEM, nil
}

View File

@@ -3,111 +3,87 @@ package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/service"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
var operationRecordService = service.ServiceGroupApp.SystemServiceGroup.OperationRecordService
// LogLayout 日志layout
type LogLayout struct {
Time time.Time
Metadata map[string]interface{} // 存储自定义原数据
Path string // 访问路径
Query string // 携带query
Body string // 携带body数据
IP string // ip地址
UserAgent string // 代理
Error string // 错误
Cost time.Duration // 花费时间
Source string // 来源
}
// OperationRecord 操作记录中间件
func OperationRecord() gin.HandlerFunc {
type Logger struct {
// Filter 用户自定义过滤
Filter func(c *gin.Context) bool
// FilterKeyword 关键字过滤(key)
FilterKeyword func(layout *LogLayout) bool
// AuthProcess 鉴权处理
AuthProcess func(c *gin.Context, layout *LogLayout)
// 日志处理
Print func(LogLayout)
// Source 服务唯一标识
Source string
}
func (l Logger) SetLoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
var body []byte
var userId int
if c.Request.Method != http.MethodGet {
var err error
body, err = io.ReadAll(c.Request.Body)
if err != nil {
global.GVA_LOG.Error("read body from request error:", zap.Error(err))
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
}
if l.Filter != nil && !l.Filter(c) {
body, _ = c.GetRawData()
// 将原body塞回去
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
}
userId = int(utils.GetUserID(c))
writer := responseBodyWriter{
ResponseWriter: c.Writer,
body: &bytes.Buffer{},
}
c.Writer = writer
now := time.Now()
c.Next()
latency := time.Since(now)
if c.Request.Method != http.MethodGet {
record := system.SysOperationRecord{
Ip: c.ClientIP(),
Method: c.Request.Method,
Path: c.Request.URL.Path,
Agent: c.Request.UserAgent(),
Body: string(body),
UserID: userId,
Status: c.Writer.Status(),
Latency: latency,
Resp: writer.body.String(),
}
values, _ := url.ParseQuery(c.Request.URL.RawQuery)
record.Query = values.Encode()
if err := operationRecordService.CreateSysOperationRecord(record); err != nil {
global.GVA_LOG.Error("create operation record error:", zap.Error(err))
}
cost := time.Since(start)
layout := LogLayout{
Time: time.Now(),
Path: path,
Query: query,
IP: c.ClientIP(),
UserAgent: c.Request.UserAgent(),
Error: strings.TrimRight(c.Errors.ByType(gin.ErrorTypePrivate).String(), "\n"),
Cost: cost,
Source: l.Source,
}
}
}
type responseBodyWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (r responseBodyWriter) Write(b []byte) (int, error) {
r.body.Write(b)
return r.ResponseWriter.Write(b)
}
func (r responseBodyWriter) WriteString(s string) (int, error) {
r.body.WriteString(s)
return r.ResponseWriter.WriteString(s)
}
func (r responseBodyWriter) WriteJSON(obj interface{}) error {
data, err := json.Marshal(obj)
if err != nil {
return err
}
r.body.Write(data)
return r.ResponseWriter.WriteJSON(obj)
}
// NeedRecordPath 判断是否需要记录操作日志
func NeedRecordPath(path string) bool {
// 排除不需要记录的路径
excludePaths := []string{
"/health",
"/swagger",
"/api/captcha",
}
for _, excludePath := range excludePaths {
if strings.HasPrefix(path, excludePath) {
return false
if l.Filter != nil && !l.Filter(c) {
layout.Body = string(body)
}
if l.AuthProcess != nil {
// 处理鉴权需要的信息
l.AuthProcess(c, &layout)
}
if l.FilterKeyword != nil {
// 自行判断key/value 脱敏等
l.FilterKeyword(&layout)
}
// 自行处理日志
l.Print(layout)
}
return true
}
func DefaultLogger() gin.HandlerFunc {
return Logger{
Print: func(layout LogLayout) {
// 标准输出,k8s做收集
v, _ := json.Marshal(layout)
fmt.Println(string(v))
},
Source: "GVA",
}.SetLoggerMiddleware()
}

View File

@@ -6,13 +6,15 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"git.echol.cn/loser/ai_proxy/server/utils"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
@@ -26,12 +28,10 @@ func init() {
}
}
// Operation 操作日志记录中间件
func Operation() gin.HandlerFunc {
func OperationRecord() gin.HandlerFunc {
return func(c *gin.Context) {
var body []byte
var userId int
if c.Request.Method != http.MethodGet {
var err error
body, err = io.ReadAll(c.Request.Body)
@@ -40,9 +40,48 @@ func Operation() gin.HandlerFunc {
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
}
} else {
query := c.Request.URL.RawQuery
query, _ = url.QueryUnescape(query)
split := strings.Split(query, "&")
m := make(map[string]string)
for _, v := range split {
kv := strings.Split(v, "=")
if len(kv) == 2 {
m[kv[0]] = kv[1]
}
}
body, _ = json.Marshal(&m)
}
claims, _ := utils.GetClaims(c)
if claims != nil && claims.BaseClaims.ID != 0 {
userId = int(claims.BaseClaims.ID)
} else {
id, err := strconv.Atoi(c.Request.Header.Get("x-user-id"))
if err != nil {
userId = 0
}
userId = id
}
record := system.SysOperationRecord{
Ip: c.ClientIP(),
Method: c.Request.Method,
Path: c.Request.URL.Path,
Agent: c.Request.UserAgent(),
Body: "",
UserID: userId,
}
userId = int(utils.GetUserID(c))
// 上传文件时候 中间件日志进行裁断操作
if strings.Contains(c.GetHeader("Content-Type"), "multipart/form-data") {
record.Body = "[文件]"
} else {
if len(body) > bufferSize {
record.Body = "[超出记录长度]"
} else {
record.Body = string(body)
}
}
writer := responseBodyWriter{
ResponseWriter: c.Writer,
@@ -54,66 +93,37 @@ func Operation() gin.HandlerFunc {
c.Next()
latency := time.Since(now)
record.ErrorMessage = c.Errors.ByType(gin.ErrorTypePrivate).String()
record.Status = c.Writer.Status()
record.Latency = latency
record.Resp = writer.body.String()
// 只记录需要的路径
if NeedRecordPath(c.Request.URL.Path) && c.Request.Method != http.MethodGet {
record := system.SysOperationRecord{
Ip: c.ClientIP(),
Method: c.Request.Method,
Path: c.Request.URL.Path,
Agent: c.Request.UserAgent(),
Body: string(body),
UserID: userId,
Status: c.Writer.Status(),
Latency: latency,
Resp: writer.body.String(),
if strings.Contains(c.Writer.Header().Get("Pragma"), "public") ||
strings.Contains(c.Writer.Header().Get("Expires"), "0") ||
strings.Contains(c.Writer.Header().Get("Cache-Control"), "must-revalidate, post-check=0, pre-check=0") ||
strings.Contains(c.Writer.Header().Get("Content-Type"), "application/force-download") ||
strings.Contains(c.Writer.Header().Get("Content-Type"), "application/octet-stream") ||
strings.Contains(c.Writer.Header().Get("Content-Type"), "application/vnd.ms-excel") ||
strings.Contains(c.Writer.Header().Get("Content-Type"), "application/download") ||
strings.Contains(c.Writer.Header().Get("Content-Disposition"), "attachment") ||
strings.Contains(c.Writer.Header().Get("Content-Transfer-Encoding"), "binary") {
if len(record.Resp) > bufferSize {
// 截断
record.Body = "超出记录长度"
}
values, _ := url.ParseQuery(c.Request.URL.RawQuery)
record.Query = values.Encode()
// 异步记录操作日志
go func() {
if err := operationRecordService.CreateSysOperationRecord(record); err != nil {
global.GVA_LOG.Error("create operation record error:", zap.Error(err))
}
}()
}
if err := global.GVA_DB.Create(&record).Error; err != nil {
global.GVA_LOG.Error("create operation record error:", zap.Error(err))
}
}
}
// FilteredPaths 需要过滤的路径
var FilteredPaths = []string{
"/health",
"/swagger",
"/api/captcha",
"/api/base/login",
type responseBodyWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
// ShouldRecord 判断是否应该记录该路径
func ShouldRecord(path string) bool {
for _, filtered := range FilteredPaths {
if strings.HasPrefix(path, filtered) {
return false
}
}
return true
}
// MaskSensitiveData 脱敏敏感数据
func MaskSensitiveData(data string) string {
var result map[string]interface{}
if err := json.Unmarshal([]byte(data), &result); err != nil {
return data
}
sensitiveFields := []string{"password", "token", "secret", "key"}
for _, field := range sensitiveFields {
if _, ok := result[field]; ok {
result[field] = "******"
}
}
masked, _ := json.Marshal(result)
return string(masked)
func (r responseBodyWriter) Write(b []byte) (int, error) {
r.body.Write(b)
return r.ResponseWriter.Write(b)
}

View File

@@ -2,31 +2,54 @@ package middleware
import (
"context"
"time"
"git.echol.cn/loser/ai_proxy/server/model/common/response"
"github.com/gin-gonic/gin"
"net/http"
"time"
)
// Timeout 超时中间件
func Timeout(timeout time.Duration) gin.HandlerFunc {
// TimeoutMiddleware 创建超时中间件
// 入参 timeout 设置超时时间例如time.Second * 5
// 使用示例 xxx.Get("path",middleware.TimeoutMiddleware(30*time.Second),HandleFunc)
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
c.Request = c.Request.WithContext(ctx)
finished := make(chan struct{})
// 使用 buffered channel 避免 goroutine 泄漏
done := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
select {
case panicChan <- p:
default:
}
}
select {
case done <- struct{}{}:
default:
}
}()
c.Next()
finished <- struct{}{}
}()
select {
case p := <-panicChan:
panic(p)
case <-done:
return
case <-ctx.Done():
response.FailWithMessage("请求超时", c)
c.Abort()
case <-finished:
// 确保服务器超时设置足够长
c.Header("Connection", "close")
c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
"code": 504,
"msg": "请求超时",
})
return
}
}
}