🎉 初始化项目

This commit is contained in:
2026-03-03 06:05:51 +08:00
commit e1c70fe218
241 changed files with 148285 additions and 0 deletions

84
server/utils/app_jwt.go Normal file
View File

@@ -0,0 +1,84 @@
package utils
import (
"errors"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/golang-jwt/jwt/v5"
)
const (
UserTypeApp = "app" // 前台用户类型标识
)
// AppJWTClaims 前台用户 JWT Claims
type AppJWTClaims struct {
UserID uint `json:"userId"`
Username string `json:"username"`
UserType string `json:"userType"` // 用户类型标识
jwt.RegisteredClaims
}
// CreateAppToken 创建前台用户 Token有效期 7 天)
func CreateAppToken(userID uint, username string) (tokenString string, expiresAt int64, err error) {
// Token 有效期为 7 天
expiresTime := time.Now().Add(7 * 24 * time.Hour)
expiresAt = expiresTime.Unix()
claims := AppJWTClaims{
UserID: userID,
Username: username,
UserType: UserTypeApp,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: global.GVA_CONFIG.JWT.Issuer,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err = token.SignedString([]byte(global.GVA_CONFIG.JWT.SigningKey))
return
}
// CreateAppRefreshToken 创建前台用户刷新 Token有效期更长
func CreateAppRefreshToken(userID uint, username string) (tokenString string, expiresAt int64, err error) {
// 刷新 Token 有效期为 7 天
expiresTime := time.Now().Add(7 * 24 * time.Hour)
expiresAt = expiresTime.Unix()
claims := AppJWTClaims{
UserID: userID,
Username: username,
UserType: UserTypeApp,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: global.GVA_CONFIG.JWT.Issuer,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err = token.SignedString([]byte(global.GVA_CONFIG.JWT.SigningKey))
return
}
// ParseAppToken 解析前台用户 Token
func ParseAppToken(tokenString string) (*AppJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AppJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(global.GVA_CONFIG.JWT.SigningKey), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*AppJWTClaims); ok && token.Valid {
return claims, nil
}
return nil, errors.New("invalid token")
}

View File

@@ -0,0 +1,121 @@
package utils
import (
"errors"
"os"
"strconv"
"strings"
)
// 前端传来文件片与当前片为什么文件的第几片
// 后端拿到以后比较次分片是否上传 或者是否为不完全片
// 前端发送每片多大
// 前端告知是否为最后一片且是否完成
const (
breakpointDir = "./breakpointDir/"
finishDir = "./fileDir/"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: BreakPointContinue
//@description: 断点续传
//@param: content []byte, fileName string, contentNumber int, contentTotal int, fileMd5 string
//@return: error, string
func BreakPointContinue(content []byte, fileName string, contentNumber int, contentTotal int, fileMd5 string) (string, error) {
if strings.Contains(fileName, "..") || strings.Contains(fileMd5, "..") {
return "", errors.New("文件名或路径不合法")
}
path := breakpointDir + fileMd5 + "/"
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return path, err
}
pathC, err := makeFileContent(content, fileName, path, contentNumber)
return pathC, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CheckMd5
//@description: 检查Md5
//@param: content []byte, chunkMd5 string
//@return: CanUpload bool
func CheckMd5(content []byte, chunkMd5 string) (CanUpload bool) {
fileMd5 := MD5V(content)
if fileMd5 == chunkMd5 {
return true // 可以继续上传
} else {
return false // 切片不完整,废弃
}
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: makeFileContent
//@description: 创建切片内容
//@param: content []byte, fileName string, FileDir string, contentNumber int
//@return: string, error
func makeFileContent(content []byte, fileName string, FileDir string, contentNumber int) (string, error) {
if strings.Contains(fileName, "..") || strings.Contains(FileDir, "..") {
return "", errors.New("文件名或路径不合法")
}
path := FileDir + fileName + "_" + strconv.Itoa(contentNumber)
f, err := os.Create(path)
if err != nil {
return path, err
}
defer f.Close()
_, err = f.Write(content)
if err != nil {
return path, err
}
return path, nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: makeFileContent
//@description: 创建切片文件
//@param: fileName string, FileMd5 string
//@return: error, string
func MakeFile(fileName string, FileMd5 string) (string, error) {
if strings.Contains(fileName, "..") || strings.Contains(FileMd5, "..") {
return "", errors.New("文件名或路径不合法")
}
rd, err := os.ReadDir(breakpointDir + FileMd5)
if err != nil {
return finishDir + fileName, err
}
_ = os.MkdirAll(finishDir, os.ModePerm)
fd, err := os.OpenFile(finishDir+fileName, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644)
if err != nil {
return finishDir + fileName, err
}
defer fd.Close()
for k := range rd {
content, _ := os.ReadFile(breakpointDir + FileMd5 + "/" + fileName + "_" + strconv.Itoa(k))
_, err = fd.Write(content)
if err != nil {
_ = os.Remove(finishDir + fileName)
return finishDir + fileName, err
}
}
return finishDir + fileName, nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: RemoveChunk
//@description: 移除切片
//@param: FileMd5 string
//@return: error
func RemoveChunk(FileMd5 string) error {
if strings.Contains(FileMd5, "..") {
return errors.New("路径不合法")
}
err := os.RemoveAll(breakpointDir + FileMd5)
return err
}

View File

@@ -0,0 +1,61 @@
package captcha
import (
"context"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"go.uber.org/zap"
)
func NewDefaultRedisStore() *RedisStore {
return &RedisStore{
Expiration: time.Second * 180,
PreKey: "CAPTCHA_",
Context: context.TODO(),
}
}
type RedisStore struct {
Expiration time.Duration
PreKey string
Context context.Context
}
func (rs *RedisStore) UseWithCtx(ctx context.Context) *RedisStore {
if ctx == nil {
rs.Context = ctx
}
return rs
}
func (rs *RedisStore) Set(id string, value string) error {
err := global.GVA_REDIS.Set(rs.Context, rs.PreKey+id, value, rs.Expiration).Err()
if err != nil {
global.GVA_LOG.Error("RedisStoreSetError!", zap.Error(err))
return err
}
return nil
}
func (rs *RedisStore) Get(key string, clear bool) string {
val, err := global.GVA_REDIS.Get(rs.Context, key).Result()
if err != nil {
global.GVA_LOG.Error("RedisStoreGetError!", zap.Error(err))
return ""
}
if clear {
err := global.GVA_REDIS.Del(rs.Context, key).Err()
if err != nil {
global.GVA_LOG.Error("RedisStoreClearError!", zap.Error(err))
return ""
}
}
return val
}
func (rs *RedisStore) Verify(id, answer string, clear bool) bool {
key := rs.PreKey + id
v := rs.Get(key, clear)
return v == answer
}

View File

@@ -0,0 +1,52 @@
package utils
import (
"sync"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
gormadapter "github.com/casbin/gorm-adapter/v3"
"go.uber.org/zap"
)
var (
syncedCachedEnforcer *casbin.SyncedCachedEnforcer
once sync.Once
)
// GetCasbin 获取casbin实例
func GetCasbin() *casbin.SyncedCachedEnforcer {
once.Do(func() {
a, err := gormadapter.NewAdapterByDB(global.GVA_DB)
if err != nil {
zap.L().Error("适配数据库失败请检查casbin表是否为InnoDB引擎!", zap.Error(err))
return
}
text := `
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[role_definition]
g = _, _
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = r.sub == p.sub && keyMatch2(r.obj,p.obj) && r.act == p.act
`
m, err := model.NewModelFromString(text)
if err != nil {
zap.L().Error("字符串加载模型失败!", zap.Error(err))
return
}
syncedCachedEnforcer, _ = casbin.NewSyncedCachedEnforcer(m, a)
syncedCachedEnforcer.SetExpireTime(60 * 60)
_ = syncedCachedEnforcer.LoadPolicy()
})
return syncedCachedEnforcer
}

View File

@@ -0,0 +1,285 @@
package utils
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"image"
"image/png"
"io"
)
// CharacterCardV2 SillyTavern 角色卡 V2 格式
type CharacterCardV2 struct {
Spec string `json:"spec"`
SpecVersion string `json:"spec_version"`
Data CharacterCardV2Data `json:"data"`
}
type CharacterCardV2Data struct {
Name string `json:"name"`
Description string `json:"description"`
Personality string `json:"personality"`
Scenario string `json:"scenario"`
FirstMes string `json:"first_mes"`
MesExample string `json:"mes_example"`
CreatorNotes string `json:"creator_notes"`
SystemPrompt string `json:"system_prompt"`
PostHistoryInstructions string `json:"post_history_instructions"`
Tags []string `json:"tags"`
Creator string `json:"creator"`
CharacterVersion string `json:"character_version"`
AlternateGreetings []string `json:"alternate_greetings"`
CharacterBook map[string]interface{} `json:"character_book,omitempty"`
Extensions map[string]interface{} `json:"extensions"`
}
// ExtractCharacterFromPNG 从 PNG 图片中提取角色卡数据
func ExtractCharacterFromPNG(pngData []byte) (*CharacterCardV2, error) {
reader := bytes.NewReader(pngData)
// 验证 PNG 格式(解码但不保存图片)
_, err := png.Decode(reader)
if err != nil {
return nil, errors.New("无效的 PNG 文件")
}
// 重新读取以获取 tEXt chunks
reader.Seek(0, 0)
// 查找 tEXt chunk 中的 "chara" 字段
charaJSON, err := extractTextChunk(reader, "chara")
if err != nil {
return nil, errors.New("PNG 中没有找到角色卡数据")
}
// 尝试 Base64 解码
decodedJSON, err := base64.StdEncoding.DecodeString(charaJSON)
if err != nil {
// 如果不是 Base64直接使用原始 JSON
decodedJSON = []byte(charaJSON)
}
// 解析 JSON
var card CharacterCardV2
err = json.Unmarshal(decodedJSON, &card)
if err != nil {
return nil, errors.New("解析角色卡数据失败: " + err.Error())
}
return &card, nil
}
// extractTextChunk 从 PNG 中提取指定 key 的 tEXt chunk
func extractTextChunk(r io.Reader, key string) (string, error) {
// 跳过 PNG signature (8 bytes)
signature := make([]byte, 8)
if _, err := io.ReadFull(r, signature); err != nil {
return "", err
}
// 验证 PNG signature
expectedSig := []byte{137, 80, 78, 71, 13, 10, 26, 10}
if !bytes.Equal(signature, expectedSig) {
return "", errors.New("invalid PNG signature")
}
// 读取所有 chunks
for {
// 读取 chunk length (4 bytes)
lengthBytes := make([]byte, 4)
if _, err := io.ReadFull(r, lengthBytes); err != nil {
if err == io.EOF {
break
}
return "", err
}
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 |
uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
// 读取 chunk type (4 bytes)
chunkType := make([]byte, 4)
if _, err := io.ReadFull(r, chunkType); err != nil {
return "", err
}
// 读取 chunk data
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return "", err
}
// 读取 CRC (4 bytes)
crc := make([]byte, 4)
if _, err := io.ReadFull(r, crc); err != nil {
return "", err
}
// 检查是否是 tEXt chunk
if string(chunkType) == "tEXt" {
// tEXt chunk 格式: keyword\0text
nullIndex := bytes.IndexByte(data, 0)
if nullIndex == -1 {
continue
}
keyword := string(data[:nullIndex])
text := string(data[nullIndex+1:])
if keyword == key {
return text, nil
}
}
// IEND chunk 表示结束
if string(chunkType) == "IEND" {
break
}
}
return "", errors.New("text chunk not found")
}
// EmbedCharacterToPNG 将角色卡数据嵌入到 PNG 图片中
func EmbedCharacterToPNG(img image.Image, card *CharacterCardV2) ([]byte, error) {
// 序列化角色卡数据
cardJSON, err := json.Marshal(card)
if err != nil {
return nil, err
}
// Base64 编码
encodedJSON := base64.StdEncoding.EncodeToString(cardJSON)
// 创建一个 buffer 来写入 PNG
var buf bytes.Buffer
// 写入 PNG signature
buf.Write([]byte{137, 80, 78, 71, 13, 10, 26, 10})
// 编码原始图片到临时 buffer
var imgBuf bytes.Buffer
if err := png.Encode(&imgBuf, img); err != nil {
return nil, err
}
// 跳过原始 PNG 的 signature
imgData := imgBuf.Bytes()[8:]
// 将原始图片的 chunks 复制到输出,在 IEND 之前插入 tEXt chunk
r := bytes.NewReader(imgData)
for {
// 读取 chunk length
lengthBytes := make([]byte, 4)
if _, err := io.ReadFull(r, lengthBytes); err != nil {
if err == io.EOF {
break
}
return nil, err
}
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 |
uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
// 读取 chunk type
chunkType := make([]byte, 4)
if _, err := io.ReadFull(r, chunkType); err != nil {
return nil, err
}
// 读取 chunk data
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return nil, err
}
// 读取 CRC
crc := make([]byte, 4)
if _, err := io.ReadFull(r, crc); err != nil {
return nil, err
}
// 如果是 IEND chunk先写入 tEXt chunk
if string(chunkType) == "IEND" {
// 写入 tEXt chunk
writeTextChunk(&buf, "chara", encodedJSON)
}
// 写入原始 chunk
buf.Write(lengthBytes)
buf.Write(chunkType)
buf.Write(data)
buf.Write(crc)
if string(chunkType) == "IEND" {
break
}
}
return buf.Bytes(), nil
}
// writeTextChunk 写入 tEXt chunk
func writeTextChunk(w io.Writer, keyword, text string) error {
data := append([]byte(keyword), 0)
data = append(data, []byte(text)...)
// 写入 length
length := uint32(len(data))
lengthBytes := []byte{
byte(length >> 24),
byte(length >> 16),
byte(length >> 8),
byte(length),
}
w.Write(lengthBytes)
// 写入 type
w.Write([]byte("tEXt"))
// 写入 data
w.Write(data)
// 计算并写入 CRC
crcData := append([]byte("tEXt"), data...)
crc := calculateCRC(crcData)
crcBytes := []byte{
byte(crc >> 24),
byte(crc >> 16),
byte(crc >> 8),
byte(crc),
}
w.Write(crcBytes)
return nil
}
// calculateCRC 计算 CRC32
func calculateCRC(data []byte) uint32 {
crc := uint32(0xFFFFFFFF)
for _, b := range data {
crc ^= uint32(b)
for i := 0; i < 8; i++ {
if crc&1 != 0 {
crc = (crc >> 1) ^ 0xEDB88320
} else {
crc >>= 1
}
}
}
return crc ^ 0xFFFFFFFF
}
// ParseCharacterCardJSON 解析 JSON 格式的角色卡
func ParseCharacterCardJSON(jsonData []byte) (*CharacterCardV2, error) {
var card CharacterCardV2
err := json.Unmarshal(jsonData, &card)
if err != nil {
return nil, errors.New("解析角色卡 JSON 失败: " + err.Error())
}
return &card, nil
}

148
server/utils/claims.go Normal file
View File

@@ -0,0 +1,148 @@
package utils
import (
"net"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func ClearToken(c *gin.Context) {
// 增加cookie x-token 向来源的web添加
host, _, err := net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
if net.ParseIP(host) != nil {
c.SetCookie("x-token", "", -1, "/", "", false, false)
} else {
c.SetCookie("x-token", "", -1, "/", host, false, false)
}
}
func SetToken(c *gin.Context, token string, maxAge int) {
// 增加cookie x-token 向来源的web添加
host, _, err := net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
if net.ParseIP(host) != nil {
c.SetCookie("x-token", token, maxAge, "/", "", false, false)
} else {
c.SetCookie("x-token", token, maxAge, "/", host, false, false)
}
}
func GetToken(c *gin.Context) string {
token := c.Request.Header.Get("x-token")
if token == "" {
j := NewJWT()
token, _ = c.Cookie("x-token")
claims, err := j.ParseToken(token)
if err != nil {
global.GVA_LOG.Error("重新写入cookie token失败,未能成功解析token,请检查请求头是否存在x-token且claims是否为规定结构")
return token
}
SetToken(c, token, int(claims.ExpiresAt.Unix()-time.Now().Unix()))
}
return token
}
func GetClaims(c *gin.Context) (*systemReq.CustomClaims, error) {
token := GetToken(c)
j := NewJWT()
claims, err := j.ParseToken(token)
if err != nil {
global.GVA_LOG.Error("从Gin的Context中获取从jwt解析信息失败, 请检查请求头是否存在x-token且claims是否为规定结构")
}
return claims, err
}
// GetUserID 从Gin的Context中获取从jwt解析出来的用户ID
func GetUserID(c *gin.Context) uint {
if claims, exists := c.Get("claims"); !exists {
if cl, err := GetClaims(c); err != nil {
return 0
} else {
return cl.BaseClaims.ID
}
} else {
waitUse := claims.(*systemReq.CustomClaims)
return waitUse.BaseClaims.ID
}
}
// GetUserUuid 从Gin的Context中获取从jwt解析出来的用户UUID
func GetUserUuid(c *gin.Context) uuid.UUID {
if claims, exists := c.Get("claims"); !exists {
if cl, err := GetClaims(c); err != nil {
return uuid.UUID{}
} else {
return cl.UUID
}
} else {
waitUse := claims.(*systemReq.CustomClaims)
return waitUse.UUID
}
}
// GetUserAuthorityId 从Gin的Context中获取从jwt解析出来的用户角色id
func GetUserAuthorityId(c *gin.Context) uint {
if claims, exists := c.Get("claims"); !exists {
if cl, err := GetClaims(c); err != nil {
return 0
} else {
return cl.AuthorityId
}
} else {
waitUse := claims.(*systemReq.CustomClaims)
return waitUse.AuthorityId
}
}
// GetUserInfo 从Gin的Context中获取从jwt解析出来的用户角色id
func GetUserInfo(c *gin.Context) *systemReq.CustomClaims {
if claims, exists := c.Get("claims"); !exists {
if cl, err := GetClaims(c); err != nil {
return nil
} else {
return cl
}
} else {
waitUse := claims.(*systemReq.CustomClaims)
return waitUse
}
}
// GetUserName 从Gin的Context中获取从jwt解析出来的用户名
func GetUserName(c *gin.Context) string {
if claims, exists := c.Get("claims"); !exists {
if cl, err := GetClaims(c); err != nil {
return ""
} else {
return cl.Username
}
} else {
waitUse := claims.(*systemReq.CustomClaims)
return waitUse.Username
}
}
func LoginToken(user system.Login) (token string, claims systemReq.CustomClaims, err error) {
j := NewJWT()
claims = j.CreateClaims(systemReq.BaseClaims{
UUID: user.GetUUID(),
ID: user.GetUserId(),
NickName: user.GetNickname(),
Username: user.GetUsername(),
AuthorityId: user.GetAuthorityId(),
})
token, err = j.CreateToken(claims)
return
}

124
server/utils/directory.go Normal file
View File

@@ -0,0 +1,124 @@
package utils
import (
"errors"
"os"
"path/filepath"
"reflect"
"strings"
"git.echol.cn/loser/ai_proxy/server/global"
"go.uber.org/zap"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: PathExists
//@description: 文件目录是否存在
//@param: path string
//@return: bool, error
func PathExists(path string) (bool, error) {
fi, err := os.Stat(path)
if err == nil {
if fi.IsDir() {
return true, nil
}
return false, errors.New("存在同名文件")
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CreateDir
//@description: 批量创建文件夹
//@param: dirs ...string
//@return: err error
func CreateDir(dirs ...string) (err error) {
for _, v := range dirs {
exist, err := PathExists(v)
if err != nil {
return err
}
if !exist {
global.GVA_LOG.Debug("create directory" + v)
if err := os.MkdirAll(v, os.ModePerm); err != nil {
global.GVA_LOG.Error("create directory"+v, zap.Any(" error:", err))
return err
}
}
}
return err
}
//@author: [songzhibin97](https://github.com/songzhibin97)
//@function: FileMove
//@description: 文件移动供外部调用
//@param: src string, dst string(src: 源位置,绝对路径or相对路径, dst: 目标位置,绝对路径or相对路径,必须为文件夹)
//@return: err error
func FileMove(src string, dst string) (err error) {
if dst == "" {
return nil
}
src, err = filepath.Abs(src)
if err != nil {
return err
}
dst, err = filepath.Abs(dst)
if err != nil {
return err
}
revoke := false
dir := filepath.Dir(dst)
Redirect:
_, err = os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0o755)
if err != nil {
return err
}
if !revoke {
revoke = true
goto Redirect
}
}
return os.Rename(src, dst)
}
func DeLFile(filePath string) error {
return os.RemoveAll(filePath)
}
//@author: [songzhibin97](https://github.com/songzhibin97)
//@function: TrimSpace
//@description: 去除结构体空格
//@param: target interface (target: 目标结构体,传入必须是指针类型)
//@return: null
func TrimSpace(target interface{}) {
t := reflect.TypeOf(target)
if t.Kind() != reflect.Ptr {
return
}
t = t.Elem()
v := reflect.ValueOf(target).Elem()
for i := 0; i < t.NumField(); i++ {
switch v.Field(i).Kind() {
case reflect.String:
v.Field(i).SetString(strings.TrimSpace(v.Field(i).String()))
}
}
}
// FileExist 判断文件是否存在
func FileExist(path string) bool {
fi, err := os.Lstat(path)
if err == nil {
return !fi.IsDir()
}
return !os.IsNotExist(err)
}

126
server/utils/fmt_plus.go Normal file
View File

@@ -0,0 +1,126 @@
package utils
import (
"fmt"
"math/rand"
"reflect"
"strings"
"git.echol.cn/loser/ai_proxy/server/model/common"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: StructToMap
//@description: 利用反射将结构体转化为map
//@param: obj interface{}
//@return: map[string]interface{}
func StructToMap(obj interface{}) map[string]interface{} {
obj1 := reflect.TypeOf(obj)
obj2 := reflect.ValueOf(obj)
data := make(map[string]interface{})
for i := 0; i < obj1.NumField(); i++ {
if obj1.Field(i).Tag.Get("mapstructure") != "" {
data[obj1.Field(i).Tag.Get("mapstructure")] = obj2.Field(i).Interface()
} else {
data[obj1.Field(i).Name] = obj2.Field(i).Interface()
}
}
return data
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: ArrayToString
//@description: 将数组格式化为字符串
//@param: array []interface{}
//@return: string
func ArrayToString(array []interface{}) string {
return strings.Replace(strings.Trim(fmt.Sprint(array), "[]"), " ", ",", -1)
}
func Pointer[T any](in T) (out *T) {
return &in
}
func FirstUpper(s string) string {
if s == "" {
return ""
}
return strings.ToUpper(s[:1]) + s[1:]
}
func FirstLower(s string) string {
if s == "" {
return ""
}
return strings.ToLower(s[:1]) + s[1:]
}
// MaheHump 将字符串转换为驼峰命名
func MaheHump(s string) string {
words := strings.Split(s, "-")
for i := 1; i < len(words); i++ {
words[i] = strings.Title(words[i])
}
return strings.Join(words, "")
}
// HumpToUnderscore 将驼峰命名转换为下划线分割模式
func HumpToUnderscore(s string) string {
var result strings.Builder
for i, char := range s {
if i > 0 && char >= 'A' && char <= 'Z' {
// 在大写字母前添加下划线
result.WriteRune('_')
result.WriteRune(char - 'A' + 'a') // 转小写
} else {
result.WriteRune(char)
}
}
return strings.ToLower(result.String())
}
// RandomString 随机字符串
func RandomString(n int) string {
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
b := make([]rune, n)
for i := range b {
b[i] = letters[RandomInt(0, len(letters))]
}
return string(b)
}
func RandomInt(min, max int) int {
return min + rand.Intn(max-min)
}
// BuildTree 用于构建一个树形结构
func BuildTree[T common.TreeNode[T]](nodes []T) []T {
nodeMap := make(map[int]T)
// 创建一个基本map
for i := range nodes {
nodeMap[nodes[i].GetID()] = nodes[i]
}
for i := range nodes {
if nodes[i].GetParentID() != 0 {
parent := nodeMap[nodes[i].GetParentID()]
parent.SetChildren(nodes[i])
}
}
var rootNodes []T
for i := range nodeMap {
if nodeMap[i].GetParentID() == 0 {
rootNodes = append(rootNodes, nodeMap[i])
}
}
return rootNodes
}

32
server/utils/hash.go Normal file
View File

@@ -0,0 +1,32 @@
package utils
import (
"crypto/md5"
"encoding/hex"
"golang.org/x/crypto/bcrypt"
)
// BcryptHash 使用 bcrypt 对密码进行加密
func BcryptHash(password string) string {
bytes, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes)
}
// BcryptCheck 对比明文密码和数据库的哈希值
func BcryptCheck(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: MD5V
//@description: md5加密
//@param: str []byte
//@return: string
func MD5V(str []byte, b ...byte) string {
h := md5.New()
h.Write(str)
return hex.EncodeToString(h.Sum(b))
}

View File

@@ -0,0 +1,29 @@
package utils
import (
"strconv"
"strings"
"time"
)
func ParseDuration(d string) (time.Duration, error) {
d = strings.TrimSpace(d)
dr, err := time.ParseDuration(d)
if err == nil {
return dr, nil
}
if strings.Contains(d, "d") {
index := strings.Index(d, "d")
hour, _ := strconv.Atoi(d[:index])
dr = time.Hour * 24 * time.Duration(hour)
ndr, err := time.ParseDuration(d[index+1:])
if err != nil {
return dr, nil
}
return dr + ndr, nil
}
dv, err := strconv.ParseInt(d, 10, 64)
return time.Duration(dv), err
}

View File

@@ -0,0 +1,49 @@
package utils
import (
"testing"
"time"
)
func TestParseDuration(t *testing.T) {
type args struct {
d string
}
tests := []struct {
name string
args args
want time.Duration
wantErr bool
}{
{
name: "5h20m",
args: args{"5h20m"},
want: time.Hour*5 + 20*time.Minute,
wantErr: false,
},
{
name: "1d5h20m",
args: args{"1d5h20m"},
want: 24*time.Hour + time.Hour*5 + 20*time.Minute,
wantErr: false,
},
{
name: "1d",
args: args{"1d"},
want: 24 * time.Hour,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseDuration(tt.args.d)
if (err != nil) != tt.wantErr {
t.Errorf("ParseDuration() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseDuration() got = %v, want %v", got, tt.want)
}
})
}
}

34
server/utils/json.go Normal file
View File

@@ -0,0 +1,34 @@
package utils
import (
"encoding/json"
"strings"
)
func GetJSONKeys(jsonStr string) (keys []string, err error) {
// 使用json.Decoder以便在解析过程中记录键的顺序
dec := json.NewDecoder(strings.NewReader(jsonStr))
t, err := dec.Token()
if err != nil {
return nil, err
}
// 确保数据是一个对象
if t != json.Delim('{') {
return nil, err
}
for dec.More() {
t, err = dec.Token()
if err != nil {
return nil, err
}
keys = append(keys, t.(string))
// 解析值
var value interface{}
err = dec.Decode(&value)
if err != nil {
return nil, err
}
}
return keys, nil
}

53
server/utils/json_test.go Normal file
View File

@@ -0,0 +1,53 @@
package utils
import (
"fmt"
"testing"
)
func TestGetJSONKeys(t *testing.T) {
var jsonStr = `
{
"Name": "test",
"TableName": "test",
"TemplateID": "test",
"TemplateInfo": "test",
"Limit": 0
}`
keys, err := GetJSONKeys(jsonStr)
if err != nil {
t.Errorf("GetJSONKeys failed" + err.Error())
return
}
if len(keys) != 5 {
t.Errorf("GetJSONKeys failed" + err.Error())
return
}
if keys[0] != "Name" {
t.Errorf("GetJSONKeys failed" + err.Error())
return
}
if keys[1] != "TableName" {
t.Errorf("GetJSONKeys failed" + err.Error())
return
}
if keys[2] != "TemplateID" {
t.Errorf("GetJSONKeys failed" + err.Error())
return
}
if keys[3] != "TemplateInfo" {
t.Errorf("GetJSONKeys failed" + err.Error())
return
}
if keys[4] != "Limit" {
t.Errorf("GetJSONKeys failed" + err.Error())
return
}
fmt.Println(keys)
}

105
server/utils/jwt.go Normal file
View File

@@ -0,0 +1,105 @@
package utils
import (
"context"
"errors"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
jwt "github.com/golang-jwt/jwt/v5"
)
type JWT struct {
SigningKey []byte
}
var (
TokenValid = errors.New("未知错误")
TokenExpired = errors.New("token已过期")
TokenNotValidYet = errors.New("token尚未激活")
TokenMalformed = errors.New("这不是一个token")
TokenSignatureInvalid = errors.New("无效签名")
TokenInvalid = errors.New("无法处理此token")
)
func NewJWT() *JWT {
return &JWT{
[]byte(global.GVA_CONFIG.JWT.SigningKey),
}
}
func (j *JWT) CreateClaims(baseClaims request.BaseClaims) request.CustomClaims {
bf, _ := ParseDuration(global.GVA_CONFIG.JWT.BufferTime)
ep, _ := ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
claims := request.CustomClaims{
BaseClaims: baseClaims,
BufferTime: int64(bf / time.Second), // 缓冲时间1天 缓冲时间内会获得新的token刷新令牌 此时一个用户会存在两个有效令牌 但是前端只留一个 另一个会丢失
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"GVA"}, // 受众
NotBefore: jwt.NewNumericDate(time.Now().Add(-1000)), // 签名生效时间
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ep)), // 过期时间 7天 配置文件
Issuer: global.GVA_CONFIG.JWT.Issuer, // 签名的发行者
},
}
return claims
}
// CreateToken 创建一个token
func (j *JWT) CreateToken(claims request.CustomClaims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(j.SigningKey)
}
// CreateTokenByOldToken 旧token 换新token 使用归并回源避免并发问题
func (j *JWT) CreateTokenByOldToken(oldToken string, claims request.CustomClaims) (string, error) {
v, err, _ := global.GVA_Concurrency_Control.Do("JWT:"+oldToken, func() (interface{}, error) {
return j.CreateToken(claims)
})
return v.(string), err
}
// ParseToken 解析 token
func (j *JWT) ParseToken(tokenString string) (*request.CustomClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &request.CustomClaims{}, func(token *jwt.Token) (i interface{}, e error) {
return j.SigningKey, nil
})
if err != nil {
switch {
case errors.Is(err, jwt.ErrTokenExpired):
return nil, TokenExpired
case errors.Is(err, jwt.ErrTokenMalformed):
return nil, TokenMalformed
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
return nil, TokenSignatureInvalid
case errors.Is(err, jwt.ErrTokenNotValidYet):
return nil, TokenNotValidYet
default:
return nil, TokenInvalid
}
}
if token != nil {
if claims, ok := token.Claims.(*request.CustomClaims); ok && token.Valid {
return claims, nil
}
}
return nil, TokenValid
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetRedisJWT
//@description: jwt存入redis并设置过期时间
//@param: jwt string, userName string
//@return: err error
func SetRedisJWT(jwt string, userName string) (err error) {
// 此处过期时间等于jwt过期时间
dr, err := ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
if err != nil {
return err
}
timer := dr
err = global.GVA_REDIS.Set(context.Background(), userName, jwt, timer).Err()
return err
}

26
server/utils/param.go Normal file
View File

@@ -0,0 +1,26 @@
package utils
import (
"strconv"
"github.com/gin-gonic/gin"
)
// StringToUint 字符串转 uint
func StringToUint(s string) (uint, error) {
val, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return 0, err
}
return uint(val), nil
}
// GetIntQuery 获取查询参数int类型
func GetIntQuery(c *gin.Context, key string, defaultValue int) int {
if val := c.Query(key); val != "" {
if intVal, err := strconv.Atoi(val); err == nil {
return intVal
}
}
return defaultValue
}

