🎉 初始化项目
This commit is contained in:
84
server/utils/app_jwt.go
Normal file
84
server/utils/app_jwt.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
UserTypeApp = "app" // 前台用户类型标识
|
||||
)
|
||||
|
||||
// AppJWTClaims 前台用户 JWT Claims
|
||||
type AppJWTClaims struct {
|
||||
UserID uint `json:"userId"`
|
||||
Username string `json:"username"`
|
||||
UserType string `json:"userType"` // 用户类型标识
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CreateAppToken 创建前台用户 Token(有效期 7 天)
|
||||
func CreateAppToken(userID uint, username string) (tokenString string, expiresAt int64, err error) {
|
||||
// Token 有效期为 7 天
|
||||
expiresTime := time.Now().Add(7 * 24 * time.Hour)
|
||||
expiresAt = expiresTime.Unix()
|
||||
|
||||
claims := AppJWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
UserType: UserTypeApp,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresTime),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: global.GVA_CONFIG.JWT.Issuer,
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err = token.SignedString([]byte(global.GVA_CONFIG.JWT.SigningKey))
|
||||
return
|
||||
}
|
||||
|
||||
// CreateAppRefreshToken 创建前台用户刷新 Token(有效期更长)
|
||||
func CreateAppRefreshToken(userID uint, username string) (tokenString string, expiresAt int64, err error) {
|
||||
// 刷新 Token 有效期为 7 天
|
||||
expiresTime := time.Now().Add(7 * 24 * time.Hour)
|
||||
expiresAt = expiresTime.Unix()
|
||||
|
||||
claims := AppJWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
UserType: UserTypeApp,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresTime),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: global.GVA_CONFIG.JWT.Issuer,
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err = token.SignedString([]byte(global.GVA_CONFIG.JWT.SigningKey))
|
||||
return
|
||||
}
|
||||
|
||||
// ParseAppToken 解析前台用户 Token
|
||||
func ParseAppToken(tokenString string) (*AppJWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &AppJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(global.GVA_CONFIG.JWT.SigningKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*AppJWTClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
121
server/utils/breakpoint_continue.go
Normal file
121
server/utils/breakpoint_continue.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 前端传来文件片与当前片为什么文件的第几片
|
||||
// 后端拿到以后比较次分片是否上传 或者是否为不完全片
|
||||
// 前端发送每片多大
|
||||
// 前端告知是否为最后一片且是否完成
|
||||
|
||||
const (
|
||||
breakpointDir = "./breakpointDir/"
|
||||
finishDir = "./fileDir/"
|
||||
)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: BreakPointContinue
|
||||
//@description: 断点续传
|
||||
//@param: content []byte, fileName string, contentNumber int, contentTotal int, fileMd5 string
|
||||
//@return: error, string
|
||||
|
||||
func BreakPointContinue(content []byte, fileName string, contentNumber int, contentTotal int, fileMd5 string) (string, error) {
|
||||
if strings.Contains(fileName, "..") || strings.Contains(fileMd5, "..") {
|
||||
return "", errors.New("文件名或路径不合法")
|
||||
}
|
||||
path := breakpointDir + fileMd5 + "/"
|
||||
err := os.MkdirAll(path, os.ModePerm)
|
||||
if err != nil {
|
||||
return path, err
|
||||
}
|
||||
pathC, err := makeFileContent(content, fileName, path, contentNumber)
|
||||
return pathC, err
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: CheckMd5
|
||||
//@description: 检查Md5
|
||||
//@param: content []byte, chunkMd5 string
|
||||
//@return: CanUpload bool
|
||||
|
||||
func CheckMd5(content []byte, chunkMd5 string) (CanUpload bool) {
|
||||
fileMd5 := MD5V(content)
|
||||
if fileMd5 == chunkMd5 {
|
||||
return true // 可以继续上传
|
||||
} else {
|
||||
return false // 切片不完整,废弃
|
||||
}
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: makeFileContent
|
||||
//@description: 创建切片内容
|
||||
//@param: content []byte, fileName string, FileDir string, contentNumber int
|
||||
//@return: string, error
|
||||
|
||||
func makeFileContent(content []byte, fileName string, FileDir string, contentNumber int) (string, error) {
|
||||
if strings.Contains(fileName, "..") || strings.Contains(FileDir, "..") {
|
||||
return "", errors.New("文件名或路径不合法")
|
||||
}
|
||||
path := FileDir + fileName + "_" + strconv.Itoa(contentNumber)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return path, err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = f.Write(content)
|
||||
if err != nil {
|
||||
return path, err
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: makeFileContent
|
||||
//@description: 创建切片文件
|
||||
//@param: fileName string, FileMd5 string
|
||||
//@return: error, string
|
||||
|
||||
func MakeFile(fileName string, FileMd5 string) (string, error) {
|
||||
if strings.Contains(fileName, "..") || strings.Contains(FileMd5, "..") {
|
||||
return "", errors.New("文件名或路径不合法")
|
||||
}
|
||||
rd, err := os.ReadDir(breakpointDir + FileMd5)
|
||||
if err != nil {
|
||||
return finishDir + fileName, err
|
||||
}
|
||||
_ = os.MkdirAll(finishDir, os.ModePerm)
|
||||
fd, err := os.OpenFile(finishDir+fileName, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return finishDir + fileName, err
|
||||
}
|
||||
defer fd.Close()
|
||||
for k := range rd {
|
||||
content, _ := os.ReadFile(breakpointDir + FileMd5 + "/" + fileName + "_" + strconv.Itoa(k))
|
||||
_, err = fd.Write(content)
|
||||
if err != nil {
|
||||
_ = os.Remove(finishDir + fileName)
|
||||
return finishDir + fileName, err
|
||||
}
|
||||
}
|
||||
return finishDir + fileName, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: RemoveChunk
|
||||
//@description: 移除切片
|
||||
//@param: FileMd5 string
|
||||
//@return: error
|
||||
|
||||
func RemoveChunk(FileMd5 string) error {
|
||||
if strings.Contains(FileMd5, "..") {
|
||||
return errors.New("路径不合法")
|
||||
}
|
||||
err := os.RemoveAll(breakpointDir + FileMd5)
|
||||
return err
|
||||
}
|
||||
61
server/utils/captcha/redis.go
Normal file
61
server/utils/captcha/redis.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func NewDefaultRedisStore() *RedisStore {
|
||||
return &RedisStore{
|
||||
Expiration: time.Second * 180,
|
||||
PreKey: "CAPTCHA_",
|
||||
Context: context.TODO(),
|
||||
}
|
||||
}
|
||||
|
||||
type RedisStore struct {
|
||||
Expiration time.Duration
|
||||
PreKey string
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
func (rs *RedisStore) UseWithCtx(ctx context.Context) *RedisStore {
|
||||
if ctx == nil {
|
||||
rs.Context = ctx
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
||||
func (rs *RedisStore) Set(id string, value string) error {
|
||||
err := global.GVA_REDIS.Set(rs.Context, rs.PreKey+id, value, rs.Expiration).Err()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("RedisStoreSetError!", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RedisStore) Get(key string, clear bool) string {
|
||||
val, err := global.GVA_REDIS.Get(rs.Context, key).Result()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("RedisStoreGetError!", zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
if clear {
|
||||
err := global.GVA_REDIS.Del(rs.Context, key).Err()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("RedisStoreClearError!", zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func (rs *RedisStore) Verify(id, answer string, clear bool) bool {
|
||||
key := rs.PreKey + id
|
||||
v := rs.Get(key, clear)
|
||||
return v == answer
|
||||
}
|
||||
52
server/utils/casbin_util.go
Normal file
52
server/utils/casbin_util.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/casbin/casbin/v2"
|
||||
"github.com/casbin/casbin/v2/model"
|
||||
gormadapter "github.com/casbin/gorm-adapter/v3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
syncedCachedEnforcer *casbin.SyncedCachedEnforcer
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// GetCasbin 获取casbin实例
|
||||
func GetCasbin() *casbin.SyncedCachedEnforcer {
|
||||
once.Do(func() {
|
||||
a, err := gormadapter.NewAdapterByDB(global.GVA_DB)
|
||||
if err != nil {
|
||||
zap.L().Error("适配数据库失败请检查casbin表是否为InnoDB引擎!", zap.Error(err))
|
||||
return
|
||||
}
|
||||
text := `
|
||||
[request_definition]
|
||||
r = sub, obj, act
|
||||
|
||||
[policy_definition]
|
||||
p = sub, obj, act
|
||||
|
||||
[role_definition]
|
||||
g = _, _
|
||||
|
||||
[policy_effect]
|
||||
e = some(where (p.eft == allow))
|
||||
|
||||
[matchers]
|
||||
m = r.sub == p.sub && keyMatch2(r.obj,p.obj) && r.act == p.act
|
||||
`
|
||||
m, err := model.NewModelFromString(text)
|
||||
if err != nil {
|
||||
zap.L().Error("字符串加载模型失败!", zap.Error(err))
|
||||
return
|
||||
}
|
||||
syncedCachedEnforcer, _ = casbin.NewSyncedCachedEnforcer(m, a)
|
||||
syncedCachedEnforcer.SetExpireTime(60 * 60)
|
||||
_ = syncedCachedEnforcer.LoadPolicy()
|
||||
})
|
||||
return syncedCachedEnforcer
|
||||
}
|
||||
285
server/utils/character_card.go
Normal file
285
server/utils/character_card.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"image"
|
||||
"image/png"
|
||||
"io"
|
||||
)
|
||||
|
||||
// CharacterCardV2 SillyTavern 角色卡 V2 格式
|
||||
type CharacterCardV2 struct {
|
||||
Spec string `json:"spec"`
|
||||
SpecVersion string `json:"spec_version"`
|
||||
Data CharacterCardV2Data `json:"data"`
|
||||
}
|
||||
|
||||
type CharacterCardV2Data struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Personality string `json:"personality"`
|
||||
Scenario string `json:"scenario"`
|
||||
FirstMes string `json:"first_mes"`
|
||||
MesExample string `json:"mes_example"`
|
||||
CreatorNotes string `json:"creator_notes"`
|
||||
SystemPrompt string `json:"system_prompt"`
|
||||
PostHistoryInstructions string `json:"post_history_instructions"`
|
||||
Tags []string `json:"tags"`
|
||||
Creator string `json:"creator"`
|
||||
CharacterVersion string `json:"character_version"`
|
||||
AlternateGreetings []string `json:"alternate_greetings"`
|
||||
CharacterBook map[string]interface{} `json:"character_book,omitempty"`
|
||||
Extensions map[string]interface{} `json:"extensions"`
|
||||
}
|
||||
|
||||
// ExtractCharacterFromPNG 从 PNG 图片中提取角色卡数据
|
||||
func ExtractCharacterFromPNG(pngData []byte) (*CharacterCardV2, error) {
|
||||
reader := bytes.NewReader(pngData)
|
||||
|
||||
// 验证 PNG 格式(解码但不保存图片)
|
||||
_, err := png.Decode(reader)
|
||||
if err != nil {
|
||||
return nil, errors.New("无效的 PNG 文件")
|
||||
}
|
||||
|
||||
// 重新读取以获取 tEXt chunks
|
||||
reader.Seek(0, 0)
|
||||
|
||||
// 查找 tEXt chunk 中的 "chara" 字段
|
||||
charaJSON, err := extractTextChunk(reader, "chara")
|
||||
if err != nil {
|
||||
return nil, errors.New("PNG 中没有找到角色卡数据")
|
||||
}
|
||||
|
||||
// 尝试 Base64 解码
|
||||
decodedJSON, err := base64.StdEncoding.DecodeString(charaJSON)
|
||||
if err != nil {
|
||||
// 如果不是 Base64,直接使用原始 JSON
|
||||
decodedJSON = []byte(charaJSON)
|
||||
}
|
||||
|
||||
// 解析 JSON
|
||||
var card CharacterCardV2
|
||||
err = json.Unmarshal(decodedJSON, &card)
|
||||
if err != nil {
|
||||
return nil, errors.New("解析角色卡数据失败: " + err.Error())
|
||||
}
|
||||
|
||||
return &card, nil
|
||||
}
|
||||
|
||||
// extractTextChunk 从 PNG 中提取指定 key 的 tEXt chunk
|
||||
func extractTextChunk(r io.Reader, key string) (string, error) {
|
||||
// 跳过 PNG signature (8 bytes)
|
||||
signature := make([]byte, 8)
|
||||
if _, err := io.ReadFull(r, signature); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 验证 PNG signature
|
||||
expectedSig := []byte{137, 80, 78, 71, 13, 10, 26, 10}
|
||||
if !bytes.Equal(signature, expectedSig) {
|
||||
return "", errors.New("invalid PNG signature")
|
||||
}
|
||||
|
||||
// 读取所有 chunks
|
||||
for {
|
||||
// 读取 chunk length (4 bytes)
|
||||
lengthBytes := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, lengthBytes); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 |
|
||||
uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
|
||||
|
||||
// 读取 chunk type (4 bytes)
|
||||
chunkType := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, chunkType); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 读取 chunk data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 读取 CRC (4 bytes)
|
||||
crc := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, crc); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 检查是否是 tEXt chunk
|
||||
if string(chunkType) == "tEXt" {
|
||||
// tEXt chunk 格式: keyword\0text
|
||||
nullIndex := bytes.IndexByte(data, 0)
|
||||
if nullIndex == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
keyword := string(data[:nullIndex])
|
||||
text := string(data[nullIndex+1:])
|
||||
|
||||
if keyword == key {
|
||||
return text, nil
|
||||
}
|
||||
}
|
||||
|
||||
// IEND chunk 表示结束
|
||||
if string(chunkType) == "IEND" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("text chunk not found")
|
||||
}
|
||||
|
||||
// EmbedCharacterToPNG 将角色卡数据嵌入到 PNG 图片中
|
||||
func EmbedCharacterToPNG(img image.Image, card *CharacterCardV2) ([]byte, error) {
|
||||
// 序列化角色卡数据
|
||||
cardJSON, err := json.Marshal(card)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Base64 编码
|
||||
encodedJSON := base64.StdEncoding.EncodeToString(cardJSON)
|
||||
|
||||
// 创建一个 buffer 来写入 PNG
|
||||
var buf bytes.Buffer
|
||||
|
||||
// 写入 PNG signature
|
||||
buf.Write([]byte{137, 80, 78, 71, 13, 10, 26, 10})
|
||||
|
||||
// 编码原始图片到临时 buffer
|
||||
var imgBuf bytes.Buffer
|
||||
if err := png.Encode(&imgBuf, img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 跳过原始 PNG 的 signature
|
||||
imgData := imgBuf.Bytes()[8:]
|
||||
|
||||
// 将原始图片的 chunks 复制到输出,在 IEND 之前插入 tEXt chunk
|
||||
r := bytes.NewReader(imgData)
|
||||
|
||||
for {
|
||||
// 读取 chunk length
|
||||
lengthBytes := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, lengthBytes); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 |
|
||||
uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
|
||||
|
||||
// 读取 chunk type
|
||||
chunkType := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, chunkType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 读取 chunk data
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 读取 CRC
|
||||
crc := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r, crc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果是 IEND chunk,先写入 tEXt chunk
|
||||
if string(chunkType) == "IEND" {
|
||||
// 写入 tEXt chunk
|
||||
writeTextChunk(&buf, "chara", encodedJSON)
|
||||
}
|
||||
|
||||
// 写入原始 chunk
|
||||
buf.Write(lengthBytes)
|
||||
buf.Write(chunkType)
|
||||
buf.Write(data)
|
||||
buf.Write(crc)
|
||||
|
||||
if string(chunkType) == "IEND" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// writeTextChunk 写入 tEXt chunk
|
||||
func writeTextChunk(w io.Writer, keyword, text string) error {
|
||||
data := append([]byte(keyword), 0)
|
||||
data = append(data, []byte(text)...)
|
||||
|
||||
// 写入 length
|
||||
length := uint32(len(data))
|
||||
lengthBytes := []byte{
|
||||
byte(length >> 24),
|
||||
byte(length >> 16),
|
||||
byte(length >> 8),
|
||||
byte(length),
|
||||
}
|
||||
w.Write(lengthBytes)
|
||||
|
||||
// 写入 type
|
||||
w.Write([]byte("tEXt"))
|
||||
|
||||
// 写入 data
|
||||
w.Write(data)
|
||||
|
||||
// 计算并写入 CRC
|
||||
crcData := append([]byte("tEXt"), data...)
|
||||
crc := calculateCRC(crcData)
|
||||
crcBytes := []byte{
|
||||
byte(crc >> 24),
|
||||
byte(crc >> 16),
|
||||
byte(crc >> 8),
|
||||
byte(crc),
|
||||
}
|
||||
w.Write(crcBytes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateCRC 计算 CRC32
|
||||
func calculateCRC(data []byte) uint32 {
|
||||
crc := uint32(0xFFFFFFFF)
|
||||
|
||||
for _, b := range data {
|
||||
crc ^= uint32(b)
|
||||
for i := 0; i < 8; i++ {
|
||||
if crc&1 != 0 {
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
} else {
|
||||
crc >>= 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return crc ^ 0xFFFFFFFF
|
||||
}
|
||||
|
||||
// ParseCharacterCardJSON 解析 JSON 格式的角色卡
|
||||
func ParseCharacterCardJSON(jsonData []byte) (*CharacterCardV2, error) {
|
||||
var card CharacterCardV2
|
||||
err := json.Unmarshal(jsonData, &card)
|
||||
if err != nil {
|
||||
return nil, errors.New("解析角色卡 JSON 失败: " + err.Error())
|
||||
}
|
||||
|
||||
return &card, nil
|
||||
}
|
||||
148
server/utils/claims.go
Normal file
148
server/utils/claims.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/system"
|
||||
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func ClearToken(c *gin.Context) {
|
||||
// 增加cookie x-token 向来源的web添加
|
||||
host, _, err := net.SplitHostPort(c.Request.Host)
|
||||
if err != nil {
|
||||
host = c.Request.Host
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
c.SetCookie("x-token", "", -1, "/", "", false, false)
|
||||
} else {
|
||||
c.SetCookie("x-token", "", -1, "/", host, false, false)
|
||||
}
|
||||
}
|
||||
|
||||
func SetToken(c *gin.Context, token string, maxAge int) {
|
||||
// 增加cookie x-token 向来源的web添加
|
||||
host, _, err := net.SplitHostPort(c.Request.Host)
|
||||
if err != nil {
|
||||
host = c.Request.Host
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
c.SetCookie("x-token", token, maxAge, "/", "", false, false)
|
||||
} else {
|
||||
c.SetCookie("x-token", token, maxAge, "/", host, false, false)
|
||||
}
|
||||
}
|
||||
|
||||
func GetToken(c *gin.Context) string {
|
||||
token := c.Request.Header.Get("x-token")
|
||||
if token == "" {
|
||||
j := NewJWT()
|
||||
token, _ = c.Cookie("x-token")
|
||||
claims, err := j.ParseToken(token)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("重新写入cookie token失败,未能成功解析token,请检查请求头是否存在x-token且claims是否为规定结构")
|
||||
return token
|
||||
}
|
||||
SetToken(c, token, int(claims.ExpiresAt.Unix()-time.Now().Unix()))
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func GetClaims(c *gin.Context) (*systemReq.CustomClaims, error) {
|
||||
token := GetToken(c)
|
||||
j := NewJWT()
|
||||
claims, err := j.ParseToken(token)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("从Gin的Context中获取从jwt解析信息失败, 请检查请求头是否存在x-token且claims是否为规定结构")
|
||||
}
|
||||
return claims, err
|
||||
}
|
||||
|
||||
// GetUserID 从Gin的Context中获取从jwt解析出来的用户ID
|
||||
func GetUserID(c *gin.Context) uint {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return 0
|
||||
} else {
|
||||
return cl.BaseClaims.ID
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.BaseClaims.ID
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserUuid 从Gin的Context中获取从jwt解析出来的用户UUID
|
||||
func GetUserUuid(c *gin.Context) uuid.UUID {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return uuid.UUID{}
|
||||
} else {
|
||||
return cl.UUID
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.UUID
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAuthorityId 从Gin的Context中获取从jwt解析出来的用户角色id
|
||||
func GetUserAuthorityId(c *gin.Context) uint {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return 0
|
||||
} else {
|
||||
return cl.AuthorityId
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.AuthorityId
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserInfo 从Gin的Context中获取从jwt解析出来的用户角色id
|
||||
func GetUserInfo(c *gin.Context) *systemReq.CustomClaims {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return nil
|
||||
} else {
|
||||
return cl
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserName 从Gin的Context中获取从jwt解析出来的用户名
|
||||
func GetUserName(c *gin.Context) string {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return ""
|
||||
} else {
|
||||
return cl.Username
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.Username
|
||||
}
|
||||
}
|
||||
|
||||
func LoginToken(user system.Login) (token string, claims systemReq.CustomClaims, err error) {
|
||||
j := NewJWT()
|
||||
claims = j.CreateClaims(systemReq.BaseClaims{
|
||||
UUID: user.GetUUID(),
|
||||
ID: user.GetUserId(),
|
||||
NickName: user.GetNickname(),
|
||||
Username: user.GetUsername(),
|
||||
AuthorityId: user.GetAuthorityId(),
|
||||
})
|
||||
token, err = j.CreateToken(claims)
|
||||
return
|
||||
}
|
||||
124
server/utils/directory.go
Normal file
124
server/utils/directory.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: PathExists
|
||||
//@description: 文件目录是否存在
|
||||
//@param: path string
|
||||
//@return: bool, error
|
||||
|
||||
func PathExists(path string) (bool, error) {
|
||||
fi, err := os.Stat(path)
|
||||
if err == nil {
|
||||
if fi.IsDir() {
|
||||
return true, nil
|
||||
}
|
||||
return false, errors.New("存在同名文件")
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: CreateDir
|
||||
//@description: 批量创建文件夹
|
||||
//@param: dirs ...string
|
||||
//@return: err error
|
||||
|
||||
func CreateDir(dirs ...string) (err error) {
|
||||
for _, v := range dirs {
|
||||
exist, err := PathExists(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exist {
|
||||
global.GVA_LOG.Debug("create directory" + v)
|
||||
if err := os.MkdirAll(v, os.ModePerm); err != nil {
|
||||
global.GVA_LOG.Error("create directory"+v, zap.Any(" error:", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//@author: [songzhibin97](https://github.com/songzhibin97)
|
||||
//@function: FileMove
|
||||
//@description: 文件移动供外部调用
|
||||
//@param: src string, dst string(src: 源位置,绝对路径or相对路径, dst: 目标位置,绝对路径or相对路径,必须为文件夹)
|
||||
//@return: err error
|
||||
|
||||
func FileMove(src string, dst string) (err error) {
|
||||
if dst == "" {
|
||||
return nil
|
||||
}
|
||||
src, err = filepath.Abs(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst, err = filepath.Abs(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revoke := false
|
||||
dir := filepath.Dir(dst)
|
||||
Redirect:
|
||||
_, err = os.Stat(dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !revoke {
|
||||
revoke = true
|
||||
goto Redirect
|
||||
}
|
||||
}
|
||||
return os.Rename(src, dst)
|
||||
}
|
||||
|
||||
func DeLFile(filePath string) error {
|
||||
return os.RemoveAll(filePath)
|
||||
}
|
||||
|
||||
//@author: [songzhibin97](https://github.com/songzhibin97)
|
||||
//@function: TrimSpace
|
||||
//@description: 去除结构体空格
|
||||
//@param: target interface (target: 目标结构体,传入必须是指针类型)
|
||||
//@return: null
|
||||
|
||||
func TrimSpace(target interface{}) {
|
||||
t := reflect.TypeOf(target)
|
||||
if t.Kind() != reflect.Ptr {
|
||||
return
|
||||
}
|
||||
t = t.Elem()
|
||||
v := reflect.ValueOf(target).Elem()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
switch v.Field(i).Kind() {
|
||||
case reflect.String:
|
||||
v.Field(i).SetString(strings.TrimSpace(v.Field(i).String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FileExist 判断文件是否存在
|
||||
func FileExist(path string) bool {
|
||||
fi, err := os.Lstat(path)
|
||||
if err == nil {
|
||||
return !fi.IsDir()
|
||||
}
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
126
server/utils/fmt_plus.go
Normal file
126
server/utils/fmt_plus.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common"
|
||||
)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: StructToMap
|
||||
//@description: 利用反射将结构体转化为map
|
||||
//@param: obj interface{}
|
||||
//@return: map[string]interface{}
|
||||
|
||||
func StructToMap(obj interface{}) map[string]interface{} {
|
||||
obj1 := reflect.TypeOf(obj)
|
||||
obj2 := reflect.ValueOf(obj)
|
||||
|
||||
data := make(map[string]interface{})
|
||||
for i := 0; i < obj1.NumField(); i++ {
|
||||
if obj1.Field(i).Tag.Get("mapstructure") != "" {
|
||||
data[obj1.Field(i).Tag.Get("mapstructure")] = obj2.Field(i).Interface()
|
||||
} else {
|
||||
data[obj1.Field(i).Name] = obj2.Field(i).Interface()
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: ArrayToString
|
||||
//@description: 将数组格式化为字符串
|
||||
//@param: array []interface{}
|
||||
//@return: string
|
||||
|
||||
func ArrayToString(array []interface{}) string {
|
||||
return strings.Replace(strings.Trim(fmt.Sprint(array), "[]"), " ", ",", -1)
|
||||
}
|
||||
|
||||
func Pointer[T any](in T) (out *T) {
|
||||
return &in
|
||||
}
|
||||
|
||||
func FirstUpper(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func FirstLower(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
// MaheHump 将字符串转换为驼峰命名
|
||||
func MaheHump(s string) string {
|
||||
words := strings.Split(s, "-")
|
||||
|
||||
for i := 1; i < len(words); i++ {
|
||||
words[i] = strings.Title(words[i])
|
||||
}
|
||||
|
||||
return strings.Join(words, "")
|
||||
}
|
||||
|
||||
// HumpToUnderscore 将驼峰命名转换为下划线分割模式
|
||||
func HumpToUnderscore(s string) string {
|
||||
var result strings.Builder
|
||||
|
||||
for i, char := range s {
|
||||
if i > 0 && char >= 'A' && char <= 'Z' {
|
||||
// 在大写字母前添加下划线
|
||||
result.WriteRune('_')
|
||||
result.WriteRune(char - 'A' + 'a') // 转小写
|
||||
} else {
|
||||
result.WriteRune(char)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// RandomString 随机字符串
|
||||
func RandomString(n int) string {
|
||||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
b := make([]rune, n)
|
||||
for i := range b {
|
||||
b[i] = letters[RandomInt(0, len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func RandomInt(min, max int) int {
|
||||
return min + rand.Intn(max-min)
|
||||
}
|
||||
|
||||
// BuildTree 用于构建一个树形结构
|
||||
func BuildTree[T common.TreeNode[T]](nodes []T) []T {
|
||||
nodeMap := make(map[int]T)
|
||||
// 创建一个基本map
|
||||
for i := range nodes {
|
||||
nodeMap[nodes[i].GetID()] = nodes[i]
|
||||
}
|
||||
|
||||
for i := range nodes {
|
||||
if nodes[i].GetParentID() != 0 {
|
||||
parent := nodeMap[nodes[i].GetParentID()]
|
||||
parent.SetChildren(nodes[i])
|
||||
}
|
||||
}
|
||||
|
||||
var rootNodes []T
|
||||
|
||||
for i := range nodeMap {
|
||||
if nodeMap[i].GetParentID() == 0 {
|
||||
rootNodes = append(rootNodes, nodeMap[i])
|
||||
}
|
||||
}
|
||||
return rootNodes
|
||||
}
|
||||
32
server/utils/hash.go
Normal file
32
server/utils/hash.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// BcryptHash 使用 bcrypt 对密码进行加密
|
||||
func BcryptHash(password string) string {
|
||||
bytes, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// BcryptCheck 对比明文密码和数据库的哈希值
|
||||
func BcryptCheck(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: MD5V
|
||||
//@description: md5加密
|
||||
//@param: str []byte
|
||||
//@return: string
|
||||
|
||||
func MD5V(str []byte, b ...byte) string {
|
||||
h := md5.New()
|
||||
h.Write(str)
|
||||
return hex.EncodeToString(h.Sum(b))
|
||||
}
|
||||
29
server/utils/human_duration.go
Normal file
29
server/utils/human_duration.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ParseDuration(d string) (time.Duration, error) {
|
||||
d = strings.TrimSpace(d)
|
||||
dr, err := time.ParseDuration(d)
|
||||
if err == nil {
|
||||
return dr, nil
|
||||
}
|
||||
if strings.Contains(d, "d") {
|
||||
index := strings.Index(d, "d")
|
||||
|
||||
hour, _ := strconv.Atoi(d[:index])
|
||||
dr = time.Hour * 24 * time.Duration(hour)
|
||||
ndr, err := time.ParseDuration(d[index+1:])
|
||||
if err != nil {
|
||||
return dr, nil
|
||||
}
|
||||
return dr + ndr, nil
|
||||
}
|
||||
|
||||
dv, err := strconv.ParseInt(d, 10, 64)
|
||||
return time.Duration(dv), err
|
||||
}
|
||||
49
server/utils/human_duration_test.go
Normal file
49
server/utils/human_duration_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseDuration(t *testing.T) {
|
||||
type args struct {
|
||||
d string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "5h20m",
|
||||
args: args{"5h20m"},
|
||||
want: time.Hour*5 + 20*time.Minute,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "1d5h20m",
|
||||
args: args{"1d5h20m"},
|
||||
want: 24*time.Hour + time.Hour*5 + 20*time.Minute,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "1d",
|
||||
args: args{"1d"},
|
||||
want: 24 * time.Hour,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseDuration(tt.args.d)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseDuration() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseDuration() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
34
server/utils/json.go
Normal file
34
server/utils/json.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetJSONKeys(jsonStr string) (keys []string, err error) {
|
||||
// 使用json.Decoder,以便在解析过程中记录键的顺序
|
||||
dec := json.NewDecoder(strings.NewReader(jsonStr))
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 确保数据是一个对象
|
||||
if t != json.Delim('{') {
|
||||
return nil, err
|
||||
}
|
||||
for dec.More() {
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, t.(string))
|
||||
|
||||
// 解析值
|
||||
var value interface{}
|
||||
err = dec.Decode(&value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
53
server/utils/json_test.go
Normal file
53
server/utils/json_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetJSONKeys(t *testing.T) {
|
||||
var jsonStr = `
|
||||
{
|
||||
"Name": "test",
|
||||
"TableName": "test",
|
||||
"TemplateID": "test",
|
||||
"TemplateInfo": "test",
|
||||
"Limit": 0
|
||||
}`
|
||||
keys, err := GetJSONKeys(jsonStr)
|
||||
if err != nil {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
return
|
||||
}
|
||||
if len(keys) != 5 {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
return
|
||||
}
|
||||
if keys[0] != "Name" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[1] != "TableName" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[2] != "TemplateID" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[3] != "TemplateInfo" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[4] != "Limit" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(keys)
|
||||
}
|
||||
105
server/utils/jwt.go
Normal file
105
server/utils/jwt.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/model/system/request"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type JWT struct {
|
||||
SigningKey []byte
|
||||
}
|
||||
|
||||
var (
|
||||
TokenValid = errors.New("未知错误")
|
||||
TokenExpired = errors.New("token已过期")
|
||||
TokenNotValidYet = errors.New("token尚未激活")
|
||||
TokenMalformed = errors.New("这不是一个token")
|
||||
TokenSignatureInvalid = errors.New("无效签名")
|
||||
TokenInvalid = errors.New("无法处理此token")
|
||||
)
|
||||
|
||||
func NewJWT() *JWT {
|
||||
return &JWT{
|
||||
[]byte(global.GVA_CONFIG.JWT.SigningKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (j *JWT) CreateClaims(baseClaims request.BaseClaims) request.CustomClaims {
|
||||
bf, _ := ParseDuration(global.GVA_CONFIG.JWT.BufferTime)
|
||||
ep, _ := ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
|
||||
claims := request.CustomClaims{
|
||||
BaseClaims: baseClaims,
|
||||
BufferTime: int64(bf / time.Second), // 缓冲时间1天 缓冲时间内会获得新的token刷新令牌 此时一个用户会存在两个有效令牌 但是前端只留一个 另一个会丢失
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Audience: jwt.ClaimStrings{"GVA"}, // 受众
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-1000)), // 签名生效时间
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ep)), // 过期时间 7天 配置文件
|
||||
Issuer: global.GVA_CONFIG.JWT.Issuer, // 签名的发行者
|
||||
},
|
||||
}
|
||||
return claims
|
||||
}
|
||||
|
||||
// CreateToken 创建一个token
|
||||
func (j *JWT) CreateToken(claims request.CustomClaims) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(j.SigningKey)
|
||||
}
|
||||
|
||||
// CreateTokenByOldToken 旧token 换新token 使用归并回源避免并发问题
|
||||
func (j *JWT) CreateTokenByOldToken(oldToken string, claims request.CustomClaims) (string, error) {
|
||||
v, err, _ := global.GVA_Concurrency_Control.Do("JWT:"+oldToken, func() (interface{}, error) {
|
||||
return j.CreateToken(claims)
|
||||
})
|
||||
return v.(string), err
|
||||
}
|
||||
|
||||
// ParseToken 解析 token
|
||||
func (j *JWT) ParseToken(tokenString string) (*request.CustomClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &request.CustomClaims{}, func(token *jwt.Token) (i interface{}, e error) {
|
||||
return j.SigningKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, jwt.ErrTokenExpired):
|
||||
return nil, TokenExpired
|
||||
case errors.Is(err, jwt.ErrTokenMalformed):
|
||||
return nil, TokenMalformed
|
||||
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
||||
return nil, TokenSignatureInvalid
|
||||
case errors.Is(err, jwt.ErrTokenNotValidYet):
|
||||
return nil, TokenNotValidYet
|
||||
default:
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
if token != nil {
|
||||
if claims, ok := token.Claims.(*request.CustomClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
}
|
||||
return nil, TokenValid
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: SetRedisJWT
|
||||
//@description: jwt存入redis并设置过期时间
|
||||
//@param: jwt string, userName string
|
||||
//@return: err error
|
||||
|
||||
func SetRedisJWT(jwt string, userName string) (err error) {
|
||||
// 此处过期时间等于jwt过期时间
|
||||
dr, err := ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
timer := dr
|
||||
err = global.GVA_REDIS.Set(context.Background(), userName, jwt, timer).Err()
|
||||
return err
|
||||
}
|
||||
26
server/utils/param.go
Normal file
26
server/utils/param.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// StringToUint 字符串转 uint
|
||||
func StringToUint(s string) (uint, error) {
|
||||
val, err := strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint(val), nil
|
||||
}
|
||||
|
||||
// GetIntQuery 获取查询参数(int类型)
|
||||
func GetIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
if val := c.Query(key); val != "" {
|
||||
if intVal, err := strconv.Atoi(val); err == nil {
|
||||
return intVal
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
22
server/utils/param_helper.go
Normal file
22
server/utils/param_helper.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetUintParam 从 URL 参数中获取 uint 值
|
||||
func GetUintParam(c *gin.Context, key string) uint {
|
||||
val := c.Param(key)
|
||||
if val == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
uintVal, err := strconv.ParseUint(val, 10, 32)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return uint(uintVal)
|
||||
}
|
||||
18
server/utils/plugin/plugin.go
Normal file
18
server/utils/plugin/plugin.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
OnlyFuncName = "Plugin"
|
||||
)
|
||||
|
||||
// Plugin 插件模式接口化
|
||||
type Plugin interface {
|
||||
// Register 注册路由
|
||||
Register(group *gin.RouterGroup)
|
||||
|
||||
// RouterPath 用户返回注册路由
|
||||
RouterPath() string
|
||||
}
|
||||
11
server/utils/plugin/v2/plugin.go
Normal file
11
server/utils/plugin/v2/plugin.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Plugin 插件模式接口化v2
|
||||
type Plugin interface {
|
||||
// Register 注册路由
|
||||
Register(group *gin.Engine)
|
||||
}
|
||||
27
server/utils/plugin/v2/registry.go
Normal file
27
server/utils/plugin/v2/registry.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package plugin
|
||||
|
||||
import "sync"
|
||||
|
||||
var (
|
||||
registryMu sync.RWMutex
|
||||
registry []Plugin
|
||||
)
|
||||
|
||||
// Register records a plugin for auto initialization.
|
||||
func Register(p Plugin) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
registryMu.Lock()
|
||||
registry = append(registry, p)
|
||||
registryMu.Unlock()
|
||||
}
|
||||
|
||||
// Registered returns a snapshot of all registered plugins.
|
||||
func Registered() []Plugin {
|
||||
registryMu.RLock()
|
||||
defer registryMu.RUnlock()
|
||||
out := make([]Plugin, len(registry))
|
||||
copy(out, registry)
|
||||
return out
|
||||
}
|
||||
15
server/utils/random.go
Normal file
15
server/utils/random.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// GenerateRandomString 生成随机字符串
|
||||
func GenerateRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length/2+1)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes)[:length], nil
|
||||
}
|
||||
62
server/utils/request/http.go
Normal file
62
server/utils/request/http.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func HttpRequest(
|
||||
urlStr string,
|
||||
method string,
|
||||
headers map[string]string,
|
||||
params map[string]string,
|
||||
data any) (*http.Response, error) {
|
||||
// 创建URL
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 添加查询参数
|
||||
query := u.Query()
|
||||
for k, v := range params {
|
||||
query.Set(k, v)
|
||||
}
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
// 将数据编码为JSON
|
||||
buf := new(bytes.Buffer)
|
||||
if data != nil {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf = bytes.NewBuffer(b)
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequest(method, u.String(), buf)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 返回响应,让调用者处理
|
||||
return resp, nil
|
||||
}
|
||||
127
server/utils/server.go
Normal file
127
server/utils/server.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
)
|
||||
|
||||
const (
|
||||
B = 1
|
||||
KB = 1024 * B
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Os Os `json:"os"`
|
||||
Cpu Cpu `json:"cpu"`
|
||||
Ram Ram `json:"ram"`
|
||||
Disk []Disk `json:"disk"`
|
||||
}
|
||||
|
||||
type Os struct {
|
||||
GOOS string `json:"goos"`
|
||||
NumCPU int `json:"numCpu"`
|
||||
Compiler string `json:"compiler"`
|
||||
GoVersion string `json:"goVersion"`
|
||||
NumGoroutine int `json:"numGoroutine"`
|
||||
}
|
||||
|
||||
type Cpu struct {
|
||||
Cpus []float64 `json:"cpus"`
|
||||
Cores int `json:"cores"`
|
||||
}
|
||||
|
||||
type Ram struct {
|
||||
UsedMB int `json:"usedMb"`
|
||||
TotalMB int `json:"totalMb"`
|
||||
UsedPercent int `json:"usedPercent"`
|
||||
}
|
||||
|
||||
type Disk struct {
|
||||
MountPoint string `json:"mountPoint"`
|
||||
UsedMB int `json:"usedMb"`
|
||||
UsedGB int `json:"usedGb"`
|
||||
TotalMB int `json:"totalMb"`
|
||||
TotalGB int `json:"totalGb"`
|
||||
UsedPercent int `json:"usedPercent"`
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitCPU
|
||||
//@description: OS信息
|
||||
//@return: o Os, err error
|
||||
|
||||
func InitOS() (o Os) {
|
||||
o.GOOS = runtime.GOOS
|
||||
o.NumCPU = runtime.NumCPU()
|
||||
o.Compiler = runtime.Compiler
|
||||
o.GoVersion = runtime.Version()
|
||||
o.NumGoroutine = runtime.NumGoroutine()
|
||||
return o
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitCPU
|
||||
//@description: CPU信息
|
||||
//@return: c Cpu, err error
|
||||
|
||||
func InitCPU() (c Cpu, err error) {
|
||||
if cores, err := cpu.Counts(false); err != nil {
|
||||
return c, err
|
||||
} else {
|
||||
c.Cores = cores
|
||||
}
|
||||
if cpus, err := cpu.Percent(time.Duration(200)*time.Millisecond, true); err != nil {
|
||||
return c, err
|
||||
} else {
|
||||
c.Cpus = cpus
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitRAM
|
||||
//@description: RAM信息
|
||||
//@return: r Ram, err error
|
||||
|
||||
func InitRAM() (r Ram, err error) {
|
||||
if u, err := mem.VirtualMemory(); err != nil {
|
||||
return r, err
|
||||
} else {
|
||||
r.UsedMB = int(u.Used) / MB
|
||||
r.TotalMB = int(u.Total) / MB
|
||||
r.UsedPercent = int(u.UsedPercent)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitDisk
|
||||
//@description: 硬盘信息
|
||||
//@return: d Disk, err error
|
||||
|
||||
func InitDisk() (d []Disk, err error) {
|
||||
for i := range global.GVA_CONFIG.DiskList {
|
||||
mp := global.GVA_CONFIG.DiskList[i].MountPoint
|
||||
if u, err := disk.Usage(mp); err != nil {
|
||||
return d, err
|
||||
} else {
|
||||
d = append(d, Disk{
|
||||
MountPoint: mp,
|
||||
UsedMB: int(u.Used) / MB,
|
||||
UsedGB: int(u.Used) / GB,
|
||||
TotalMB: int(u.Total) / MB,
|
||||
TotalGB: int(u.Total) / GB,
|
||||
UsedPercent: int(u.UsedPercent),
|
||||
})
|
||||
}
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
79
server/utils/stacktrace/stacktrace.go
Normal file
79
server/utils/stacktrace/stacktrace.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package stacktrace
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Frame 表示一次栈帧解析结果
|
||||
type Frame struct {
|
||||
File string
|
||||
Line int
|
||||
Func string
|
||||
}
|
||||
|
||||
var fileLineRe = regexp.MustCompile(`\s*(.+\.go):(\d+)\s*$`)
|
||||
|
||||
// FindFinalCaller 从 zap 的 entry.Stack 文本中,解析“最终业务调用方”的文件与行号
|
||||
// 策略:自顶向下解析,优先选择第一条项目代码帧,过滤第三方库/标准库/框架中间件
|
||||
func FindFinalCaller(stack string) (Frame, bool) {
|
||||
if stack == "" {
|
||||
return Frame{}, false
|
||||
}
|
||||
lines := strings.Split(stack, "\n")
|
||||
var currFunc string
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := strings.TrimSpace(lines[i])
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if m := fileLineRe.FindStringSubmatch(line); m != nil {
|
||||
file := m[1]
|
||||
ln, _ := strconv.Atoi(m[2])
|
||||
if shouldSkip(file) {
|
||||
// 跳过此帧,同时重置函数名以避免错误配对
|
||||
currFunc = ""
|
||||
continue
|
||||
}
|
||||
return Frame{File: file, Line: ln, Func: currFunc}, true
|
||||
}
|
||||
// 记录函数名行,下一行通常是文件:行
|
||||
currFunc = line
|
||||
}
|
||||
return Frame{}, false
|
||||
}
|
||||
|
||||
func shouldSkip(file string) bool {
|
||||
// 第三方库与 Go 模块缓存
|
||||
if strings.Contains(file, "/go/pkg/mod/") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(file, "/go.uber.org/") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(file, "/gorm.io/") {
|
||||
return true
|
||||
}
|
||||
// 标准库
|
||||
if strings.Contains(file, "/go/go") && strings.Contains(file, "/src/") { // e.g. /Users/name/go/go1.24.2/src/net/http/server.go
|
||||
return true
|
||||
}
|
||||
// 框架内不需要作为最终调用方的路径
|
||||
if strings.Contains(file, "/server/core/zap.go") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(file, "/server/core/") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(file, "/server/utils/errorhook/") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(file, "/server/middleware/") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(file, "/server/router/") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
34
server/utils/system_events.go
Normal file
34
server/utils/system_events.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SystemEvents 定义系统级事件处理
|
||||
type SystemEvents struct {
|
||||
reloadHandlers []func() error
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// 全局事件管理器
|
||||
var GlobalSystemEvents = &SystemEvents{}
|
||||
|
||||
// RegisterReloadHandler 注册系统重载处理函数
|
||||
func (e *SystemEvents) RegisterReloadHandler(handler func() error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.reloadHandlers = append(e.reloadHandlers, handler)
|
||||
}
|
||||
|
||||
// TriggerReload 触发所有注册的重载处理函数
|
||||
func (e *SystemEvents) TriggerReload() error {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
for _, handler := range e.reloadHandlers {
|
||||
if err := handler(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
230
server/utils/timer/timed_task.go
Normal file
230
server/utils/timer/timed_task.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package timer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
type Timer interface {
|
||||
// 寻找所有Cron
|
||||
FindCronList() map[string]*taskManager
|
||||
// 添加Task 方法形式以秒的形式加入
|
||||
AddTaskByFuncWithSecond(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) // 添加Task Func以秒的形式加入
|
||||
// 添加Task 接口形式以秒的形式加入
|
||||
AddTaskByJobWithSeconds(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error)
|
||||
// 通过函数的方法添加任务
|
||||
AddTaskByFunc(cronName string, spec string, task func(), taskName string, option ...cron.Option) (cron.EntryID, error)
|
||||
// 通过接口的方法添加任务 要实现一个带有 Run方法的接口触发
|
||||
AddTaskByJob(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error)
|
||||
// 获取对应taskName的cron 可能会为空
|
||||
FindCron(cronName string) (*taskManager, bool)
|
||||
// 指定cron开始执行
|
||||
StartCron(cronName string)
|
||||
// 指定cron停止执行
|
||||
StopCron(cronName string)
|
||||
// 查找指定cron下的指定task
|
||||
FindTask(cronName string, taskName string) (*task, bool)
|
||||
// 根据id删除指定cron下的指定task
|
||||
RemoveTask(cronName string, id int)
|
||||
// 根据taskName删除指定cron下的指定task
|
||||
RemoveTaskByName(cronName string, taskName string)
|
||||
// 清理掉指定cronName
|
||||
Clear(cronName string)
|
||||
// 停止所有的cron
|
||||
Close()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
EntryID cron.EntryID
|
||||
Spec string
|
||||
TaskName string
|
||||
}
|
||||
|
||||
type taskManager struct {
|
||||
corn *cron.Cron
|
||||
tasks map[cron.EntryID]*task
|
||||
}
|
||||
|
||||
// timer 定时任务管理
|
||||
type timer struct {
|
||||
cronList map[string]*taskManager
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// AddTaskByFunc 通过函数的方法添加任务
|
||||
func (t *timer) AddTaskByFunc(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddFunc(spec, fun)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// AddTaskByFuncWithSecond 通过函数的方法使用WithSeconds添加任务
|
||||
func (t *timer) AddTaskByFuncWithSecond(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
option = append(option, cron.WithSeconds())
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddFunc(spec, fun)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// AddTaskByJob 通过接口的方法添加任务
|
||||
func (t *timer) AddTaskByJob(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddJob(spec, job)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// AddTaskByJobWithSeconds 通过接口的方法添加任务
|
||||
func (t *timer) AddTaskByJobWithSeconds(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
option = append(option, cron.WithSeconds())
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddJob(spec, job)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// FindCron 获取对应cronName的cron 可能会为空
|
||||
func (t *timer) FindCron(cronName string) (*taskManager, bool) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
v, ok := t.cronList[cronName]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// FindTask 获取对应cronName的cron 可能会为空
|
||||
func (t *timer) FindTask(cronName string, taskName string) (*task, bool) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
v, ok := t.cronList[cronName]
|
||||
if !ok {
|
||||
return nil, ok
|
||||
}
|
||||
for _, t2 := range v.tasks {
|
||||
if t2.TaskName == taskName {
|
||||
return t2, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindCronList 获取所有的任务列表
|
||||
func (t *timer) FindCronList() map[string]*taskManager {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
return t.cronList
|
||||
}
|
||||
|
||||
// StartCron 开始任务
|
||||
func (t *timer) StartCron(cronName string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// StopCron 停止任务
|
||||
func (t *timer) StopCron(cronName string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveTask 从cronName 删除指定任务
|
||||
func (t *timer) RemoveTask(cronName string, id int) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Remove(cron.EntryID(id))
|
||||
delete(v.tasks, cron.EntryID(id))
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveTaskByName 从cronName 使用taskName 删除指定任务
|
||||
func (t *timer) RemoveTaskByName(cronName string, taskName string) {
|
||||
fTask, ok := t.FindTask(cronName, taskName)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
t.RemoveTask(cronName, int(fTask.EntryID))
|
||||
}
|
||||
|
||||
// Clear 清除任务
|
||||
func (t *timer) Clear(cronName string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Stop()
|
||||
delete(t.cronList, cronName)
|
||||
}
|
||||
}
|
||||
|
||||
// Close 释放资源
|
||||
func (t *timer) Close() {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
for _, v := range t.cronList {
|
||||
v.corn.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func NewTimerTask() Timer {
|
||||
return &timer{cronList: make(map[string]*taskManager)}
|
||||
}
|
||||
72
server/utils/timer/timed_task_test.go
Normal file
72
server/utils/timer/timed_task_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package timer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var job = mockJob{}
|
||||
|
||||
type mockJob struct{}
|
||||
|
||||
func (job mockJob) Run() {
|
||||
mockFunc()
|
||||
}
|
||||
|
||||
func mockFunc() {
|
||||
time.Sleep(time.Second)
|
||||
fmt.Println("1s...")
|
||||
}
|
||||
|
||||
func TestNewTimerTask(t *testing.T) {
|
||||
tm := NewTimerTask()
|
||||
_tm := tm.(*timer)
|
||||
|
||||
{
|
||||
_, err := tm.AddTaskByFunc("func", "@every 1s", mockFunc, "测试mockfunc")
|
||||
assert.Nil(t, err)
|
||||
_, ok := _tm.cronList["func"]
|
||||
if !ok {
|
||||
t.Error("no find func")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
_, err := tm.AddTaskByJob("job", "@every 1s", job, "测试job mockfunc")
|
||||
assert.Nil(t, err)
|
||||
_, ok := _tm.cronList["job"]
|
||||
if !ok {
|
||||
t.Error("no find job")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
_, ok := tm.FindCron("func")
|
||||
if !ok {
|
||||
t.Error("no find func")
|
||||
}
|
||||
_, ok = tm.FindCron("job")
|
||||
if !ok {
|
||||
t.Error("no find job")
|
||||
}
|
||||
_, ok = tm.FindCron("none")
|
||||
if ok {
|
||||
t.Error("find none")
|
||||
}
|
||||
}
|
||||
{
|
||||
tm.Clear("func")
|
||||
_, ok := tm.FindCron("func")
|
||||
if ok {
|
||||
t.Error("find func")
|
||||
}
|
||||
}
|
||||
{
|
||||
a := tm.FindCronList()
|
||||
b, c := tm.FindCron("job")
|
||||
fmt.Println(a, b, c)
|
||||
}
|
||||
}
|
||||
75
server/utils/upload/aliyun_oss.go
Normal file
75
server/utils/upload/aliyun_oss.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/aliyun/aliyun-oss-go-sdk/oss"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AliyunOSS struct{}
|
||||
|
||||
func (*AliyunOSS) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
bucket, err := NewBucket()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function AliyunOSS.NewBucket() Failed", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("function AliyunOSS.NewBucket() Failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
// 读取本地文件。
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() Failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() Failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
// 上传阿里云路径 文件名格式 自己可以改 建议保证唯一性
|
||||
// yunFileTmpPath := filepath.Join("uploads", time.Now().Format("2006-01-02")) + "/" + file.Filename
|
||||
yunFileTmpPath := global.GVA_CONFIG.AliyunOSS.BasePath + "/" + "uploads" + "/" + time.Now().Format("2006-01-02") + "/" + file.Filename
|
||||
|
||||
// 上传文件流。
|
||||
err = bucket.PutObject(yunFileTmpPath, f)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function formUploader.Put() Failed", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("function formUploader.Put() Failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
return global.GVA_CONFIG.AliyunOSS.BucketUrl + "/" + yunFileTmpPath, yunFileTmpPath, nil
|
||||
}
|
||||
|
||||
func (*AliyunOSS) DeleteFile(key string) error {
|
||||
bucket, err := NewBucket()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function AliyunOSS.NewBucket() Failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function AliyunOSS.NewBucket() Failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
// 删除单个文件。objectName表示删除OSS文件时需要指定包含文件后缀在内的完整路径,例如abc/efg/123.jpg。
|
||||
// 如需删除文件夹,请将objectName设置为对应的文件夹名称。如果文件夹非空,则需要将文件夹下的所有object删除后才能删除该文件夹。
|
||||
err = bucket.DeleteObject(key)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewBucket() (*oss.Bucket, error) {
|
||||
// 创建OSSClient实例。
|
||||
client, err := oss.New(global.GVA_CONFIG.AliyunOSS.Endpoint, global.GVA_CONFIG.AliyunOSS.AccessKeyId, global.GVA_CONFIG.AliyunOSS.AccessKeySecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取存储空间。
|
||||
bucket, err := client.Bucket(global.GVA_CONFIG.AliyunOSS.BucketName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bucket, nil
|
||||
}
|
||||
98
server/utils/upload/aws_s3.go
Normal file
98
server/utils/upload/aws_s3.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AwsS3 struct{}
|
||||
|
||||
//@author: [WqyJh](https://github.com/WqyJh)
|
||||
//@object: *AwsS3
|
||||
//@function: UploadFile
|
||||
//@description: Upload file to Aws S3 using aws-sdk-go. See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/s3-example-basic-bucket-operations.html#s3-examples-bucket-ops-upload-file-to-bucket
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*AwsS3) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
session := newSession()
|
||||
uploader := s3manager.NewUploader(session)
|
||||
|
||||
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename)
|
||||
filename := global.GVA_CONFIG.AwsS3.PathPrefix + "/" + fileKey
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
|
||||
_, err := uploader.Upload(&s3manager.UploadInput{
|
||||
Bucket: aws.String(global.GVA_CONFIG.AwsS3.Bucket),
|
||||
Key: aws.String(filename),
|
||||
Body: f,
|
||||
ContentType: aws.String(file.Header.Get("Content-Type")),
|
||||
})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function uploader.Upload() failed", zap.Any("err", err.Error()))
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return global.GVA_CONFIG.AwsS3.BaseURL + "/" + filename, fileKey, nil
|
||||
}
|
||||
|
||||
//@author: [WqyJh](https://github.com/WqyJh)
|
||||
//@object: *AwsS3
|
||||
//@function: DeleteFile
|
||||
//@description: Delete file from Aws S3 using aws-sdk-go. See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/s3-example-basic-bucket-operations.html#s3-examples-bucket-ops-delete-bucket-item
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*AwsS3) DeleteFile(key string) error {
|
||||
session := newSession()
|
||||
svc := s3.New(session)
|
||||
filename := global.GVA_CONFIG.AwsS3.PathPrefix + "/" + key
|
||||
bucket := global.GVA_CONFIG.AwsS3.Bucket
|
||||
|
||||
_, err := svc.DeleteObject(&s3.DeleteObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function svc.DeleteObject() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function svc.DeleteObject() failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
_ = svc.WaitUntilObjectNotExists(&s3.HeadObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// newSession Create S3 session
|
||||
func newSession() *session.Session {
|
||||
sess, _ := session.NewSession(&aws.Config{
|
||||
Region: aws.String(global.GVA_CONFIG.AwsS3.Region),
|
||||
Endpoint: aws.String(global.GVA_CONFIG.AwsS3.Endpoint), //minio在这里设置地址,可以兼容
|
||||
S3ForcePathStyle: aws.Bool(global.GVA_CONFIG.AwsS3.S3ForcePathStyle),
|
||||
DisableSSL: aws.Bool(global.GVA_CONFIG.AwsS3.DisableSSL),
|
||||
Credentials: credentials.NewStaticCredentials(
|
||||
global.GVA_CONFIG.AwsS3.SecretID,
|
||||
global.GVA_CONFIG.AwsS3.SecretKey,
|
||||
"",
|
||||
),
|
||||
})
|
||||
return sess
|
||||
}
|
||||
85
server/utils/upload/cloudflare_r2.go
Normal file
85
server/utils/upload/cloudflare_r2.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type CloudflareR2 struct{}
|
||||
|
||||
func (c *CloudflareR2) UploadFile(file *multipart.FileHeader) (fileUrl string, fileName string, err error) {
|
||||
session := c.newSession()
|
||||
client := s3manager.NewUploader(session)
|
||||
|
||||
fileKey := fmt.Sprintf("%d_%s", time.Now().Unix(), file.Filename)
|
||||
fileName = fmt.Sprintf("%s/%s", global.GVA_CONFIG.CloudflareR2.Path, fileKey)
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
|
||||
input := &s3manager.UploadInput{
|
||||
Bucket: aws.String(global.GVA_CONFIG.CloudflareR2.Bucket),
|
||||
Key: aws.String(fileName),
|
||||
Body: f,
|
||||
}
|
||||
|
||||
_, err = client.Upload(input)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function uploader.Upload() failed", zap.Any("err", err.Error()))
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s", global.GVA_CONFIG.CloudflareR2.BaseURL,
|
||||
fileName),
|
||||
fileKey,
|
||||
nil
|
||||
}
|
||||
|
||||
func (c *CloudflareR2) DeleteFile(key string) error {
|
||||
session := newSession()
|
||||
svc := s3.New(session)
|
||||
filename := global.GVA_CONFIG.CloudflareR2.Path + "/" + key
|
||||
bucket := global.GVA_CONFIG.CloudflareR2.Bucket
|
||||
|
||||
_, err := svc.DeleteObject(&s3.DeleteObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function svc.DeleteObject() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function svc.DeleteObject() failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
_ = svc.WaitUntilObjectNotExists(&s3.HeadObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*CloudflareR2) newSession() *session.Session {
|
||||
endpoint := fmt.Sprintf("%s.r2.cloudflarestorage.com", global.GVA_CONFIG.CloudflareR2.AccountID)
|
||||
|
||||
return session.Must(session.NewSession(&aws.Config{
|
||||
Region: aws.String("auto"),
|
||||
Endpoint: aws.String(endpoint),
|
||||
Credentials: credentials.NewStaticCredentials(
|
||||
global.GVA_CONFIG.CloudflareR2.AccessKeyID,
|
||||
global.GVA_CONFIG.CloudflareR2.SecretAccessKey,
|
||||
"",
|
||||
),
|
||||
}))
|
||||
}
|
||||
109
server/utils/upload/local.go
Normal file
109
server/utils/upload/local.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
type Local struct{}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Local
|
||||
//@function: UploadFile
|
||||
//@description: 上传文件
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*Local) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
// 读取文件后缀
|
||||
ext := filepath.Ext(file.Filename)
|
||||
// 读取文件名并加密
|
||||
name := strings.TrimSuffix(file.Filename, ext)
|
||||
name = utils.MD5V([]byte(name))
|
||||
// 拼接新文件名
|
||||
filename := name + "_" + time.Now().Format("20060102150405") + ext
|
||||
// 尝试创建此路径
|
||||
mkdirErr := os.MkdirAll(global.GVA_CONFIG.Local.StorePath, os.ModePerm)
|
||||
if mkdirErr != nil {
|
||||
global.GVA_LOG.Error("function os.MkdirAll() failed", zap.Any("err", mkdirErr.Error()))
|
||||
return "", "", errors.New("function os.MkdirAll() failed, err:" + mkdirErr.Error())
|
||||
}
|
||||
// 拼接路径和文件名
|
||||
p := global.GVA_CONFIG.Local.StorePath + "/" + filename
|
||||
filepath := global.GVA_CONFIG.Local.Path + "/" + filename
|
||||
|
||||
f, openError := file.Open() // 读取文件
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
|
||||
out, createErr := os.Create(p)
|
||||
if createErr != nil {
|
||||
global.GVA_LOG.Error("function os.Create() failed", zap.Any("err", createErr.Error()))
|
||||
|
||||
return "", "", errors.New("function os.Create() failed, err:" + createErr.Error())
|
||||
}
|
||||
defer out.Close() // 创建文件 defer 关闭
|
||||
|
||||
_, copyErr := io.Copy(out, f) // 传输(拷贝)文件
|
||||
if copyErr != nil {
|
||||
global.GVA_LOG.Error("function io.Copy() failed", zap.Any("err", copyErr.Error()))
|
||||
return "", "", errors.New("function io.Copy() failed, err:" + copyErr.Error())
|
||||
}
|
||||
return filepath, filename, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Local
|
||||
//@function: DeleteFile
|
||||
//@description: 删除文件
|
||||
//@param: key string
|
||||
//@return: error
|
||||
|
||||
func (*Local) DeleteFile(key string) error {
|
||||
// 检查 key 是否为空
|
||||
if key == "" {
|
||||
return errors.New("key不能为空")
|
||||
}
|
||||
|
||||
// 验证 key 是否包含非法字符或尝试访问存储路径之外的文件
|
||||
if strings.Contains(key, "..") || strings.ContainsAny(key, `\/:*?"<>|`) {
|
||||
return errors.New("非法的key")
|
||||
}
|
||||
|
||||
p := filepath.Join(global.GVA_CONFIG.Local.StorePath, key)
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(p); os.IsNotExist(err) {
|
||||
return errors.New("文件不存在")
|
||||
}
|
||||
|
||||
// 使用文件锁防止并发删除
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
err := os.Remove(p)
|
||||
if err != nil {
|
||||
return errors.New("文件删除失败: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
106
server/utils/upload/minio_oss.go
Normal file
106
server/utils/upload/minio_oss.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"git.echol.cn/loser/ai_proxy/server/utils"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var MinioClient *Minio // 优化性能,但是不支持动态配置
|
||||
|
||||
type Minio struct {
|
||||
Client *minio.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
func GetMinio(endpoint, accessKeyID, secretAccessKey, bucketName string, useSSL bool) (*Minio, error) {
|
||||
if MinioClient != nil {
|
||||
return MinioClient, nil
|
||||
}
|
||||
// Initialize minio client object.
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
|
||||
Secure: useSSL, // Set to true if using https
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 尝试创建bucket
|
||||
err = minioClient.MakeBucket(context.Background(), bucketName, minio.MakeBucketOptions{})
|
||||
if err != nil {
|
||||
// Check to see if we already own this bucket (which happens if you run this twice)
|
||||
exists, errBucketExists := minioClient.BucketExists(context.Background(), bucketName)
|
||||
if errBucketExists == nil && exists {
|
||||
// log.Printf("We already own %s\n", bucketName)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
MinioClient = &Minio{Client: minioClient, bucket: bucketName}
|
||||
return MinioClient, nil
|
||||
}
|
||||
|
||||
func (m *Minio) UploadFile(file *multipart.FileHeader) (filePathres, key string, uploadErr error) {
|
||||
f, openError := file.Open()
|
||||
// mutipart.File to os.File
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() Failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() Failed, err:" + openError.Error())
|
||||
}
|
||||
|
||||
filecontent := bytes.Buffer{}
|
||||
_, err := io.Copy(&filecontent, f)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("读取文件失败", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("读取文件失败, err:" + err.Error())
|
||||
}
|
||||
f.Close() // 创建文件 defer 关闭
|
||||
|
||||
// 对文件名进行加密存储
|
||||
ext := filepath.Ext(file.Filename)
|
||||
filename := utils.MD5V([]byte(strings.TrimSuffix(file.Filename, ext))) + ext
|
||||
if global.GVA_CONFIG.Minio.BasePath == "" {
|
||||
filePathres = "uploads" + "/" + time.Now().Format("2006-01-02") + "/" + filename
|
||||
} else {
|
||||
filePathres = global.GVA_CONFIG.Minio.BasePath + "/" + time.Now().Format("2006-01-02") + "/" + filename
|
||||
}
|
||||
|
||||
// 根据文件扩展名检测 MIME 类型
|
||||
contentType := mime.TypeByExtension(ext)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
// 设置超时10分钟
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
|
||||
defer cancel()
|
||||
|
||||
// Upload the file with PutObject 大文件自动切换为分片上传
|
||||
info, err := m.Client.PutObject(ctx, global.GVA_CONFIG.Minio.BucketName, filePathres, &filecontent, file.Size, minio.PutObjectOptions{ContentType: contentType})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("上传文件到minio失败", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("上传文件到minio失败, err:" + err.Error())
|
||||
}
|
||||
return global.GVA_CONFIG.Minio.BucketUrl + "/" + info.Key, filePathres, nil
|
||||
}
|
||||
|
||||
func (m *Minio) DeleteFile(key string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
// Delete the object from MinIO
|
||||
err := m.Client.RemoveObject(ctx, m.bucket, key, minio.RemoveObjectOptions{})
|
||||
return err
|
||||
}
|
||||
69
server/utils/upload/obs.go
Normal file
69
server/utils/upload/obs.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/huaweicloud/huaweicloud-sdk-go-obs/obs"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var HuaWeiObs = new(Obs)
|
||||
|
||||
type Obs struct{}
|
||||
|
||||
func NewHuaWeiObsClient() (client *obs.ObsClient, err error) {
|
||||
return obs.New(global.GVA_CONFIG.HuaWeiObs.AccessKey, global.GVA_CONFIG.HuaWeiObs.SecretKey, global.GVA_CONFIG.HuaWeiObs.Endpoint)
|
||||
}
|
||||
|
||||
func (o *Obs) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
// var open multipart.File
|
||||
open, err := file.Open()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer open.Close()
|
||||
filename := file.Filename
|
||||
input := &obs.PutObjectInput{
|
||||
PutObjectBasicInput: obs.PutObjectBasicInput{
|
||||
ObjectOperationInput: obs.ObjectOperationInput{
|
||||
Bucket: global.GVA_CONFIG.HuaWeiObs.Bucket,
|
||||
Key: filename,
|
||||
},
|
||||
HttpHeader: obs.HttpHeader{
|
||||
ContentType: file.Header.Get("content-type"),
|
||||
},
|
||||
},
|
||||
Body: open,
|
||||
}
|
||||
|
||||
var client *obs.ObsClient
|
||||
client, err = NewHuaWeiObsClient()
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(err, "获取华为对象存储对象失败!")
|
||||
}
|
||||
|
||||
_, err = client.PutObject(input)
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(err, "文件上传失败!")
|
||||
}
|
||||
filepath := global.GVA_CONFIG.HuaWeiObs.Path + "/" + filename
|
||||
return filepath, filename, err
|
||||
}
|
||||
|
||||
func (o *Obs) DeleteFile(key string) error {
|
||||
client, err := NewHuaWeiObsClient()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "获取华为对象存储对象失败!")
|
||||
}
|
||||
input := &obs.DeleteObjectInput{
|
||||
Bucket: global.GVA_CONFIG.HuaWeiObs.Bucket,
|
||||
Key: key,
|
||||
}
|
||||
var output *obs.DeleteObjectOutput
|
||||
output, err = client.DeleteObject(input)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "删除对象(%s)失败!, output: %v", key, output)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
96
server/utils/upload/qiniu.go
Normal file
96
server/utils/upload/qiniu.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Qiniu struct{}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Qiniu
|
||||
//@function: UploadFile
|
||||
//@description: 上传文件
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*Qiniu) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
putPolicy := storage.PutPolicy{Scope: global.GVA_CONFIG.Qiniu.Bucket}
|
||||
mac := qbox.NewMac(global.GVA_CONFIG.Qiniu.AccessKey, global.GVA_CONFIG.Qiniu.SecretKey)
|
||||
upToken := putPolicy.UploadToken(mac)
|
||||
cfg := qiniuConfig()
|
||||
formUploader := storage.NewFormUploader(cfg)
|
||||
ret := storage.PutRet{}
|
||||
putExtra := storage.PutExtra{Params: map[string]string{"x:name": "github logo"}}
|
||||
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename) // 文件名格式 自己可以改 建议保证唯一性
|
||||
putErr := formUploader.Put(context.Background(), &ret, upToken, fileKey, f, file.Size, &putExtra)
|
||||
if putErr != nil {
|
||||
global.GVA_LOG.Error("function formUploader.Put() failed", zap.Any("err", putErr.Error()))
|
||||
return "", "", errors.New("function formUploader.Put() failed, err:" + putErr.Error())
|
||||
}
|
||||
return global.GVA_CONFIG.Qiniu.ImgPath + "/" + ret.Key, ret.Key, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Qiniu
|
||||
//@function: DeleteFile
|
||||
//@description: 删除文件
|
||||
//@param: key string
|
||||
//@return: error
|
||||
|
||||
func (*Qiniu) DeleteFile(key string) error {
|
||||
mac := qbox.NewMac(global.GVA_CONFIG.Qiniu.AccessKey, global.GVA_CONFIG.Qiniu.SecretKey)
|
||||
cfg := qiniuConfig()
|
||||
bucketManager := storage.NewBucketManager(mac, cfg)
|
||||
if err := bucketManager.Delete(global.GVA_CONFIG.Qiniu.Bucket, key); err != nil {
|
||||
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Qiniu
|
||||
//@function: qiniuConfig
|
||||
//@description: 根据配置文件进行返回七牛云的配置
|
||||
//@return: *storage.Config
|
||||
|
||||
func qiniuConfig() *storage.Config {
|
||||
cfg := storage.Config{
|
||||
UseHTTPS: global.GVA_CONFIG.Qiniu.UseHTTPS,
|
||||
UseCdnDomains: global.GVA_CONFIG.Qiniu.UseCdnDomains,
|
||||
}
|
||||
switch global.GVA_CONFIG.Qiniu.Zone { // 根据配置文件进行初始化空间对应的机房
|
||||
case "ZoneHuadong":
|
||||
cfg.Zone = &storage.ZoneHuadong
|
||||
case "ZoneHuabei":
|
||||
cfg.Zone = &storage.ZoneHuabei
|
||||
case "ZoneHuanan":
|
||||
cfg.Zone = &storage.ZoneHuanan
|
||||
case "ZoneBeimei":
|
||||
cfg.Zone = &storage.ZoneBeimei
|
||||
case "ZoneXinjiapo":
|
||||
cfg.Zone = &storage.ZoneXinjiapo
|
||||
}
|
||||
return &cfg
|
||||
}
|
||||
61
server/utils/upload/tencent_cos.go
Normal file
61
server/utils/upload/tencent_cos.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
|
||||
"github.com/tencentyun/cos-go-sdk-v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type TencentCOS struct{}
|
||||
|
||||
// UploadFile upload file to COS
|
||||
func (*TencentCOS) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
client := NewClient()
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename)
|
||||
|
||||
_, err := client.Object.Put(context.Background(), global.GVA_CONFIG.TencentCOS.PathPrefix+"/"+fileKey, f, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return global.GVA_CONFIG.TencentCOS.BaseURL + "/" + global.GVA_CONFIG.TencentCOS.PathPrefix + "/" + fileKey, fileKey, nil
|
||||
}
|
||||
|
||||
// DeleteFile delete file form COS
|
||||
func (*TencentCOS) DeleteFile(key string) error {
|
||||
client := NewClient()
|
||||
name := global.GVA_CONFIG.TencentCOS.PathPrefix + "/" + key
|
||||
_, err := client.Object.Delete(context.Background(), name)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewClient init COS client
|
||||
func NewClient() *cos.Client {
|
||||
urlStr, _ := url.Parse("https://" + global.GVA_CONFIG.TencentCOS.Bucket + ".cos." + global.GVA_CONFIG.TencentCOS.Region + ".myqcloud.com")
|
||||
baseURL := &cos.BaseURL{BucketURL: urlStr}
|
||||
client := cos.NewClient(baseURL, &http.Client{
|
||||
Transport: &cos.AuthorizationTransport{
|
||||
SecretID: global.GVA_CONFIG.TencentCOS.SecretID,
|
||||
SecretKey: global.GVA_CONFIG.TencentCOS.SecretKey,
|
||||
},
|
||||
})
|
||||
return client
|
||||
}
|
||||
46
server/utils/upload/upload.go
Normal file
46
server/utils/upload/upload.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/global"
|
||||
)
|
||||
|
||||
// OSS 对象存储接口
|
||||
// Author [SliverHorn](https://github.com/SliverHorn)
|
||||
// Author [ccfish86](https://github.com/ccfish86)
|
||||
type OSS interface {
|
||||
UploadFile(file *multipart.FileHeader) (string, string, error)
|
||||
DeleteFile(key string) error
|
||||
}
|
||||
|
||||
// NewOss OSS的实例化方法
|
||||
// Author [SliverHorn](https://github.com/SliverHorn)
|
||||
// Author [ccfish86](https://github.com/ccfish86)
|
||||
func NewOss() OSS {
|
||||
switch global.GVA_CONFIG.System.OssType {
|
||||
case "local":
|
||||
return &Local{}
|
||||
case "qiniu":
|
||||
return &Qiniu{}
|
||||
case "tencent-cos":
|
||||
return &TencentCOS{}
|
||||
case "aliyun-oss":
|
||||
return &AliyunOSS{}
|
||||
case "huawei-obs":
|
||||
return HuaWeiObs
|
||||
case "aws-s3":
|
||||
return &AwsS3{}
|
||||
case "cloudflare-r2":
|
||||
return &CloudflareR2{}
|
||||
case "minio":
|
||||
minioClient, err := GetMinio(global.GVA_CONFIG.Minio.Endpoint, global.GVA_CONFIG.Minio.AccessKeyId, global.GVA_CONFIG.Minio.AccessKeySecret, global.GVA_CONFIG.Minio.BucketName, global.GVA_CONFIG.Minio.UseSSL)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Warn("你配置了使用minio,但是初始化失败,请检查minio可用性或安全配置: " + err.Error())
|
||||
panic("minio初始化失败") // 建议这样做,用户自己配置了minio,如果报错了还要把服务开起来,使用起来也很危险
|
||||
}
|
||||
return minioClient
|
||||
default:
|
||||
return &Local{}
|
||||
}
|
||||
}
|
||||
294
server/utils/validator.go
Normal file
294
server/utils/validator.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Rules map[string][]string
|
||||
|
||||
type RulesMap map[string]Rules
|
||||
|
||||
var CustomizeMap = make(map[string]Rules)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: RegisterRule
|
||||
//@description: 注册自定义规则方案建议在路由初始化层即注册
|
||||
//@param: key string, rule Rules
|
||||
//@return: err error
|
||||
|
||||
func RegisterRule(key string, rule Rules) (err error) {
|
||||
if CustomizeMap[key] != nil {
|
||||
return errors.New(key + "已注册,无法重复注册")
|
||||
} else {
|
||||
CustomizeMap[key] = rule
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: NotEmpty
|
||||
//@description: 非空 不能为其对应类型的0值
|
||||
//@return: string
|
||||
|
||||
func NotEmpty() string {
|
||||
return "notEmpty"
|
||||
}
|
||||
|
||||
// @author: [zooqkl](https://github.com/zooqkl)
|
||||
// @function: RegexpMatch
|
||||
// @description: 正则校验 校验输入项是否满足正则表达式
|
||||
// @param: rule string
|
||||
// @return: string
|
||||
|
||||
func RegexpMatch(rule string) string {
|
||||
return "regexp=" + rule
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Lt
|
||||
//@description: 小于入参(<) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Lt(mark string) string {
|
||||
return "lt=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Le
|
||||
//@description: 小于等于入参(<=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Le(mark string) string {
|
||||
return "le=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Eq
|
||||
//@description: 等于入参(==) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Eq(mark string) string {
|
||||
return "eq=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Ne
|
||||
//@description: 不等于入参(!=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Ne(mark string) string {
|
||||
return "ne=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Ge
|
||||
//@description: 大于等于入参(>=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Ge(mark string) string {
|
||||
return "ge=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Gt
|
||||
//@description: 大于入参(>) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Gt(mark string) string {
|
||||
return "gt=" + mark
|
||||
}
|
||||
|
||||
//
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Verify
|
||||
//@description: 校验方法
|
||||
//@param: st interface{}, roleMap Rules(入参实例,规则map)
|
||||
//@return: err error
|
||||
|
||||
func Verify(st interface{}, roleMap Rules) (err error) {
|
||||
compareMap := map[string]bool{
|
||||
"lt": true,
|
||||
"le": true,
|
||||
"eq": true,
|
||||
"ne": true,
|
||||
"ge": true,
|
||||
"gt": true,
|
||||
}
|
||||
|
||||
typ := reflect.TypeOf(st)
|
||||
val := reflect.ValueOf(st) // 获取reflect.Type类型
|
||||
|
||||
kd := val.Kind() // 获取到st对应的类别
|
||||
if kd != reflect.Struct {
|
||||
return errors.New("expect struct")
|
||||
}
|
||||
num := val.NumField()
|
||||
// 遍历结构体的所有字段
|
||||
for i := 0; i < num; i++ {
|
||||
tagVal := typ.Field(i)
|
||||
val := val.Field(i)
|
||||
if tagVal.Type.Kind() == reflect.Struct {
|
||||
if err = Verify(val.Interface(), roleMap); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(roleMap[tagVal.Name]) > 0 {
|
||||
for _, v := range roleMap[tagVal.Name] {
|
||||
switch {
|
||||
case v == "notEmpty":
|
||||
if isBlank(val) {
|
||||
return errors.New(tagVal.Name + "值不能为空")
|
||||
}
|
||||
case strings.Split(v, "=")[0] == "regexp":
|
||||
if !regexpMatch(strings.Split(v, "=")[1], val.String()) {
|
||||
return errors.New(tagVal.Name + "格式校验不通过")
|
||||
}
|
||||
case compareMap[strings.Split(v, "=")[0]]:
|
||||
if !compareVerify(val, v) {
|
||||
return errors.New(tagVal.Name + "长度或值不在合法范围," + v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: compareVerify
|
||||
//@description: 长度和数字的校验方法 根据类型自动校验
|
||||
//@param: value reflect.Value, VerifyStr string
|
||||
//@return: bool
|
||||
|
||||
func compareVerify(value reflect.Value, VerifyStr string) bool {
|
||||
switch value.Kind() {
|
||||
case reflect.String:
|
||||
return compare(len([]rune(value.String())), VerifyStr)
|
||||
case reflect.Slice, reflect.Array:
|
||||
return compare(value.Len(), VerifyStr)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return compare(value.Uint(), VerifyStr)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return compare(value.Float(), VerifyStr)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return compare(value.Int(), VerifyStr)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: isBlank
|
||||
//@description: 非空校验
|
||||
//@param: value reflect.Value
|
||||
//@return: bool
|
||||
|
||||
func isBlank(value reflect.Value) bool {
|
||||
switch value.Kind() {
|
||||
case reflect.String, reflect.Slice:
|
||||
return value.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !value.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return value.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return value.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return value.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return value.IsNil()
|
||||
}
|
||||
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: compare
|
||||
//@description: 比较函数
|
||||
//@param: value interface{}, VerifyStr string
|
||||
//@return: bool
|
||||
|
||||
func compare(value interface{}, VerifyStr string) bool {
|
||||
VerifyStrArr := strings.Split(VerifyStr, "=")
|
||||
val := reflect.ValueOf(value)
|
||||
switch val.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
VInt, VErr := strconv.ParseInt(VerifyStrArr[1], 10, 64)
|
||||
if VErr != nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case VerifyStrArr[0] == "lt":
|
||||
return val.Int() < VInt
|
||||
case VerifyStrArr[0] == "le":
|
||||
return val.Int() <= VInt
|
||||
case VerifyStrArr[0] == "eq":
|
||||
return val.Int() == VInt
|
||||
case VerifyStrArr[0] == "ne":
|
||||
return val.Int() != VInt
|
||||
case VerifyStrArr[0] == "ge":
|
||||
return val.Int() >= VInt
|
||||
case VerifyStrArr[0] == "gt":
|
||||
return val.Int() > VInt
|
||||
default:
|
||||
return false
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
VInt, VErr := strconv.Atoi(VerifyStrArr[1])
|
||||
if VErr != nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case VerifyStrArr[0] == "lt":
|
||||
return val.Uint() < uint64(VInt)
|
||||
case VerifyStrArr[0] == "le":
|
||||
return val.Uint() <= uint64(VInt)
|
||||
case VerifyStrArr[0] == "eq":
|
||||
return val.Uint() == uint64(VInt)
|
||||
case VerifyStrArr[0] == "ne":
|
||||
return val.Uint() != uint64(VInt)
|
||||
case VerifyStrArr[0] == "ge":
|
||||
return val.Uint() >= uint64(VInt)
|
||||
case VerifyStrArr[0] == "gt":
|
||||
return val.Uint() > uint64(VInt)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
VFloat, VErr := strconv.ParseFloat(VerifyStrArr[1], 64)
|
||||
if VErr != nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case VerifyStrArr[0] == "lt":
|
||||
return val.Float() < VFloat
|
||||
case VerifyStrArr[0] == "le":
|
||||
return val.Float() <= VFloat
|
||||
case VerifyStrArr[0] == "eq":
|
||||
return val.Float() == VFloat
|
||||
case VerifyStrArr[0] == "ne":
|
||||
return val.Float() != VFloat
|
||||
case VerifyStrArr[0] == "ge":
|
||||
return val.Float() >= VFloat
|
||||
case VerifyStrArr[0] == "gt":
|
||||
return val.Float() > VFloat
|
||||
default:
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func regexpMatch(rule, matchStr string) bool {
|
||||
return regexp.MustCompile(rule).MatchString(matchStr)
|
||||
}
|
||||
38
server/utils/validator_test.go
Normal file
38
server/utils/validator_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.echol.cn/loser/ai_proxy/server/model/common/request"
|
||||
)
|
||||
|
||||
type PageInfoTest struct {
|
||||
PageInfo request.PageInfo
|
||||
Name string
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
PageInfoVerify := Rules{"Page": {NotEmpty()}, "PageSize": {NotEmpty()}, "Name": {NotEmpty()}}
|
||||
var testInfo PageInfoTest
|
||||
testInfo.Name = "test"
|
||||
testInfo.PageInfo.Page = 0
|
||||
testInfo.PageInfo.PageSize = 0
|
||||
err := Verify(testInfo, PageInfoVerify)
|
||||
if err == nil {
|
||||
t.Error("校验失败,未能捕捉0值")
|
||||
}
|
||||
testInfo.Name = ""
|
||||
testInfo.PageInfo.Page = 1
|
||||
testInfo.PageInfo.PageSize = 10
|
||||
err = Verify(testInfo, PageInfoVerify)
|
||||
if err == nil {
|
||||
t.Error("校验失败,未能正常检测name为空")
|
||||
}
|
||||
testInfo.Name = "test"
|
||||
testInfo.PageInfo.Page = 1
|
||||
testInfo.PageInfo.PageSize = 10
|
||||
err = Verify(testInfo, PageInfoVerify)
|
||||
if err != nil {
|
||||
t.Error("校验失败,未能正常通过检测")
|
||||
}
|
||||
}
|
||||
19
server/utils/verify.go
Normal file
19
server/utils/verify.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
var (
|
||||
IdVerify = Rules{"ID": []string{NotEmpty()}}
|
||||
ApiVerify = Rules{"Path": {NotEmpty()}, "Description": {NotEmpty()}, "ApiGroup": {NotEmpty()}, "Method": {NotEmpty()}}
|
||||
MenuVerify = Rules{"Path": {NotEmpty()}, "Name": {NotEmpty()}, "Component": {NotEmpty()}, "Sort": {Ge("0")}}
|
||||
MenuMetaVerify = Rules{"Title": {NotEmpty()}}
|
||||
LoginVerify = Rules{"Username": {NotEmpty()}, "Password": {NotEmpty()}}
|
||||
RegisterVerify = Rules{"Username": {NotEmpty()}, "NickName": {NotEmpty()}, "Password": {NotEmpty()}, "AuthorityId": {NotEmpty()}}
|
||||
PageInfoVerify = Rules{"Page": {NotEmpty()}, "PageSize": {NotEmpty()}}
|
||||
CustomerVerify = Rules{"CustomerName": {NotEmpty()}, "CustomerPhoneData": {NotEmpty()}}
|
||||
AutoCodeVerify = Rules{"Abbreviation": {NotEmpty()}, "StructName": {NotEmpty()}, "PackageName": {NotEmpty()}}
|
||||
AutoPackageVerify = Rules{"PackageName": {NotEmpty()}}
|
||||
AuthorityVerify = Rules{"AuthorityId": {NotEmpty()}, "AuthorityName": {NotEmpty()}}
|
||||
AuthorityIdVerify = Rules{"AuthorityId": {NotEmpty()}}
|
||||
OldAuthorityVerify = Rules{"OldAuthorityId": {NotEmpty()}}
|
||||
ChangePasswordVerify = Rules{"Password": {NotEmpty()}, "NewPassword": {NotEmpty()}}
|
||||
SetUserAuthorityVerify = Rules{"AuthorityId": {NotEmpty()}}
|
||||
)
|
||||
53
server/utils/zip.go
Normal file
53
server/utils/zip.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 解压
|
||||
func Unzip(zipFile string, destDir string) ([]string, error) {
|
||||
zipReader, err := zip.OpenReader(zipFile)
|
||||
var paths []string
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
for _, f := range zipReader.File {
|
||||
if strings.Contains(f.Name, "..") {
|
||||
return []string{}, fmt.Errorf("%s 文件名不合法", f.Name)
|
||||
}
|
||||
fpath := filepath.Join(destDir, f.Name)
|
||||
paths = append(paths, fpath)
|
||||
if f.FileInfo().IsDir() {
|
||||
os.MkdirAll(fpath, os.ModePerm)
|
||||
} else {
|
||||
if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
inFile, err := f.Open()
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer inFile.Close()
|
||||
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
_, err = io.Copy(outFile, inFile)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return paths, nil
|
||||
}
|
||||
Reference in New Issue
Block a user