View File

@@ -0,0 +1,22 @@
package utils
import (
"strconv"
"github.com/gin-gonic/gin"
)
// GetUintParam 从 URL 参数中获取 uint 值
func GetUintParam(c *gin.Context, key string) uint {
val := c.Param(key)
if val == "" {
return 0
}
uintVal, err := strconv.ParseUint(val, 10, 32)
if err != nil {
return 0
}
return uint(uintVal)
}

View File

@@ -0,0 +1,18 @@
package plugin
import (
"github.com/gin-gonic/gin"
)
const (
OnlyFuncName = "Plugin"
)
// Plugin 插件模式接口化
type Plugin interface {
// Register 注册路由
Register(group *gin.RouterGroup)
// RouterPath 用户返回注册路由
RouterPath() string
}

View File

@@ -0,0 +1,11 @@
package plugin
import (
"github.com/gin-gonic/gin"
)
// Plugin 插件模式接口化v2
type Plugin interface {
// Register 注册路由
Register(group *gin.Engine)
}

View File

@@ -0,0 +1,27 @@
package plugin
import "sync"
var (
registryMu sync.RWMutex
registry []Plugin
)
// Register records a plugin for auto initialization.
func Register(p Plugin) {
if p == nil {
return
}
registryMu.Lock()
registry = append(registry, p)
registryMu.Unlock()
}
// Registered returns a snapshot of all registered plugins.
func Registered() []Plugin {
registryMu.RLock()
defer registryMu.RUnlock()
out := make([]Plugin, len(registry))
copy(out, registry)
return out
}

15
server/utils/random.go Normal file
View File

@@ -0,0 +1,15 @@
package utils
import (
"crypto/rand"
"encoding/hex"
)
// GenerateRandomString 生成随机字符串
func GenerateRandomString(length int) (string, error) {
bytes := make([]byte, length/2+1)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes)[:length], nil
}

View File

@@ -0,0 +1,62 @@
package request
import (
"bytes"
"encoding/json"
"net/http"
"net/url"
)
func HttpRequest(
urlStr string,
method string,
headers map[string]string,
params map[string]string,
data any) (*http.Response, error) {
// 创建URL
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
// 添加查询参数
query := u.Query()
for k, v := range params {
query.Set(k, v)
}
u.RawQuery = query.Encode()
// 将数据编码为JSON
buf := new(bytes.Buffer)
if data != nil {
b, err := json.Marshal(data)
if err != nil {
return nil, err
}
buf = bytes.NewBuffer(b)
}
// 创建请求
req, err := http.NewRequest(method, u.String(), buf)
if err != nil {
return nil, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
if data != nil {
req.Header.Set("Content-Type", "application/json")
}
// 发送请求
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
// 返回响应,让调用者处理
return resp, nil
}

127
server/utils/server.go Normal file
View File

@@ -0,0 +1,127 @@
package utils
import (
"runtime"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/mem"
)
const (
B = 1
KB = 1024 * B
MB = 1024 * KB
GB = 1024 * MB
)
type Server struct {
Os Os `json:"os"`
Cpu Cpu `json:"cpu"`
Ram Ram `json:"ram"`
Disk []Disk `json:"disk"`
}
type Os struct {
GOOS string `json:"goos"`
NumCPU int `json:"numCpu"`
Compiler string `json:"compiler"`
GoVersion string `json:"goVersion"`
NumGoroutine int `json:"numGoroutine"`
}
type Cpu struct {
Cpus []float64 `json:"cpus"`
Cores int `json:"cores"`
}
type Ram struct {
UsedMB int `json:"usedMb"`
TotalMB int `json:"totalMb"`
UsedPercent int `json:"usedPercent"`
}
type Disk struct {
MountPoint string `json:"mountPoint"`
UsedMB int `json:"usedMb"`
UsedGB int `json:"usedGb"`
TotalMB int `json:"totalMb"`
TotalGB int `json:"totalGb"`
UsedPercent int `json:"usedPercent"`
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: InitCPU
//@description: OS信息
//@return: o Os, err error
func InitOS() (o Os) {
o.GOOS = runtime.GOOS
o.NumCPU = runtime.NumCPU()
o.Compiler = runtime.Compiler
o.GoVersion = runtime.Version()
o.NumGoroutine = runtime.NumGoroutine()
return o
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: InitCPU
//@description: CPU信息
//@return: c Cpu, err error
func InitCPU() (c Cpu, err error) {
if cores, err := cpu.Counts(false); err != nil {
return c, err
} else {
c.Cores = cores
}
if cpus, err := cpu.Percent(time.Duration(200)*time.Millisecond, true); err != nil {
return c, err
} else {
c.Cpus = cpus
}
return c, nil
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: InitRAM
//@description: RAM信息
//@return: r Ram, err error
func InitRAM() (r Ram, err error) {
if u, err := mem.VirtualMemory(); err != nil {
return r, err
} else {
r.UsedMB = int(u.Used) / MB
r.TotalMB = int(u.Total) / MB
r.UsedPercent = int(u.UsedPercent)
}
return r, nil
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: InitDisk
//@description: 硬盘信息
//@return: d Disk, err error
func InitDisk() (d []Disk, err error) {
for i := range global.GVA_CONFIG.DiskList {
mp := global.GVA_CONFIG.DiskList[i].MountPoint
if u, err := disk.Usage(mp); err != nil {
return d, err
} else {
d = append(d, Disk{
MountPoint: mp,
UsedMB: int(u.Used) / MB,
UsedGB: int(u.Used) / GB,
TotalMB: int(u.Total) / MB,
TotalGB: int(u.Total) / GB,
UsedPercent: int(u.UsedPercent),
})
}
}
return d, nil
}

View File

@@ -0,0 +1,79 @@
package stacktrace
import (
"regexp"
"strconv"
"strings"
)
// Frame 表示一次栈帧解析结果
type Frame struct {
File string
Line int
Func string
}
var fileLineRe = regexp.MustCompile(`\s*(.+\.go):(\d+)\s*$`)
// FindFinalCaller 从 zap 的 entry.Stack 文本中,解析“最终业务调用方”的文件与行号
// 策略:自顶向下解析,优先选择第一条项目代码帧,过滤第三方库/标准库/框架中间件
func FindFinalCaller(stack string) (Frame, bool) {
if stack == "" {
return Frame{}, false
}
lines := strings.Split(stack, "\n")
var currFunc string
for i := 0; i < len(lines); i++ {
line := strings.TrimSpace(lines[i])
if line == "" {
continue
}
if m := fileLineRe.FindStringSubmatch(line); m != nil {
file := m[1]
ln, _ := strconv.Atoi(m[2])
if shouldSkip(file) {
// 跳过此帧,同时重置函数名以避免错误配对
currFunc = ""
continue
}
return Frame{File: file, Line: ln, Func: currFunc}, true
}
// 记录函数名行,下一行通常是文件:行
currFunc = line
}
return Frame{}, false
}
func shouldSkip(file string) bool {
// 第三方库与 Go 模块缓存
if strings.Contains(file, "/go/pkg/mod/") {
return true
}
if strings.Contains(file, "/go.uber.org/") {
return true
}
if strings.Contains(file, "/gorm.io/") {
return true
}
// 标准库
if strings.Contains(file, "/go/go") && strings.Contains(file, "/src/") { // e.g. /Users/name/go/go1.24.2/src/net/http/server.go
return true
}
// 框架内不需要作为最终调用方的路径
if strings.Contains(file, "/server/core/zap.go") {
return true
}
if strings.Contains(file, "/server/core/") {
return true
}
if strings.Contains(file, "/server/utils/errorhook/") {
return true
}
if strings.Contains(file, "/server/middleware/") {
return true
}
if strings.Contains(file, "/server/router/") {
return true
}
return false
}

View File

@@ -0,0 +1,34 @@
package utils
import (
"sync"
)
// SystemEvents 定义系统级事件处理
type SystemEvents struct {
reloadHandlers []func() error
mu sync.RWMutex
}
// 全局事件管理器
var GlobalSystemEvents = &SystemEvents{}
// RegisterReloadHandler 注册系统重载处理函数
func (e *SystemEvents) RegisterReloadHandler(handler func() error) {
e.mu.Lock()
defer e.mu.Unlock()
e.reloadHandlers = append(e.reloadHandlers, handler)
}
// TriggerReload 触发所有注册的重载处理函数
func (e *SystemEvents) TriggerReload() error {
e.mu.RLock()
defer e.mu.RUnlock()
for _, handler := range e.reloadHandlers {
if err := handler(); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,230 @@
package timer
import (
"sync"
"github.com/robfig/cron/v3"
)
type Timer interface {
// 寻找所有Cron
FindCronList() map[string]*taskManager
// 添加Task 方法形式以秒的形式加入
AddTaskByFuncWithSecond(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) // 添加Task Func以秒的形式加入
// 添加Task 接口形式以秒的形式加入
AddTaskByJobWithSeconds(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error)
// 通过函数的方法添加任务
AddTaskByFunc(cronName string, spec string, task func(), taskName string, option ...cron.Option) (cron.EntryID, error)
// 通过接口的方法添加任务 要实现一个带有 Run方法的接口触发
AddTaskByJob(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error)
// 获取对应taskName的cron 可能会为空
FindCron(cronName string) (*taskManager, bool)
// 指定cron开始执行
StartCron(cronName string)
// 指定cron停止执行
StopCron(cronName string)
// 查找指定cron下的指定task
FindTask(cronName string, taskName string) (*task, bool)
// 根据id删除指定cron下的指定task
RemoveTask(cronName string, id int)
// 根据taskName删除指定cron下的指定task
RemoveTaskByName(cronName string, taskName string)
// 清理掉指定cronName
Clear(cronName string)
// 停止所有的cron
Close()
}
type task struct {
EntryID cron.EntryID
Spec string
TaskName string
}
type taskManager struct {
corn *cron.Cron
tasks map[cron.EntryID]*task
}
// timer 定时任务管理
type timer struct {
cronList map[string]*taskManager
sync.Mutex
}
// AddTaskByFunc 通过函数的方法添加任务
func (t *timer) AddTaskByFunc(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) {
t.Lock()
defer t.Unlock()
if _, ok := t.cronList[cronName]; !ok {
tasks := make(map[cron.EntryID]*task)
t.cronList[cronName] = &taskManager{
corn: cron.New(option...),
tasks: tasks,
}
}
id, err := t.cronList[cronName].corn.AddFunc(spec, fun)
t.cronList[cronName].corn.Start()
t.cronList[cronName].tasks[id] = &task{
EntryID: id,
Spec: spec,
TaskName: taskName,
}
return id, err
}
// AddTaskByFuncWithSecond 通过函数的方法使用WithSeconds添加任务
func (t *timer) AddTaskByFuncWithSecond(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) {
t.Lock()
defer t.Unlock()
option = append(option, cron.WithSeconds())
if _, ok := t.cronList[cronName]; !ok {
tasks := make(map[cron.EntryID]*task)
t.cronList[cronName] = &taskManager{
corn: cron.New(option...),
tasks: tasks,
}
}
id, err := t.cronList[cronName].corn.AddFunc(spec, fun)
t.cronList[cronName].corn.Start()
t.cronList[cronName].tasks[id] = &task{
EntryID: id,
Spec: spec,
TaskName: taskName,
}
return id, err
}
// AddTaskByJob 通过接口的方法添加任务
func (t *timer) AddTaskByJob(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error) {
t.Lock()
defer t.Unlock()
if _, ok := t.cronList[cronName]; !ok {
tasks := make(map[cron.EntryID]*task)
t.cronList[cronName] = &taskManager{
corn: cron.New(option...),
tasks: tasks,
}
}
id, err := t.cronList[cronName].corn.AddJob(spec, job)
t.cronList[cronName].corn.Start()
t.cronList[cronName].tasks[id] = &task{
EntryID: id,
Spec: spec,
TaskName: taskName,
}
return id, err
}
// AddTaskByJobWithSeconds 通过接口的方法添加任务
func (t *timer) AddTaskByJobWithSeconds(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error) {
t.Lock()
defer t.Unlock()
option = append(option, cron.WithSeconds())
if _, ok := t.cronList[cronName]; !ok {
tasks := make(map[cron.EntryID]*task)
t.cronList[cronName] = &taskManager{
corn: cron.New(option...),
tasks: tasks,
}
}
id, err := t.cronList[cronName].corn.AddJob(spec, job)
t.cronList[cronName].corn.Start()
t.cronList[cronName].tasks[id] = &task{
EntryID: id,
Spec: spec,
TaskName: taskName,
}
return id, err
}
// FindCron 获取对应cronName的cron 可能会为空
func (t *timer) FindCron(cronName string) (*taskManager, bool) {
t.Lock()
defer t.Unlock()
v, ok := t.cronList[cronName]
return v, ok
}
// FindTask 获取对应cronName的cron 可能会为空
func (t *timer) FindTask(cronName string, taskName string) (*task, bool) {
t.Lock()
defer t.Unlock()
v, ok := t.cronList[cronName]
if !ok {
return nil, ok
}
for _, t2 := range v.tasks {
if t2.TaskName == taskName {
return t2, true
}
}
return nil, false
}
// FindCronList 获取所有的任务列表
func (t *timer) FindCronList() map[string]*taskManager {
t.Lock()
defer t.Unlock()
return t.cronList
}
// StartCron 开始任务
func (t *timer) StartCron(cronName string) {
t.Lock()
defer t.Unlock()
if v, ok := t.cronList[cronName]; ok {
v.corn.Start()
}
}
// StopCron 停止任务
func (t *timer) StopCron(cronName string) {
t.Lock()
defer t.Unlock()
if v, ok := t.cronList[cronName]; ok {
v.corn.Stop()
}
}
// RemoveTask 从cronName 删除指定任务
func (t *timer) RemoveTask(cronName string, id int) {
t.Lock()
defer t.Unlock()
if v, ok := t.cronList[cronName]; ok {
v.corn.Remove(cron.EntryID(id))
delete(v.tasks, cron.EntryID(id))
}
}
// RemoveTaskByName 从cronName 使用taskName 删除指定任务
func (t *timer) RemoveTaskByName(cronName string, taskName string) {
fTask, ok := t.FindTask(cronName, taskName)
if !ok {
return
}
t.RemoveTask(cronName, int(fTask.EntryID))
}
// Clear 清除任务
func (t *timer) Clear(cronName string) {
t.Lock()
defer t.Unlock()
if v, ok := t.cronList[cronName]; ok {
v.corn.Stop()
delete(t.cronList, cronName)
}
}
// Close 释放资源
func (t *timer) Close() {
t.Lock()
defer t.Unlock()
for _, v := range t.cronList {
v.corn.Stop()
}
}
func NewTimerTask() Timer {
return &timer{cronList: make(map[string]*taskManager)}
}

View File

@@ -0,0 +1,72 @@
package timer
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var job = mockJob{}
type mockJob struct{}
func (job mockJob) Run() {
mockFunc()
}
func mockFunc() {
time.Sleep(time.Second)
fmt.Println("1s...")
}
func TestNewTimerTask(t *testing.T) {
tm := NewTimerTask()
_tm := tm.(*timer)
{
_, err := tm.AddTaskByFunc("func", "@every 1s", mockFunc, "测试mockfunc")
assert.Nil(t, err)
_, ok := _tm.cronList["func"]
if !ok {
t.Error("no find func")
}
}
{
_, err := tm.AddTaskByJob("job", "@every 1s", job, "测试job mockfunc")
assert.Nil(t, err)
_, ok := _tm.cronList["job"]
if !ok {
t.Error("no find job")
}
}
{
_, ok := tm.FindCron("func")
if !ok {
t.Error("no find func")
}
_, ok = tm.FindCron("job")
if !ok {
t.Error("no find job")
}
_, ok = tm.FindCron("none")
if ok {
t.Error("find none")
}
}
{
tm.Clear("func")
_, ok := tm.FindCron("func")
if ok {
t.Error("find func")
}
}
{
a := tm.FindCronList()
b, c := tm.FindCron("job")
fmt.Println(a, b, c)
}
}

View File

@@ -0,0 +1,75 @@
package upload
import (
"errors"
"mime/multipart"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/aliyun/aliyun-oss-go-sdk/oss"
"go.uber.org/zap"
)
type AliyunOSS struct{}
func (*AliyunOSS) UploadFile(file *multipart.FileHeader) (string, string, error) {
bucket, err := NewBucket()
if err != nil {
global.GVA_LOG.Error("function AliyunOSS.NewBucket() Failed", zap.Any("err", err.Error()))
return "", "", errors.New("function AliyunOSS.NewBucket() Failed, err:" + err.Error())
}
// 读取本地文件。
f, openError := file.Open()
if openError != nil {
global.GVA_LOG.Error("function file.Open() Failed", zap.Any("err", openError.Error()))
return "", "", errors.New("function file.Open() Failed, err:" + openError.Error())
}
defer f.Close() // 创建文件 defer 关闭
// 上传阿里云路径 文件名格式 自己可以改 建议保证唯一性
// yunFileTmpPath := filepath.Join("uploads", time.Now().Format("2006-01-02")) + "/" + file.Filename
yunFileTmpPath := global.GVA_CONFIG.AliyunOSS.BasePath + "/" + "uploads" + "/" + time.Now().Format("2006-01-02") + "/" + file.Filename
// 上传文件流。
err = bucket.PutObject(yunFileTmpPath, f)
if err != nil {
global.GVA_LOG.Error("function formUploader.Put() Failed", zap.Any("err", err.Error()))
return "", "", errors.New("function formUploader.Put() Failed, err:" + err.Error())
}
return global.GVA_CONFIG.AliyunOSS.BucketUrl + "/" + yunFileTmpPath, yunFileTmpPath, nil
}
func (*AliyunOSS) DeleteFile(key string) error {
bucket, err := NewBucket()
if err != nil {
global.GVA_LOG.Error("function AliyunOSS.NewBucket() Failed", zap.Any("err", err.Error()))
return errors.New("function AliyunOSS.NewBucket() Failed, err:" + err.Error())
}
// 删除单个文件。objectName表示删除OSS文件时需要指定包含文件后缀在内的完整路径例如abc/efg/123.jpg。
// 如需删除文件夹请将objectName设置为对应的文件夹名称。如果文件夹非空则需要将文件夹下的所有object删除后才能删除该文件夹。
err = bucket.DeleteObject(key)
if err != nil {
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
}
return nil
}
func NewBucket() (*oss.Bucket, error) {
// 创建OSSClient实例。
client, err := oss.New(global.GVA_CONFIG.AliyunOSS.Endpoint, global.GVA_CONFIG.AliyunOSS.AccessKeyId, global.GVA_CONFIG.AliyunOSS.AccessKeySecret)
if err != nil {
return nil, err
}
// 获取存储空间。
bucket, err := client.Bucket(global.GVA_CONFIG.AliyunOSS.BucketName)
if err != nil {
return nil, err
}
return bucket, nil
}

View File

@@ -0,0 +1,98 @@
package upload
import (
"errors"
"fmt"
"mime/multipart"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"go.uber.org/zap"
)
type AwsS3 struct{}
//@author: [WqyJh](https://github.com/WqyJh)
//@object: *AwsS3
//@function: UploadFile
//@description: Upload file to Aws S3 using aws-sdk-go. See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/s3-example-basic-bucket-operations.html#s3-examples-bucket-ops-upload-file-to-bucket
//@param: file *multipart.FileHeader
//@return: string, string, error
func (*AwsS3) UploadFile(file *multipart.FileHeader) (string, string, error) {
session := newSession()
uploader := s3manager.NewUploader(session)
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename)
filename := global.GVA_CONFIG.AwsS3.PathPrefix + "/" + fileKey
f, openError := file.Open()
if openError != nil {
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
}
defer f.Close() // 创建文件 defer 关闭
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(global.GVA_CONFIG.AwsS3.Bucket),
Key: aws.String(filename),
Body: f,
ContentType: aws.String(file.Header.Get("Content-Type")),
})
if err != nil {
global.GVA_LOG.Error("function uploader.Upload() failed", zap.Any("err", err.Error()))
return "", "", err
}
return global.GVA_CONFIG.AwsS3.BaseURL + "/" + filename, fileKey, nil
}
//@author: [WqyJh](https://github.com/WqyJh)
//@object: *AwsS3
//@function: DeleteFile
//@description: Delete file from Aws S3 using aws-sdk-go. See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/s3-example-basic-bucket-operations.html#s3-examples-bucket-ops-delete-bucket-item
//@param: file *multipart.FileHeader
//@return: string, string, error
func (*AwsS3) DeleteFile(key string) error {
session := newSession()
svc := s3.New(session)
filename := global.GVA_CONFIG.AwsS3.PathPrefix + "/" + key
bucket := global.GVA_CONFIG.AwsS3.Bucket
_, err := svc.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(filename),
})
if err != nil {
global.GVA_LOG.Error("function svc.DeleteObject() failed", zap.Any("err", err.Error()))
return errors.New("function svc.DeleteObject() failed, err:" + err.Error())
}
_ = svc.WaitUntilObjectNotExists(&s3.HeadObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(filename),
})
return nil
}
// newSession Create S3 session
func newSession() *session.Session {
sess, _ := session.NewSession(&aws.Config{
Region: aws.String(global.GVA_CONFIG.AwsS3.Region),
Endpoint: aws.String(global.GVA_CONFIG.AwsS3.Endpoint), //minio在这里设置地址,可以兼容
S3ForcePathStyle: aws.Bool(global.GVA_CONFIG.AwsS3.S3ForcePathStyle),
DisableSSL: aws.Bool(global.GVA_CONFIG.AwsS3.DisableSSL),
Credentials: credentials.NewStaticCredentials(
global.GVA_CONFIG.AwsS3.SecretID,
global.GVA_CONFIG.AwsS3.SecretKey,
"",
),
})
return sess
}

View File

@@ -0,0 +1,85 @@
package upload
import (
"errors"
"fmt"
"mime/multipart"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"go.uber.org/zap"
)
type CloudflareR2 struct{}
func (c *CloudflareR2) UploadFile(file *multipart.FileHeader) (fileUrl string, fileName string, err error) {
session := c.newSession()
client := s3manager.NewUploader(session)
fileKey := fmt.Sprintf("%d_%s", time.Now().Unix(), file.Filename)
fileName = fmt.Sprintf("%s/%s", global.GVA_CONFIG.CloudflareR2.Path, fileKey)
f, openError := file.Open()
if openError != nil {
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
}
defer f.Close() // 创建文件 defer 关闭
input := &s3manager.UploadInput{
Bucket: aws.String(global.GVA_CONFIG.CloudflareR2.Bucket),
Key: aws.String(fileName),
Body: f,
}
_, err = client.Upload(input)
if err != nil {
global.GVA_LOG.Error("function uploader.Upload() failed", zap.Any("err", err.Error()))
return "", "", err
}
return fmt.Sprintf("%s/%s", global.GVA_CONFIG.CloudflareR2.BaseURL,
fileName),
fileKey,
nil
}
func (c *CloudflareR2) DeleteFile(key string) error {
session := newSession()
svc := s3.New(session)
filename := global.GVA_CONFIG.CloudflareR2.Path + "/" + key
bucket := global.GVA_CONFIG.CloudflareR2.Bucket
_, err := svc.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(filename),
})
if err != nil {
global.GVA_LOG.Error("function svc.DeleteObject() failed", zap.Any("err", err.Error()))
return errors.New("function svc.DeleteObject() failed, err:" + err.Error())
}
_ = svc.WaitUntilObjectNotExists(&s3.HeadObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(filename),
})
return nil
}
func (*CloudflareR2) newSession() *session.Session {
endpoint := fmt.Sprintf("%s.r2.cloudflarestorage.com", global.GVA_CONFIG.CloudflareR2.AccountID)
return session.Must(session.NewSession(&aws.Config{
Region: aws.String("auto"),
Endpoint: aws.String(endpoint),
Credentials: credentials.NewStaticCredentials(
global.GVA_CONFIG.CloudflareR2.AccessKeyID,
global.GVA_CONFIG.CloudflareR2.SecretAccessKey,
"",
),
}))
}

View File

@@ -0,0 +1,109 @@
package upload
import (
"errors"
"io"
"mime/multipart"
"os"
"path/filepath"
"strings"
"sync"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/utils"
"go.uber.org/zap"
)
var mu sync.Mutex
type Local struct{}
//@author: [piexlmax](https://github.com/piexlmax)
//@author: [ccfish86](https://github.com/ccfish86)
//@author: [SliverHorn](https://github.com/SliverHorn)
//@object: *Local
//@function: UploadFile
//@description: 上传文件
//@param: file *multipart.FileHeader
//@return: string, string, error
func (*Local) UploadFile(file *multipart.FileHeader) (string, string, error) {
// 读取文件后缀
ext := filepath.Ext(file.Filename)
// 读取文件名并加密
name := strings.TrimSuffix(file.Filename, ext)
name = utils.MD5V([]byte(name))
// 拼接新文件名
filename := name + "_" + time.Now().Format("20060102150405") + ext
// 尝试创建此路径
mkdirErr := os.MkdirAll(global.GVA_CONFIG.Local.StorePath, os.ModePerm)
if mkdirErr != nil {
global.GVA_LOG.Error("function os.MkdirAll() failed", zap.Any("err", mkdirErr.Error()))
return "", "", errors.New("function os.MkdirAll() failed, err:" + mkdirErr.Error())
}
// 拼接路径和文件名
p := global.GVA_CONFIG.Local.StorePath + "/" + filename
filepath := global.GVA_CONFIG.Local.Path + "/" + filename
f, openError := file.Open() // 读取文件
if openError != nil {
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
}
defer f.Close() // 创建文件 defer 关闭
out, createErr := os.Create(p)
if createErr != nil {
global.GVA_LOG.Error("function os.Create() failed", zap.Any("err", createErr.Error()))
return "", "", errors.New("function os.Create() failed, err:" + createErr.Error())
}
defer out.Close() // 创建文件 defer 关闭
_, copyErr := io.Copy(out, f) // 传输(拷贝)文件
if copyErr != nil {
global.GVA_LOG.Error("function io.Copy() failed", zap.Any("err", copyErr.Error()))
return "", "", errors.New("function io.Copy() failed, err:" + copyErr.Error())
}
return filepath, filename, nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@author: [ccfish86](https://github.com/ccfish86)
//@author: [SliverHorn](https://github.com/SliverHorn)
//@object: *Local
//@function: DeleteFile
//@description: 删除文件
//@param: key string
//@return: error
func (*Local) DeleteFile(key string) error {
// 检查 key 是否为空
if key == "" {
return errors.New("key不能为空")
}
// 验证 key 是否包含非法字符或尝试访问存储路径之外的文件
if strings.Contains(key, "..") || strings.ContainsAny(key, `\/:*?"<>|`) {
return errors.New("非法的key")
}
p := filepath.Join(global.GVA_CONFIG.Local.StorePath, key)
// 检查文件是否存在
if _, err := os.Stat(p); os.IsNotExist(err) {
return errors.New("文件不存在")
}
// 使用文件锁防止并发删除
mu.Lock()
defer mu.Unlock()
err := os.Remove(p)
if err != nil {
return errors.New("文件删除失败: " + err.Error())
}
return nil
}

View File

@@ -0,0 +1,106 @@
package upload
import (
"bytes"
"context"
"errors"
"io"
"mime"
"mime/multipart"
"path/filepath"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"go.uber.org/zap"
)
var MinioClient *Minio // 优化性能,但是不支持动态配置
type Minio struct {
Client *minio.Client
bucket string
}
func GetMinio(endpoint, accessKeyID, secretAccessKey, bucketName string, useSSL bool) (*Minio, error) {
if MinioClient != nil {
return MinioClient, nil
}
// Initialize minio client object.
minioClient, err := minio.New(endpoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL, // Set to true if using https
})
if err != nil {
return nil, err
}
// 尝试创建bucket
err = minioClient.MakeBucket(context.Background(), bucketName, minio.MakeBucketOptions{})
if err != nil {
// Check to see if we already own this bucket (which happens if you run this twice)
exists, errBucketExists := minioClient.BucketExists(context.Background(), bucketName)
if errBucketExists == nil && exists {
// log.Printf("We already own %s\n", bucketName)
} else {
return nil, err
}
}
MinioClient = &Minio{Client: minioClient, bucket: bucketName}
return MinioClient, nil
}
func (m *Minio) UploadFile(file *multipart.FileHeader) (filePathres, key string, uploadErr error) {
f, openError := file.Open()
// mutipart.File to os.File
if openError != nil {
global.GVA_LOG.Error("function file.Open() Failed", zap.Any("err", openError.Error()))
return "", "", errors.New("function file.Open() Failed, err:" + openError.Error())
}
filecontent := bytes.Buffer{}
_, err := io.Copy(&filecontent, f)
if err != nil {
global.GVA_LOG.Error("读取文件失败", zap.Any("err", err.Error()))
return "", "", errors.New("读取文件失败, err:" + err.Error())
}
f.Close() // 创建文件 defer 关闭
// 对文件名进行加密存储
ext := filepath.Ext(file.Filename)
filename := utils.MD5V([]byte(strings.TrimSuffix(file.Filename, ext))) + ext
if global.GVA_CONFIG.Minio.BasePath == "" {
filePathres = "uploads" + "/" + time.Now().Format("2006-01-02") + "/" + filename
} else {
filePathres = global.GVA_CONFIG.Minio.BasePath + "/" + time.Now().Format("2006-01-02") + "/" + filename
}
// 根据文件扩展名检测 MIME 类型
contentType := mime.TypeByExtension(ext)
if contentType == "" {
contentType = "application/octet-stream"
}
// 设置超时10分钟
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
defer cancel()
// Upload the file with PutObject 大文件自动切换为分片上传
info, err := m.Client.PutObject(ctx, global.GVA_CONFIG.Minio.BucketName, filePathres, &filecontent, file.Size, minio.PutObjectOptions{ContentType: contentType})
if err != nil {
global.GVA_LOG.Error("上传文件到minio失败", zap.Any("err", err.Error()))
return "", "", errors.New("上传文件到minio失败, err:" + err.Error())
}
return global.GVA_CONFIG.Minio.BucketUrl + "/" + info.Key, filePathres, nil
}
func (m *Minio) DeleteFile(key string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
// Delete the object from MinIO
err := m.Client.RemoveObject(ctx, m.bucket, key, minio.RemoveObjectOptions{})
return err
}

View File

@@ -0,0 +1,69 @@
package upload
import (
"mime/multipart"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/huaweicloud/huaweicloud-sdk-go-obs/obs"
"github.com/pkg/errors"
)
var HuaWeiObs = new(Obs)
type Obs struct{}
func NewHuaWeiObsClient() (client *obs.ObsClient, err error) {
return obs.New(global.GVA_CONFIG.HuaWeiObs.AccessKey, global.GVA_CONFIG.HuaWeiObs.SecretKey, global.GVA_CONFIG.HuaWeiObs.Endpoint)
}
func (o *Obs) UploadFile(file *multipart.FileHeader) (string, string, error) {
// var open multipart.File
open, err := file.Open()
if err != nil {
return "", "", err
}
defer open.Close()
filename := file.Filename
input := &obs.PutObjectInput{
PutObjectBasicInput: obs.PutObjectBasicInput{
ObjectOperationInput: obs.ObjectOperationInput{
Bucket: global.GVA_CONFIG.HuaWeiObs.Bucket,
Key: filename,
},
HttpHeader: obs.HttpHeader{
ContentType: file.Header.Get("content-type"),
},
},
Body: open,
}
var client *obs.ObsClient
client, err = NewHuaWeiObsClient()
if err != nil {
return "", "", errors.Wrap(err, "获取华为对象存储对象失败!")
}
_, err = client.PutObject(input)
if err != nil {
return "", "", errors.Wrap(err, "文件上传失败!")
}
filepath := global.GVA_CONFIG.HuaWeiObs.Path + "/" + filename
return filepath, filename, err
}
func (o *Obs) DeleteFile(key string) error {
client, err := NewHuaWeiObsClient()
if err != nil {
return errors.Wrap(err, "获取华为对象存储对象失败!")
}
input := &obs.DeleteObjectInput{
Bucket: global.GVA_CONFIG.HuaWeiObs.Bucket,
Key: key,
}
var output *obs.DeleteObjectOutput
output, err = client.DeleteObject(input)
if err != nil {
return errors.Wrapf(err, "删除对象(%s)失败!, output: %v", key, output)
}
return nil
}

View File

@@ -0,0 +1,96 @@
package upload
import (
"context"
"errors"
"fmt"
"mime/multipart"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
"go.uber.org/zap"
)
type Qiniu struct{}
//@author: [piexlmax](https://github.com/piexlmax)
//@author: [ccfish86](https://github.com/ccfish86)
//@author: [SliverHorn](https://github.com/SliverHorn)
//@object: *Qiniu
//@function: UploadFile
//@description: 上传文件
//@param: file *multipart.FileHeader
//@return: string, string, error
func (*Qiniu) UploadFile(file *multipart.FileHeader) (string, string, error) {
putPolicy := storage.PutPolicy{Scope: global.GVA_CONFIG.Qiniu.Bucket}
mac := qbox.NewMac(global.GVA_CONFIG.Qiniu.AccessKey, global.GVA_CONFIG.Qiniu.SecretKey)
upToken := putPolicy.UploadToken(mac)
cfg := qiniuConfig()
formUploader := storage.NewFormUploader(cfg)
ret := storage.PutRet{}
putExtra := storage.PutExtra{Params: map[string]string{"x:name": "github logo"}}
f, openError := file.Open()
if openError != nil {
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
}
defer f.Close() // 创建文件 defer 关闭
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename) // 文件名格式 自己可以改 建议保证唯一性
putErr := formUploader.Put(context.Background(), &ret, upToken, fileKey, f, file.Size, &putExtra)
if putErr != nil {
global.GVA_LOG.Error("function formUploader.Put() failed", zap.Any("err", putErr.Error()))
return "", "", errors.New("function formUploader.Put() failed, err:" + putErr.Error())
}
return global.GVA_CONFIG.Qiniu.ImgPath + "/" + ret.Key, ret.Key, nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@author: [ccfish86](https://github.com/ccfish86)
//@author: [SliverHorn](https://github.com/SliverHorn)
//@object: *Qiniu
//@function: DeleteFile
//@description: 删除文件
//@param: key string
//@return: error
func (*Qiniu) DeleteFile(key string) error {
mac := qbox.NewMac(global.GVA_CONFIG.Qiniu.AccessKey, global.GVA_CONFIG.Qiniu.SecretKey)
cfg := qiniuConfig()
bucketManager := storage.NewBucketManager(mac, cfg)
if err := bucketManager.Delete(global.GVA_CONFIG.Qiniu.Bucket, key); err != nil {
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
}
return nil
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@object: *Qiniu
//@function: qiniuConfig
//@description: 根据配置文件进行返回七牛云的配置
//@return: *storage.Config
func qiniuConfig() *storage.Config {
cfg := storage.Config{
UseHTTPS: global.GVA_CONFIG.Qiniu.UseHTTPS,
UseCdnDomains: global.GVA_CONFIG.Qiniu.UseCdnDomains,
}
switch global.GVA_CONFIG.Qiniu.Zone { // 根据配置文件进行初始化空间对应的机房
case "ZoneHuadong":
cfg.Zone = &storage.ZoneHuadong
case "ZoneHuabei":
cfg.Zone = &storage.ZoneHuabei
case "ZoneHuanan":
cfg.Zone = &storage.ZoneHuanan
case "ZoneBeimei":
cfg.Zone = &storage.ZoneBeimei
case "ZoneXinjiapo":
cfg.Zone = &storage.ZoneXinjiapo
}
return &cfg
}

View File

@@ -0,0 +1,61 @@
package upload
import (
"context"
"errors"
"fmt"
"mime/multipart"
"net/http"
"net/url"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/tencentyun/cos-go-sdk-v5"
"go.uber.org/zap"
)
type TencentCOS struct{}
// UploadFile upload file to COS
func (*TencentCOS) UploadFile(file *multipart.FileHeader) (string, string, error) {
client := NewClient()
f, openError := file.Open()
if openError != nil {
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
}
defer f.Close() // 创建文件 defer 关闭
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename)
_, err := client.Object.Put(context.Background(), global.GVA_CONFIG.TencentCOS.PathPrefix+"/"+fileKey, f, nil)
if err != nil {
panic(err)
}
return global.GVA_CONFIG.TencentCOS.BaseURL + "/" + global.GVA_CONFIG.TencentCOS.PathPrefix + "/" + fileKey, fileKey, nil
}
// DeleteFile delete file form COS
func (*TencentCOS) DeleteFile(key string) error {
client := NewClient()
name := global.GVA_CONFIG.TencentCOS.PathPrefix + "/" + key
_, err := client.Object.Delete(context.Background(), name)
if err != nil {
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
}
return nil
}
// NewClient init COS client
func NewClient() *cos.Client {
urlStr, _ := url.Parse("https://" + global.GVA_CONFIG.TencentCOS.Bucket + ".cos." + global.GVA_CONFIG.TencentCOS.Region + ".myqcloud.com")
baseURL := &cos.BaseURL{BucketURL: urlStr}
client := cos.NewClient(baseURL, &http.Client{
Transport: &cos.AuthorizationTransport{
SecretID: global.GVA_CONFIG.TencentCOS.SecretID,
SecretKey: global.GVA_CONFIG.TencentCOS.SecretKey,
},
})
return client
}

View File

@@ -0,0 +1,46 @@
package upload
import (
"mime/multipart"
"git.echol.cn/loser/ai_proxy/server/global"
)
// OSS 对象存储接口
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [ccfish86](https://github.com/ccfish86)
type OSS interface {
UploadFile(file *multipart.FileHeader) (string, string, error)
DeleteFile(key string) error
}
// NewOss OSS的实例化方法
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [ccfish86](https://github.com/ccfish86)
func NewOss() OSS {
switch global.GVA_CONFIG.System.OssType {
case "local":
return &Local{}
case "qiniu":
return &Qiniu{}
case "tencent-cos":
return &TencentCOS{}
case "aliyun-oss":
return &AliyunOSS{}
case "huawei-obs":
return HuaWeiObs
case "aws-s3":
return &AwsS3{}
case "cloudflare-r2":
return &CloudflareR2{}
case "minio":
minioClient, err := GetMinio(global.GVA_CONFIG.Minio.Endpoint, global.GVA_CONFIG.Minio.AccessKeyId, global.GVA_CONFIG.Minio.AccessKeySecret, global.GVA_CONFIG.Minio.BucketName, global.GVA_CONFIG.Minio.UseSSL)
if err != nil {
global.GVA_LOG.Warn("你配置了使用minio但是初始化失败请检查minio可用性或安全配置: " + err.Error())
panic("minio初始化失败") // 建议这样做用户自己配置了minio如果报错了还要把服务开起来使用起来也很危险
}
return minioClient
default:
return &Local{}
}
}

294
server/utils/validator.go Normal file
View File

@@ -0,0 +1,294 @@
package utils
import (
"errors"
"reflect"
"regexp"
"strconv"
"strings"
)
type Rules map[string][]string
type RulesMap map[string]Rules
var CustomizeMap = make(map[string]Rules)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: RegisterRule
//@description: 注册自定义规则方案建议在路由初始化层即注册
//@param: key string, rule Rules
//@return: err error
func RegisterRule(key string, rule Rules) (err error) {
if CustomizeMap[key] != nil {
return errors.New(key + "已注册,无法重复注册")
} else {
CustomizeMap[key] = rule
return nil
}
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: NotEmpty
//@description: 非空 不能为其对应类型的0值
//@return: string
func NotEmpty() string {
return "notEmpty"
}
// @author: [zooqkl](https://github.com/zooqkl)
// @function: RegexpMatch
// @description: 正则校验 校验输入项是否满足正则表达式
// @param: rule string
// @return: string
func RegexpMatch(rule string) string {
return "regexp=" + rule
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Lt
//@description: 小于入参(<) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
//@param: mark string
//@return: string
func Lt(mark string) string {
return "lt=" + mark
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Le
//@description: 小于等于入参(<=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
//@param: mark string
//@return: string
func Le(mark string) string {
return "le=" + mark
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Eq
//@description: 等于入参(==) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
//@param: mark string
//@return: string
func Eq(mark string) string {
return "eq=" + mark
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Ne
//@description: 不等于入参(!=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
//@param: mark string
//@return: string
func Ne(mark string) string {
return "ne=" + mark
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Ge
//@description: 大于等于入参(>=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
//@param: mark string
//@return: string
func Ge(mark string) string {
return "ge=" + mark
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Gt
//@description: 大于入参(>) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
//@param: mark string
//@return: string
func Gt(mark string) string {
return "gt=" + mark
}
//
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Verify
//@description: 校验方法
//@param: st interface{}, roleMap Rules(入参实例规则map)
//@return: err error
func Verify(st interface{}, roleMap Rules) (err error) {
compareMap := map[string]bool{
"lt": true,
"le": true,
"eq": true,
"ne": true,
"ge": true,
"gt": true,
}
typ := reflect.TypeOf(st)
val := reflect.ValueOf(st) // 获取reflect.Type类型
kd := val.Kind() // 获取到st对应的类别
if kd != reflect.Struct {
return errors.New("expect struct")
}
num := val.NumField()
// 遍历结构体的所有字段
for i := 0; i < num; i++ {
tagVal := typ.Field(i)
val := val.Field(i)
if tagVal.Type.Kind() == reflect.Struct {
if err = Verify(val.Interface(), roleMap); err != nil {
return err
}
}
if len(roleMap[tagVal.Name]) > 0 {
for _, v := range roleMap[tagVal.Name] {
switch {
case v == "notEmpty":
if isBlank(val) {
return errors.New(tagVal.Name + "值不能为空")
}
case strings.Split(v, "=")[0] == "regexp":
if !regexpMatch(strings.Split(v, "=")[1], val.String()) {
return errors.New(tagVal.Name + "格式校验不通过")
}
case compareMap[strings.Split(v, "=")[0]]:
if !compareVerify(val, v) {
return errors.New(tagVal.Name + "长度或值不在合法范围," + v)
}
}
}
}
}
return nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: compareVerify
//@description: 长度和数字的校验方法 根据类型自动校验
//@param: value reflect.Value, VerifyStr string
//@return: bool
func compareVerify(value reflect.Value, VerifyStr string) bool {
switch value.Kind() {
case reflect.String:
return compare(len([]rune(value.String())), VerifyStr)
case reflect.Slice, reflect.Array:
return compare(value.Len(), VerifyStr)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return compare(value.Uint(), VerifyStr)
case reflect.Float32, reflect.Float64:
return compare(value.Float(), VerifyStr)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return compare(value.Int(), VerifyStr)
default:
return false
}
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: isBlank
//@description: 非空校验
//@param: value reflect.Value
//@return: bool
func isBlank(value reflect.Value) bool {
switch value.Kind() {
case reflect.String, reflect.Slice:
return value.Len() == 0
case reflect.Bool:
return !value.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return value.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return value.Uint() == 0
case reflect.Float32, reflect.Float64:
return value.Float() == 0
case reflect.Interface, reflect.Ptr:
return value.IsNil()
}
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: compare
//@description: 比较函数
//@param: value interface{}, VerifyStr string
//@return: bool
func compare(value interface{}, VerifyStr string) bool {
VerifyStrArr := strings.Split(VerifyStr, "=")
val := reflect.ValueOf(value)
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
VInt, VErr := strconv.ParseInt(VerifyStrArr[1], 10, 64)
if VErr != nil {
return false
}
switch {
case VerifyStrArr[0] == "lt":
return val.Int() < VInt
case VerifyStrArr[0] == "le":
return val.Int() <= VInt
case VerifyStrArr[0] == "eq":
return val.Int() == VInt
case VerifyStrArr[0] == "ne":
return val.Int() != VInt
case VerifyStrArr[0] == "ge":
return val.Int() >= VInt
case VerifyStrArr[0] == "gt":
return val.Int() > VInt
default:
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
VInt, VErr := strconv.Atoi(VerifyStrArr[1])
if VErr != nil {
return false
}
switch {
case VerifyStrArr[0] == "lt":
return val.Uint() < uint64(VInt)
case VerifyStrArr[0] == "le":
return val.Uint() <= uint64(VInt)
case VerifyStrArr[0] == "eq":
return val.Uint() == uint64(VInt)
case VerifyStrArr[0] == "ne":
return val.Uint() != uint64(VInt)
case VerifyStrArr[0] == "ge":
return val.Uint() >= uint64(VInt)
case VerifyStrArr[0] == "gt":
return val.Uint() > uint64(VInt)
default:
return false
}
case reflect.Float32, reflect.Float64:
VFloat, VErr := strconv.ParseFloat(VerifyStrArr[1], 64)
if VErr != nil {
return false
}
switch {
case VerifyStrArr[0] == "lt":
return val.Float() < VFloat
case VerifyStrArr[0] == "le":
return val.Float() <= VFloat
case VerifyStrArr[0] == "eq":
return val.Float() == VFloat
case VerifyStrArr[0] == "ne":
return val.Float() != VFloat
case VerifyStrArr[0] == "ge":
return val.Float() >= VFloat
case VerifyStrArr[0] == "gt":
return val.Float() > VFloat
default:
return false
}
default:
return false
}
}
func regexpMatch(rule, matchStr string) bool {
return regexp.MustCompile(rule).MatchString(matchStr)
}

View File

@@ -0,0 +1,38 @@
package utils
import (
"testing"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type PageInfoTest struct {
PageInfo request.PageInfo
Name string
}
func TestVerify(t *testing.T) {
PageInfoVerify := Rules{"Page": {NotEmpty()}, "PageSize": {NotEmpty()}, "Name": {NotEmpty()}}
var testInfo PageInfoTest
testInfo.Name = "test"
testInfo.PageInfo.Page = 0
testInfo.PageInfo.PageSize = 0
err := Verify(testInfo, PageInfoVerify)
if err == nil {
t.Error("校验失败未能捕捉0值")
}
testInfo.Name = ""
testInfo.PageInfo.Page = 1
testInfo.PageInfo.PageSize = 10
err = Verify(testInfo, PageInfoVerify)
if err == nil {
t.Error("校验失败未能正常检测name为空")
}
testInfo.Name = "test"
testInfo.PageInfo.Page = 1
testInfo.PageInfo.PageSize = 10
err = Verify(testInfo, PageInfoVerify)
if err != nil {
t.Error("校验失败,未能正常通过检测")
}
}

19
server/utils/verify.go Normal file
View File

@@ -0,0 +1,19 @@
package utils
var (
IdVerify = Rules{"ID": []string{NotEmpty()}}
ApiVerify = Rules{"Path": {NotEmpty()}, "Description": {NotEmpty()}, "ApiGroup": {NotEmpty()}, "Method": {NotEmpty()}}
MenuVerify = Rules{"Path": {NotEmpty()}, "Name": {NotEmpty()}, "Component": {NotEmpty()}, "Sort": {Ge("0")}}
MenuMetaVerify = Rules{"Title": {NotEmpty()}}
LoginVerify = Rules{"Username": {NotEmpty()}, "Password": {NotEmpty()}}
RegisterVerify = Rules{"Username": {NotEmpty()}, "NickName": {NotEmpty()}, "Password": {NotEmpty()}, "AuthorityId": {NotEmpty()}}
PageInfoVerify = Rules{"Page": {NotEmpty()}, "PageSize": {NotEmpty()}}
CustomerVerify = Rules{"CustomerName": {NotEmpty()}, "CustomerPhoneData": {NotEmpty()}}
AutoCodeVerify = Rules{"Abbreviation": {NotEmpty()}, "StructName": {NotEmpty()}, "PackageName": {NotEmpty()}}
AutoPackageVerify = Rules{"PackageName": {NotEmpty()}}
AuthorityVerify = Rules{"AuthorityId": {NotEmpty()}, "AuthorityName": {NotEmpty()}}
AuthorityIdVerify = Rules{"AuthorityId": {NotEmpty()}}
OldAuthorityVerify = Rules{"OldAuthorityId": {NotEmpty()}}
ChangePasswordVerify = Rules{"Password": {NotEmpty()}, "NewPassword": {NotEmpty()}}
SetUserAuthorityVerify = Rules{"AuthorityId": {NotEmpty()}}
)

53
server/utils/zip.go Normal file
View File

@@ -0,0 +1,53 @@
package utils
import (
"archive/zip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// 解压
func Unzip(zipFile string, destDir string) ([]string, error) {
zipReader, err := zip.OpenReader(zipFile)
var paths []string
if err != nil {
return []string{}, err
}
defer zipReader.Close()
for _, f := range zipReader.File {
if strings.Contains(f.Name, "..") {
return []string{}, fmt.Errorf("%s 文件名不合法", f.Name)
}
fpath := filepath.Join(destDir, f.Name)
paths = append(paths, fpath)
if f.FileInfo().IsDir() {
os.MkdirAll(fpath, os.ModePerm)
} else {
if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
return []string{}, err
}
inFile, err := f.Open()
if err != nil {
return []string{}, err
}
defer inFile.Close()
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return []string{}, err
}
defer outFile.Close()
_, err = io.Copy(outFile, inFile)
if err != nil {
return []string{}, err
}
}
}
return paths, nil
}