✨ init project
This commit is contained in:
231
utils/ast/ast.go
Normal file
231
utils/ast/ast.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.echol.cn/loser/lckt/model/system"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"log"
|
||||
)
|
||||
|
||||
// AddImport 增加 import 方法
|
||||
func AddImport(astNode ast.Node, imp string) {
|
||||
impStr := fmt.Sprintf("\"%s\"", imp)
|
||||
ast.Inspect(astNode, func(node ast.Node) bool {
|
||||
if genDecl, ok := node.(*ast.GenDecl); ok {
|
||||
if genDecl.Tok == token.IMPORT {
|
||||
for i := range genDecl.Specs {
|
||||
if impNode, ok := genDecl.Specs[i].(*ast.ImportSpec); ok {
|
||||
if impNode.Path.Value == impStr {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
|
||||
Path: &ast.BasicLit{
|
||||
Kind: token.STRING,
|
||||
Value: impStr,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// FindFunction 查询特定function方法
|
||||
func FindFunction(astNode ast.Node, FunctionName string) *ast.FuncDecl {
|
||||
var funcDeclP *ast.FuncDecl
|
||||
ast.Inspect(astNode, func(node ast.Node) bool {
|
||||
if funcDecl, ok := node.(*ast.FuncDecl); ok {
|
||||
if funcDecl.Name.String() == FunctionName {
|
||||
funcDeclP = funcDecl
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return funcDeclP
|
||||
}
|
||||
|
||||
// FindArray 查询特定数组方法
|
||||
func FindArray(astNode ast.Node, identName, selectorExprName string) *ast.CompositeLit {
|
||||
var assignStmt *ast.CompositeLit
|
||||
ast.Inspect(astNode, func(n ast.Node) bool {
|
||||
switch node := n.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for _, expr := range node.Rhs {
|
||||
if exprType, ok := expr.(*ast.CompositeLit); ok {
|
||||
if arrayType, ok := exprType.Type.(*ast.ArrayType); ok {
|
||||
sel, ok1 := arrayType.Elt.(*ast.SelectorExpr)
|
||||
x, ok2 := sel.X.(*ast.Ident)
|
||||
if ok1 && ok2 && x.Name == identName && sel.Sel.Name == selectorExprName {
|
||||
assignStmt = exprType
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return assignStmt
|
||||
}
|
||||
|
||||
func CreateMenuStructAst(menus []system.SysBaseMenu) *[]ast.Expr {
|
||||
var menuElts []ast.Expr
|
||||
for i := range menus {
|
||||
elts := []ast.Expr{ // 结构体的字段
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "ParentId"},
|
||||
Value: &ast.BasicLit{Kind: token.INT, Value: "0"},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Path"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Path)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Name"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Name)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Hidden"},
|
||||
Value: &ast.Ident{Name: "false"},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Component"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Component)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Sort"},
|
||||
Value: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", menus[i].Sort)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Meta"},
|
||||
Value: &ast.CompositeLit{
|
||||
Type: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: "model"},
|
||||
Sel: &ast.Ident{Name: "Meta"},
|
||||
},
|
||||
Elts: []ast.Expr{
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Title"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Title)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Icon"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Icon)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
menuElts = append(menuElts, &ast.CompositeLit{
|
||||
Type: nil,
|
||||
Elts: elts,
|
||||
})
|
||||
}
|
||||
return &menuElts
|
||||
}
|
||||
|
||||
func CreateApiStructAst(apis []system.SysApi) *[]ast.Expr {
|
||||
var apiElts []ast.Expr
|
||||
for i := range apis {
|
||||
elts := []ast.Expr{ // 结构体的字段
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Path"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].Path)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Description"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].Description)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "ApiGroup"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].ApiGroup)},
|
||||
},
|
||||
&ast.KeyValueExpr{
|
||||
Key: &ast.Ident{Name: "Method"},
|
||||
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].Method)},
|
||||
},
|
||||
}
|
||||
apiElts = append(apiElts, &ast.CompositeLit{
|
||||
Type: nil,
|
||||
Elts: elts,
|
||||
})
|
||||
}
|
||||
return &apiElts
|
||||
}
|
||||
|
||||
// CheckImport 检查是否存在Import
|
||||
func CheckImport(file *ast.File, importPath string) bool {
|
||||
for _, imp := range file.Imports {
|
||||
// Remove quotes around the import path
|
||||
path := imp.Path.Value[1 : len(imp.Path.Value)-1]
|
||||
|
||||
if path == importPath {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func clearPosition(astNode ast.Node) {
|
||||
ast.Inspect(astNode, func(n ast.Node) bool {
|
||||
switch node := n.(type) {
|
||||
case *ast.Ident:
|
||||
// 清除位置信息
|
||||
node.NamePos = token.NoPos
|
||||
case *ast.CallExpr:
|
||||
// 清除位置信息
|
||||
node.Lparen = token.NoPos
|
||||
node.Rparen = token.NoPos
|
||||
case *ast.BasicLit:
|
||||
// 清除位置信息
|
||||
node.ValuePos = token.NoPos
|
||||
case *ast.SelectorExpr:
|
||||
// 清除位置信息
|
||||
node.Sel.NamePos = token.NoPos
|
||||
case *ast.BinaryExpr:
|
||||
node.OpPos = token.NoPos
|
||||
case *ast.UnaryExpr:
|
||||
node.OpPos = token.NoPos
|
||||
case *ast.StarExpr:
|
||||
node.Star = token.NoPos
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func CreateStmt(statement string) *ast.ExprStmt {
|
||||
expr, err := parser.ParseExpr(statement)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
clearPosition(expr)
|
||||
return &ast.ExprStmt{X: expr}
|
||||
}
|
||||
|
||||
func IsBlockStmt(node ast.Node) bool {
|
||||
_, ok := node.(*ast.BlockStmt)
|
||||
return ok
|
||||
}
|
||||
|
||||
func VariableExistsInBlock(block *ast.BlockStmt, varName string) bool {
|
||||
exists := false
|
||||
ast.Inspect(block, func(n ast.Node) bool {
|
||||
switch node := n.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for _, expr := range node.Lhs {
|
||||
if ident, ok := expr.(*ast.Ident); ok && ident.Name == varName {
|
||||
exists = true
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return exists
|
||||
}
|
47
utils/ast/ast_auto_enter.go
Normal file
47
utils/ast/ast_auto_enter.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"os"
|
||||
)
|
||||
|
||||
func ImportForAutoEnter(path string, funcName string, code string) {
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fileSet := token.NewFileSet()
|
||||
astFile, err := parser.ParseFile(fileSet, "", src, 0)
|
||||
ast.Inspect(astFile, func(node ast.Node) bool {
|
||||
if typeSpec, ok := node.(*ast.TypeSpec); ok {
|
||||
if typeSpec.Name.Name == funcName {
|
||||
if st, ok := typeSpec.Type.(*ast.StructType); ok {
|
||||
for i := range st.Fields.List {
|
||||
if t, ok := st.Fields.List[i].Type.(*ast.Ident); ok {
|
||||
if t.Name == code {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
sn := &ast.Field{
|
||||
Type: &ast.Ident{Name: code},
|
||||
}
|
||||
st.Fields.List = append(st.Fields.List, sn)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
var out []byte
|
||||
bf := bytes.NewBuffer(out)
|
||||
err = printer.Fprint(bf, fileSet, astFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = os.WriteFile(path, bf.Bytes(), 0666)
|
||||
}
|
181
utils/ast/ast_enter.go
Normal file
181
utils/ast/ast_enter.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Visitor struct {
|
||||
ImportCode string
|
||||
StructName string
|
||||
PackageName string
|
||||
GroupName string
|
||||
}
|
||||
|
||||
func (vi *Visitor) Visit(node ast.Node) ast.Visitor {
|
||||
switch n := node.(type) {
|
||||
case *ast.GenDecl:
|
||||
// 查找有没有import context包
|
||||
// Notice:没有考虑没有import任何包的情况
|
||||
if n.Tok == token.IMPORT && vi.ImportCode != "" {
|
||||
vi.addImport(n)
|
||||
// 不需要再遍历子树
|
||||
return nil
|
||||
}
|
||||
if n.Tok == token.TYPE && vi.StructName != "" && vi.PackageName != "" && vi.GroupName != "" {
|
||||
vi.addStruct(n)
|
||||
return nil
|
||||
}
|
||||
case *ast.FuncDecl:
|
||||
if n.Name.Name == "Routers" {
|
||||
vi.addFuncBodyVar(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
}
|
||||
return vi
|
||||
}
|
||||
|
||||
func (vi *Visitor) addStruct(genDecl *ast.GenDecl) ast.Visitor {
|
||||
for i := range genDecl.Specs {
|
||||
switch n := genDecl.Specs[i].(type) {
|
||||
case *ast.TypeSpec:
|
||||
if strings.Index(n.Name.Name, "Group") > -1 {
|
||||
switch t := n.Type.(type) {
|
||||
case *ast.StructType:
|
||||
f := &ast.Field{
|
||||
Names: []*ast.Ident{
|
||||
{
|
||||
Name: vi.StructName,
|
||||
Obj: &ast.Object{
|
||||
Kind: ast.Var,
|
||||
Name: vi.StructName,
|
||||
},
|
||||
},
|
||||
},
|
||||
Type: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: vi.PackageName,
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: vi.GroupName,
|
||||
},
|
||||
},
|
||||
}
|
||||
t.Fields.List = append(t.Fields.List, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return vi
|
||||
}
|
||||
|
||||
func (vi *Visitor) addImport(genDecl *ast.GenDecl) ast.Visitor {
|
||||
// 是否已经import
|
||||
hasImported := false
|
||||
for _, v := range genDecl.Specs {
|
||||
importSpec := v.(*ast.ImportSpec)
|
||||
// 如果已经包含
|
||||
if importSpec.Path.Value == strconv.Quote(vi.ImportCode) {
|
||||
hasImported = true
|
||||
}
|
||||
}
|
||||
if !hasImported {
|
||||
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
|
||||
Path: &ast.BasicLit{
|
||||
Kind: token.STRING,
|
||||
Value: strconv.Quote(vi.ImportCode),
|
||||
},
|
||||
})
|
||||
}
|
||||
return vi
|
||||
}
|
||||
|
||||
func (vi *Visitor) addFuncBodyVar(funDecl *ast.FuncDecl) ast.Visitor {
|
||||
hasVar := false
|
||||
for _, v := range funDecl.Body.List {
|
||||
switch varSpec := v.(type) {
|
||||
case *ast.AssignStmt:
|
||||
for i := range varSpec.Lhs {
|
||||
switch nn := varSpec.Lhs[i].(type) {
|
||||
case *ast.Ident:
|
||||
if nn.Name == vi.PackageName+"Router" {
|
||||
hasVar = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasVar {
|
||||
assignStmt := &ast.AssignStmt{
|
||||
Lhs: []ast.Expr{
|
||||
&ast.Ident{
|
||||
Name: vi.PackageName + "Router",
|
||||
Obj: &ast.Object{
|
||||
Kind: ast.Var,
|
||||
Name: vi.PackageName + "Router",
|
||||
},
|
||||
},
|
||||
},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{
|
||||
&ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: "router",
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: "RouterGroupApp",
|
||||
},
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: cases.Title(language.English).String(vi.PackageName),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
funDecl.Body.List = append(funDecl.Body.List, funDecl.Body.List[1])
|
||||
index := 1
|
||||
copy(funDecl.Body.List[index+1:], funDecl.Body.List[index:])
|
||||
funDecl.Body.List[index] = assignStmt
|
||||
}
|
||||
return vi
|
||||
}
|
||||
|
||||
func ImportReference(filepath, importCode, structName, packageName, groupName string) error {
|
||||
fSet := token.NewFileSet()
|
||||
fParser, err := parser.ParseFile(fSet, filepath, nil, parser.ParseComments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
importCode = strings.TrimSpace(importCode)
|
||||
v := &Visitor{
|
||||
ImportCode: importCode,
|
||||
StructName: structName,
|
||||
PackageName: packageName,
|
||||
GroupName: groupName,
|
||||
}
|
||||
if importCode == "" {
|
||||
ast.Print(fSet, fParser)
|
||||
}
|
||||
|
||||
ast.Walk(v, fParser)
|
||||
|
||||
var output []byte
|
||||
buffer := bytes.NewBuffer(output)
|
||||
err = format.Node(buffer, fSet, fParser)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// 写回数据
|
||||
return os.WriteFile(filepath, buffer.Bytes(), 0o600)
|
||||
}
|
166
utils/ast/ast_gorm.go
Normal file
166
utils/ast/ast_gorm.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"os"
|
||||
)
|
||||
|
||||
// AddRegisterTablesAst 自动为 gorm.go 注册一个自动迁移
|
||||
func AddRegisterTablesAst(path, funcName, pk, varName, dbName, model string) {
|
||||
modelPk := fmt.Sprintf("git.echol.cn/loser/lckt/model/%s", pk)
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fileSet := token.NewFileSet()
|
||||
astFile, err := parser.ParseFile(fileSet, "", src, 0)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
AddImport(astFile, modelPk)
|
||||
FuncNode := FindFunction(astFile, funcName)
|
||||
if FuncNode != nil {
|
||||
ast.Print(fileSet, FuncNode)
|
||||
}
|
||||
addDBVar(FuncNode.Body, varName, dbName)
|
||||
addAutoMigrate(FuncNode.Body, varName, pk, model)
|
||||
var out []byte
|
||||
bf := bytes.NewBuffer(out)
|
||||
printer.Fprint(bf, fileSet, astFile)
|
||||
|
||||
os.WriteFile(path, bf.Bytes(), 0666)
|
||||
}
|
||||
|
||||
// 增加一个 db库变量
|
||||
func addDBVar(astBody *ast.BlockStmt, varName, dbName string) {
|
||||
if dbName == "" {
|
||||
return
|
||||
}
|
||||
dbStr := fmt.Sprintf("\"%s\"", dbName)
|
||||
|
||||
for i := range astBody.List {
|
||||
if assignStmt, ok := astBody.List[i].(*ast.AssignStmt); ok {
|
||||
if ident, ok := assignStmt.Lhs[0].(*ast.Ident); ok {
|
||||
if ident.Name == varName {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assignNode := &ast.AssignStmt{
|
||||
Lhs: []ast.Expr{
|
||||
&ast.Ident{
|
||||
Name: varName,
|
||||
},
|
||||
},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{
|
||||
&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: "global",
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: "GetGlobalDBByDBName",
|
||||
},
|
||||
},
|
||||
Args: []ast.Expr{
|
||||
&ast.BasicLit{
|
||||
Kind: token.STRING,
|
||||
Value: dbStr,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
astBody.List = append([]ast.Stmt{assignNode}, astBody.List...)
|
||||
}
|
||||
|
||||
// 为db库变量增加 AutoMigrate 方法
|
||||
func addAutoMigrate(astBody *ast.BlockStmt, dbname string, pk string, model string) {
|
||||
if dbname == "" {
|
||||
dbname = "db"
|
||||
}
|
||||
flag := true
|
||||
ast.Inspect(astBody, func(node ast.Node) bool {
|
||||
// 首先判断需要加入的方法调用语句是否存在 不存在则直接走到下方逻辑
|
||||
switch n := node.(type) {
|
||||
case *ast.CallExpr:
|
||||
// 判断是否找到了AutoMigrate语句
|
||||
if s, ok := n.Fun.(*ast.SelectorExpr); ok {
|
||||
if x, ok := s.X.(*ast.Ident); ok {
|
||||
if s.Sel.Name == "AutoMigrate" && x.Name == dbname {
|
||||
flag = false
|
||||
if !NeedAppendModel(n, pk, model) {
|
||||
return false
|
||||
}
|
||||
// 判断已经找到了AutoMigrate语句
|
||||
n.Args = append(n.Args, &ast.CompositeLit{
|
||||
Type: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: pk,
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: model,
|
||||
},
|
||||
},
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
//然后判断 pk.model是否存在 如果存在直接跳出 如果不存在 则向已经找到的方法调用语句的node里面push一条
|
||||
})
|
||||
|
||||
if flag {
|
||||
exprStmt := &ast.ExprStmt{
|
||||
X: &ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: dbname,
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: "AutoMigrate",
|
||||
},
|
||||
},
|
||||
Args: []ast.Expr{
|
||||
&ast.CompositeLit{
|
||||
Type: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: pk,
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: model,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
astBody.List = append(astBody.List, exprStmt)
|
||||
}
|
||||
}
|
||||
|
||||
// NeedAppendModel 为automigrate增加实参
|
||||
func NeedAppendModel(callNode ast.Node, pk string, model string) bool {
|
||||
flag := true
|
||||
ast.Inspect(callNode, func(node ast.Node) bool {
|
||||
switch n := node.(type) {
|
||||
case *ast.SelectorExpr:
|
||||
if x, ok := n.X.(*ast.Ident); ok {
|
||||
if n.Sel.Name == model && x.Name == pk {
|
||||
flag = false
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return flag
|
||||
}
|
11
utils/ast/ast_init_test.go
Normal file
11
utils/ast/ast_init_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func init() {
|
||||
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("../../../")
|
||||
global.GVA_CONFIG.AutoCode.Server = "server"
|
||||
}
|
173
utils/ast/ast_rollback.go
Normal file
173
utils/ast/ast_rollback.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func RollBackAst(pk, model string) {
|
||||
RollGormBack(pk, model)
|
||||
RollRouterBack(pk, model)
|
||||
}
|
||||
|
||||
func RollGormBack(pk, model string) {
|
||||
|
||||
// 首先分析存在多少个ttt作为调用方的node块
|
||||
// 如果多个 仅仅删除对应块即可
|
||||
// 如果单个 那么还需要剔除import
|
||||
path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go")
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fileSet := token.NewFileSet()
|
||||
astFile, err := parser.ParseFile(fileSet, "", src, 0)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
var n *ast.CallExpr
|
||||
var k int = -1
|
||||
var pkNum = 0
|
||||
ast.Inspect(astFile, func(node ast.Node) bool {
|
||||
if node, ok := node.(*ast.CallExpr); ok {
|
||||
for i := range node.Args {
|
||||
pkOK := false
|
||||
modelOK := false
|
||||
ast.Inspect(node.Args[i], func(item ast.Node) bool {
|
||||
if ii, ok := item.(*ast.Ident); ok {
|
||||
if ii.Name == pk {
|
||||
pkOK = true
|
||||
pkNum++
|
||||
}
|
||||
if ii.Name == model {
|
||||
modelOK = true
|
||||
}
|
||||
}
|
||||
if pkOK && modelOK {
|
||||
n = node
|
||||
k = i
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if k > -1 {
|
||||
n.Args = append(append([]ast.Expr{}, n.Args[:k]...), n.Args[k+1:]...)
|
||||
}
|
||||
if pkNum == 1 {
|
||||
var imI int = -1
|
||||
var gp *ast.GenDecl
|
||||
ast.Inspect(astFile, func(node ast.Node) bool {
|
||||
if gen, ok := node.(*ast.GenDecl); ok {
|
||||
for i := range gen.Specs {
|
||||
if imspec, ok := gen.Specs[i].(*ast.ImportSpec); ok {
|
||||
if imspec.Path.Value == "\"git.echol.cn/loser/lckt/model/"+pk+"\"" {
|
||||
gp = gen
|
||||
imI = i
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if imI > -1 {
|
||||
gp.Specs = append(append([]ast.Spec{}, gp.Specs[:imI]...), gp.Specs[imI+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
var out []byte
|
||||
bf := bytes.NewBuffer(out)
|
||||
printer.Fprint(bf, fileSet, astFile)
|
||||
os.Remove(path)
|
||||
os.WriteFile(path, bf.Bytes(), 0666)
|
||||
|
||||
}
|
||||
|
||||
func RollRouterBack(pk, model string) {
|
||||
|
||||
// 首先抓到所有的代码块结构 {}
|
||||
// 分析结构中是否存在一个变量叫做 pk+Router
|
||||
// 然后获取到代码块指针 对内部需要回滚的代码进行剔除
|
||||
path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go")
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fileSet := token.NewFileSet()
|
||||
astFile, err := parser.ParseFile(fileSet, "", src, 0)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
var block *ast.BlockStmt
|
||||
var routerStmt *ast.FuncDecl
|
||||
|
||||
ast.Inspect(astFile, func(node ast.Node) bool {
|
||||
if n, ok := node.(*ast.FuncDecl); ok {
|
||||
if n.Name.Name == "initBizRouter" {
|
||||
routerStmt = n
|
||||
}
|
||||
}
|
||||
|
||||
if n, ok := node.(*ast.BlockStmt); ok {
|
||||
ast.Inspect(n, func(bNode ast.Node) bool {
|
||||
if in, ok := bNode.(*ast.Ident); ok {
|
||||
if in.Name == pk+"Router" {
|
||||
block = n
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
}
|
||||
return true
|
||||
})
|
||||
var k int
|
||||
for i := range block.List {
|
||||
if stmtNode, ok := block.List[i].(*ast.ExprStmt); ok {
|
||||
ast.Inspect(stmtNode, func(node ast.Node) bool {
|
||||
if n, ok := node.(*ast.Ident); ok {
|
||||
if n.Name == "Init"+model+"Router" {
|
||||
k = i
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
block.List = append(append([]ast.Stmt{}, block.List[:k]...), block.List[k+1:]...)
|
||||
|
||||
if len(block.List) == 1 {
|
||||
// 说明这个块就没任何意义了
|
||||
block.List = nil
|
||||
}
|
||||
|
||||
for i, n := range routerStmt.Body.List {
|
||||
if n, ok := n.(*ast.BlockStmt); ok {
|
||||
if n.List == nil {
|
||||
routerStmt.Body.List = append(append([]ast.Stmt{}, routerStmt.Body.List[:i]...), routerStmt.Body.List[i+1:]...)
|
||||
i--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var out []byte
|
||||
bf := bytes.NewBuffer(out)
|
||||
printer.Fprint(bf, fileSet, astFile)
|
||||
os.Remove(path)
|
||||
os.WriteFile(path, bf.Bytes(), 0666)
|
||||
}
|
135
utils/ast/ast_router.go
Normal file
135
utils/ast/ast_router.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func AppendNodeToList(stmts []ast.Stmt, stmt ast.Stmt, index int) []ast.Stmt {
|
||||
return append(stmts[:index], append([]ast.Stmt{stmt}, stmts[index:]...)...)
|
||||
}
|
||||
|
||||
func AddRouterCode(path, funcName, pk, model string) {
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fileSet := token.NewFileSet()
|
||||
astFile, err := parser.ParseFile(fileSet, "", src, parser.ParseComments)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
FuncNode := FindFunction(astFile, funcName)
|
||||
|
||||
pkName := strings.ToUpper(pk[:1]) + pk[1:]
|
||||
routerName := fmt.Sprintf("%sRouter", pk)
|
||||
modelName := fmt.Sprintf("Init%sRouter", model)
|
||||
var bloctPre *ast.BlockStmt
|
||||
for i := len(FuncNode.Body.List) - 1; i >= 0; i-- {
|
||||
if block, ok := FuncNode.Body.List[i].(*ast.BlockStmt); ok {
|
||||
bloctPre = block
|
||||
}
|
||||
}
|
||||
ast.Print(fileSet, FuncNode)
|
||||
if ok, b := needAppendRouter(FuncNode, pk); ok {
|
||||
routerNode :=
|
||||
&ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{
|
||||
&ast.Ident{Name: routerName},
|
||||
},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{
|
||||
&ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: "router"},
|
||||
Sel: &ast.Ident{Name: "RouterGroupApp"},
|
||||
},
|
||||
Sel: &ast.Ident{Name: pkName},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
FuncNode.Body.List = AppendNodeToList(FuncNode.Body.List, routerNode, len(FuncNode.Body.List)-1)
|
||||
bloctPre = routerNode
|
||||
} else {
|
||||
bloctPre = b
|
||||
}
|
||||
|
||||
if needAppendInit(FuncNode, routerName, modelName) {
|
||||
bloctPre.List = append(bloctPre.List,
|
||||
&ast.ExprStmt{
|
||||
X: &ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: routerName},
|
||||
Sel: &ast.Ident{Name: modelName},
|
||||
},
|
||||
Args: []ast.Expr{
|
||||
&ast.Ident{
|
||||
Name: "privateGroup",
|
||||
},
|
||||
&ast.Ident{
|
||||
Name: "publicGroup",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
var out []byte
|
||||
bf := bytes.NewBuffer(out)
|
||||
printer.Fprint(bf, fileSet, astFile)
|
||||
os.WriteFile(path, bf.Bytes(), 0666)
|
||||
}
|
||||
|
||||
func needAppendRouter(funcNode ast.Node, pk string) (bool, *ast.BlockStmt) {
|
||||
flag := true
|
||||
var block *ast.BlockStmt
|
||||
ast.Inspect(funcNode, func(node ast.Node) bool {
|
||||
switch n := node.(type) {
|
||||
case *ast.BlockStmt:
|
||||
for i := range n.List {
|
||||
if assignNode, ok := n.List[i].(*ast.AssignStmt); ok {
|
||||
if identNode, ok := assignNode.Lhs[0].(*ast.Ident); ok {
|
||||
if identNode.Name == fmt.Sprintf("%sRouter", pk) {
|
||||
flag = false
|
||||
block = n
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return true
|
||||
})
|
||||
return flag, block
|
||||
}
|
||||
|
||||
func needAppendInit(funcNode ast.Node, routerName string, modelName string) bool {
|
||||
flag := true
|
||||
ast.Inspect(funcNode, func(node ast.Node) bool {
|
||||
switch n := funcNode.(type) {
|
||||
case *ast.CallExpr:
|
||||
if selectNode, ok := n.Fun.(*ast.SelectorExpr); ok {
|
||||
x, xok := selectNode.X.(*ast.Ident)
|
||||
if xok && x.Name == routerName && selectNode.Sel.Name == modelName {
|
||||
flag = false
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return flag
|
||||
}
|
32
utils/ast/ast_test.go
Normal file
32
utils/ast/ast_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/printer"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAst(t *testing.T) {
|
||||
filename := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "plugin.go")
|
||||
fileSet := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fileSet, filename, nil, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
err = ast.Print(fileSet, file)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
err = printer.Fprint(os.Stdout, token.NewFileSet(), file)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
}
|
53
utils/ast/ast_type.go
Normal file
53
utils/ast/ast_type.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package ast
|
||||
|
||||
type Type string
|
||||
|
||||
func (r Type) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func (r Type) Group() string {
|
||||
switch r {
|
||||
case TypePackageApiEnter:
|
||||
return "ApiGroup"
|
||||
case TypePackageRouterEnter:
|
||||
return "RouterGroup"
|
||||
case TypePackageServiceEnter:
|
||||
return "ServiceGroup"
|
||||
case TypePackageApiModuleEnter:
|
||||
return "ApiGroup"
|
||||
case TypePackageRouterModuleEnter:
|
||||
return "RouterGroup"
|
||||
case TypePackageServiceModuleEnter:
|
||||
return "ServiceGroup"
|
||||
case TypePluginApiEnter:
|
||||
return "api"
|
||||
case TypePluginRouterEnter:
|
||||
return "router"
|
||||
case TypePluginServiceEnter:
|
||||
return "service"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
TypePackageApiEnter = "PackageApiEnter" // server/api/v1/enter.go
|
||||
TypePackageRouterEnter = "PackageRouterEnter" // server/router/enter.go
|
||||
TypePackageServiceEnter = "PackageServiceEnter" // server/service/enter.go
|
||||
TypePackageApiModuleEnter = "PackageApiModuleEnter" // server/api/v1/{package}/enter.go
|
||||
TypePackageRouterModuleEnter = "PackageRouterModuleEnter" // server/router/{package}/enter.go
|
||||
TypePackageServiceModuleEnter = "PackageServiceModuleEnter" // server/service/{package}/enter.go
|
||||
TypePackageInitializeGorm = "PackageInitializeGorm" // server/initialize/gorm_biz.go
|
||||
TypePackageInitializeRouter = "PackageInitializeRouter" // server/initialize/router_biz.go
|
||||
TypePluginGen = "PluginGen" // server/plugin/{package}/gen/main.go
|
||||
TypePluginApiEnter = "PluginApiEnter" // server/plugin/{package}/enter.go
|
||||
TypePluginInitializeV1 = "PluginInitializeV1" // server/initialize/plugin_biz_v1.go
|
||||
TypePluginInitializeV2 = "PluginInitializeV2" // server/initialize/plugin_biz_v2.go
|
||||
TypePluginRouterEnter = "PluginRouterEnter" // server/plugin/{package}/enter.go
|
||||
TypePluginServiceEnter = "PluginServiceEnter" // server/plugin/{package}/enter.go
|
||||
TypePluginInitializeApi = "PluginInitializeApi" // server/plugin/{package}/initialize/api.go
|
||||
TypePluginInitializeGorm = "PluginInitializeGorm" // server/plugin/{package}/initialize/gorm.go
|
||||
TypePluginInitializeMenu = "PluginInitializeMenu" // server/plugin/{package}/initialize/menu.go
|
||||
TypePluginInitializeRouter = "PluginInitializeRouter" // server/plugin/{package}/initialize/router.go
|
||||
)
|
94
utils/ast/import.go
Normal file
94
utils/ast/import.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Import struct {
|
||||
Base
|
||||
ImportPath string // 导包路径
|
||||
}
|
||||
|
||||
func NewImport(importPath string) *Import {
|
||||
return &Import{ImportPath: importPath}
|
||||
}
|
||||
|
||||
func (a *Import) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *Import) Rollback(file *ast.File) error {
|
||||
if a.ImportPath == "" {
|
||||
return nil
|
||||
}
|
||||
for i := 0; i < len(file.Decls); i++ {
|
||||
v1, o1 := file.Decls[i].(*ast.GenDecl)
|
||||
if o1 {
|
||||
if v1.Tok != token.IMPORT {
|
||||
break
|
||||
}
|
||||
for j := 0; j < len(v1.Specs); j++ {
|
||||
v2, o2 := v1.Specs[j].(*ast.ImportSpec)
|
||||
if o2 && strings.HasSuffix(a.ImportPath, v2.Path.Value) {
|
||||
v1.Specs = append(v1.Specs[:j], v1.Specs[j+1:]...)
|
||||
if len(v1.Specs) == 0 {
|
||||
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
|
||||
} // 如果没有import声明,就删除, 如果不删除则会出现import()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Import) Injection(file *ast.File) error {
|
||||
if a.ImportPath == "" {
|
||||
return nil
|
||||
}
|
||||
var has bool
|
||||
for i := 0; i < len(file.Decls); i++ {
|
||||
v1, o1 := file.Decls[i].(*ast.GenDecl)
|
||||
if o1 {
|
||||
if v1.Tok != token.IMPORT {
|
||||
break
|
||||
}
|
||||
for j := 0; j < len(v1.Specs); j++ {
|
||||
v2, o2 := v1.Specs[j].(*ast.ImportSpec)
|
||||
if o2 && strings.HasSuffix(a.ImportPath, v2.Path.Value) {
|
||||
has = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !has {
|
||||
spec := &ast.ImportSpec{
|
||||
Path: &ast.BasicLit{Kind: token.STRING, Value: a.ImportPath},
|
||||
}
|
||||
v1.Specs = append(v1.Specs, spec)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if !has {
|
||||
decls := file.Decls
|
||||
file.Decls = make([]ast.Decl, 0, len(file.Decls)+1)
|
||||
decl := &ast.GenDecl{
|
||||
Tok: token.IMPORT,
|
||||
Specs: []ast.Spec{
|
||||
&ast.ImportSpec{
|
||||
Path: &ast.BasicLit{Kind: token.STRING, Value: a.ImportPath},
|
||||
},
|
||||
},
|
||||
}
|
||||
file.Decls = append(file.Decls, decl)
|
||||
file.Decls = append(file.Decls, decls...)
|
||||
} // 如果没有import声明,就创建一个, 主要要放在第一个
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Import) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
17
utils/ast/interfaces.go
Normal file
17
utils/ast/interfaces.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Ast interface {
|
||||
// Parse 解析文件/代码
|
||||
Parse(filename string, writer io.Writer) (file *ast.File, err error)
|
||||
// Rollback 回滚
|
||||
Rollback(file *ast.File) error
|
||||
// Injection 注入
|
||||
Injection(file *ast.File) error
|
||||
// Format 格式化输出
|
||||
Format(filename string, writer io.Writer, file *ast.File) error
|
||||
}
|
76
utils/ast/interfaces_base.go
Normal file
76
utils/ast/interfaces_base.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"github.com/pkg/errors"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Base struct{}
|
||||
|
||||
func (a *Base) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
fileSet := token.NewFileSet()
|
||||
if writer != nil {
|
||||
file, err = parser.ParseFile(fileSet, filename, nil, parser.ParseComments)
|
||||
} else {
|
||||
file, err = parser.ParseFile(fileSet, filename, writer, parser.ParseComments)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "[filepath:%s]打开/解析文件失败!", filename)
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (a *Base) Rollback(file *ast.File) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Base) Injection(file *ast.File) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Base) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
fileSet := token.NewFileSet()
|
||||
if writer == nil {
|
||||
open, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666)
|
||||
defer open.Close()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "[filepath:%s]打开文件失败!", filename)
|
||||
}
|
||||
writer = open
|
||||
}
|
||||
err := format.Node(writer, fileSet, file)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "[filepath:%s]注入失败!", filename)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelativePath 绝对路径转相对路径
|
||||
func (a *Base) RelativePath(filePath string) string {
|
||||
server := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server)
|
||||
hasServer := strings.Index(filePath, server)
|
||||
if hasServer != -1 {
|
||||
filePath = strings.TrimPrefix(filePath, server)
|
||||
keys := strings.Split(filePath, string(filepath.Separator))
|
||||
filePath = path.Join(keys...)
|
||||
}
|
||||
return filePath
|
||||
}
|
||||
|
||||
// AbsolutePath 相对路径转绝对路径
|
||||
func (a *Base) AbsolutePath(filePath string) string {
|
||||
server := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server)
|
||||
keys := strings.Split(filePath, "/")
|
||||
filePath = filepath.Join(keys...)
|
||||
filePath = filepath.Join(server, filePath)
|
||||
return filePath
|
||||
}
|
85
utils/ast/package_enter.go
Normal file
85
utils/ast/package_enter.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PackageEnter 模块化入口
|
||||
type PackageEnter struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
StructName string // 结构体名称
|
||||
PackageName string // 包名
|
||||
RelativePath string // 相对路径
|
||||
PackageStructName string // 包结构体名称
|
||||
}
|
||||
|
||||
func (a *PackageEnter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PackageEnter) Rollback(file *ast.File) error {
|
||||
// 无需回滚
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageEnter) Injection(file *ast.File) error {
|
||||
_ = NewImport(a.ImportPath).Injection(file)
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
genDecl, ok := n.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.TYPE {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
typeSpec, specok := spec.(*ast.TypeSpec)
|
||||
if !specok || typeSpec.Name.Name != a.Type.Group() {
|
||||
continue
|
||||
}
|
||||
|
||||
structType, structTypeOK := typeSpec.Type.(*ast.StructType)
|
||||
if !structTypeOK {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, field := range structType.Fields.List {
|
||||
if len(field.Names) == 1 && field.Names[0].Name == a.StructName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
field := &ast.Field{
|
||||
Names: []*ast.Ident{{Name: a.StructName}},
|
||||
Type: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.PackageStructName},
|
||||
},
|
||||
}
|
||||
structType.Fields.List = append(structType.Fields.List, field)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageEnter) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
154
utils/ast/package_enter_test.go
Normal file
154
utils/ast/package_enter_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPackageEnter_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
PackageName string
|
||||
PackageStructName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试ExampleApiGroup回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageApiEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/api/v1/example"`,
|
||||
StructName: "ExampleApiGroup",
|
||||
PackageName: "example",
|
||||
PackageStructName: "ApiGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试ExampleRouterGroup回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageRouterEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/router/example"`,
|
||||
StructName: "Example",
|
||||
PackageName: "example",
|
||||
PackageStructName: "RouterGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试ExampleServiceGroup回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageServiceEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/service/example"`,
|
||||
StructName: "ExampleServiceGroup",
|
||||
PackageName: "example",
|
||||
PackageStructName: "ServiceGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageEnter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
PackageStructName: tt.fields.PackageStructName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackageEnter_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
PackageName string
|
||||
PackageStructName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试ExampleApiGroup注入",
|
||||
fields: fields{
|
||||
Type: TypePackageApiEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/api/v1/example"`,
|
||||
StructName: "ExampleApiGroup",
|
||||
PackageName: "example",
|
||||
PackageStructName: "ApiGroup",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试ExampleRouterGroup注入",
|
||||
fields: fields{
|
||||
Type: TypePackageRouterEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/router/example"`,
|
||||
StructName: "Example",
|
||||
PackageName: "example",
|
||||
PackageStructName: "RouterGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试ExampleServiceGroup注入",
|
||||
fields: fields{
|
||||
Type: TypePackageServiceEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/service/example"`,
|
||||
StructName: "ExampleServiceGroup",
|
||||
PackageName: "example",
|
||||
PackageStructName: "ServiceGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageEnter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
PackageStructName: tt.fields.PackageStructName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
196
utils/ast/package_initialize_gorm.go
Normal file
196
utils/ast/package_initialize_gorm.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PackageInitializeGorm 包初始化gorm
|
||||
type PackageInitializeGorm struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
Business string // 业务库 gva => gva, 不要传"gva"
|
||||
StructName string // 结构体名称
|
||||
PackageName string // 包名
|
||||
RelativePath string // 相对路径
|
||||
IsNew bool // 是否使用new关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
|
||||
}
|
||||
|
||||
func (a *PackageInitializeGorm) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PackageInitializeGorm) Rollback(file *ast.File) error {
|
||||
packageNameNum := 0
|
||||
// 寻找目标结构
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
// 总调用的db变量根据business来决定
|
||||
varDB := a.Business + "Db"
|
||||
|
||||
if a.Business == "" {
|
||||
varDB = "db"
|
||||
}
|
||||
|
||||
callExpr, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是不是 db.AutoMigrate() 方法
|
||||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||||
if !ok || selExpr.Sel.Name != "AutoMigrate" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查调用方是不是 db
|
||||
ident, ok := selExpr.X.(*ast.Ident)
|
||||
if !ok || ident.Name != varDB {
|
||||
return true
|
||||
}
|
||||
|
||||
// 删除结构体参数
|
||||
for i := 0; i < len(callExpr.Args); i++ {
|
||||
if com, comok := callExpr.Args[i].(*ast.CompositeLit); comok {
|
||||
if selector, exprok := com.Type.(*ast.SelectorExpr); exprok {
|
||||
if x, identok := selector.X.(*ast.Ident); identok {
|
||||
if x.Name == a.PackageName {
|
||||
packageNameNum++
|
||||
if selector.Sel.Name == a.StructName {
|
||||
callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...)
|
||||
i--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if packageNameNum == 1 {
|
||||
_ = NewImport(a.ImportPath).Rollback(file)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageInitializeGorm) Injection(file *ast.File) error {
|
||||
_ = NewImport(a.ImportPath).Injection(file)
|
||||
bizModelDecl := FindFunction(file, "bizModel")
|
||||
if bizModelDecl != nil {
|
||||
a.addDbVar(bizModelDecl.Body)
|
||||
}
|
||||
// 寻找目标结构
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
// 总调用的db变量根据business来决定
|
||||
varDB := a.Business + "Db"
|
||||
|
||||
if a.Business == "" {
|
||||
varDB = "db"
|
||||
}
|
||||
|
||||
callExpr, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是不是 db.AutoMigrate() 方法
|
||||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||||
if !ok || selExpr.Sel.Name != "AutoMigrate" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查调用方是不是 db
|
||||
ident, ok := selExpr.X.(*ast.Ident)
|
||||
if !ok || ident.Name != varDB {
|
||||
return true
|
||||
}
|
||||
|
||||
// 添加结构体参数
|
||||
callExpr.Args = append(callExpr.Args, &ast.CompositeLit{
|
||||
Type: &ast.SelectorExpr{
|
||||
X: ast.NewIdent(a.PackageName),
|
||||
Sel: ast.NewIdent(a.StructName),
|
||||
},
|
||||
})
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageInitializeGorm) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
||||
|
||||
// 创建businessDB变量
|
||||
func (a *PackageInitializeGorm) addDbVar(astBody *ast.BlockStmt) {
|
||||
for i := range astBody.List {
|
||||
if assignStmt, ok := astBody.List[i].(*ast.AssignStmt); ok {
|
||||
if ident, ok := assignStmt.Lhs[0].(*ast.Ident); ok {
|
||||
if (a.Business == "" && ident.Name == "db") || ident.Name == a.Business+"Db" {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加 businessDb := global.GetGlobalDBByDBName("business") 变量
|
||||
assignNode := &ast.AssignStmt{
|
||||
Lhs: []ast.Expr{
|
||||
&ast.Ident{
|
||||
Name: a.Business + "Db",
|
||||
},
|
||||
},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{
|
||||
&ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: "global",
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: "GetGlobalDBByDBName",
|
||||
},
|
||||
},
|
||||
Args: []ast.Expr{
|
||||
&ast.BasicLit{
|
||||
Kind: token.STRING,
|
||||
Value: fmt.Sprintf("\"%s\"", a.Business),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 添加 businessDb.AutoMigrate() 方法
|
||||
autoMigrateCall := &ast.ExprStmt{
|
||||
X: &ast.CallExpr{
|
||||
Fun: &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
Name: a.Business + "Db",
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
Name: "AutoMigrate",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
returnNode := astBody.List[len(astBody.List)-1]
|
||||
astBody.List = append(astBody.List[:len(astBody.List)-1], assignNode, autoMigrateCall, returnNode)
|
||||
}
|
171
utils/ast/package_initialize_gorm_test.go
Normal file
171
utils/ast/package_initialize_gorm_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPackageInitializeGorm_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
PackageName string
|
||||
IsNew bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 &example.ExaFileUploadAndDownload{} 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaFileUploadAndDownload",
|
||||
PackageName: "example",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 &example.ExaCustomer{} 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaCustomer",
|
||||
PackageName: "example",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 new(example.ExaFileUploadAndDownload) 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaFileUploadAndDownload",
|
||||
PackageName: "example",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 new(example.ExaCustomer) 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaCustomer",
|
||||
PackageName: "example",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageInitializeGorm{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
IsNew: tt.fields.IsNew,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackageInitializeGorm_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
PackageName string
|
||||
IsNew bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 &example.ExaFileUploadAndDownload{} 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaFileUploadAndDownload",
|
||||
PackageName: "example",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 &example.ExaCustomer{} 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaCustomer",
|
||||
PackageName: "example",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 new(example.ExaFileUploadAndDownload) 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaFileUploadAndDownload",
|
||||
PackageName: "example",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 new(example.ExaCustomer) 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/model/example"`,
|
||||
StructName: "ExaCustomer",
|
||||
PackageName: "example",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageInitializeGorm{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
IsNew: tt.fields.IsNew,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
150
utils/ast/package_initialize_router.go
Normal file
150
utils/ast/package_initialize_router.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PackageInitializeRouter 包初始化路由
|
||||
// ModuleName := PackageName.AppName.GroupName
|
||||
// ModuleName.FunctionName(RouterGroupName)
|
||||
type PackageInitializeRouter struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
RelativePath string // 相对路径
|
||||
AppName string // 应用名称
|
||||
GroupName string // 分组名称
|
||||
ModuleName string // 模块名称
|
||||
PackageName string // 包名
|
||||
FunctionName string // 函数名
|
||||
RouterGroupName string // 路由分组名称
|
||||
LeftRouterGroupName string // 左路由分组名称
|
||||
RightRouterGroupName string // 右路由分组名称
|
||||
}
|
||||
|
||||
func (a *PackageInitializeRouter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PackageInitializeRouter) Rollback(file *ast.File) error {
|
||||
funcDecl := FindFunction(file, "initBizRouter")
|
||||
exprNum := 0
|
||||
for i := range funcDecl.Body.List {
|
||||
if IsBlockStmt(funcDecl.Body.List[i]) {
|
||||
if VariableExistsInBlock(funcDecl.Body.List[i].(*ast.BlockStmt), a.ModuleName) {
|
||||
for ii, stmt := range funcDecl.Body.List[i].(*ast.BlockStmt).List {
|
||||
// 检查语句是否为 *ast.ExprStmt
|
||||
exprStmt, ok := stmt.(*ast.ExprStmt)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 检查表达式是否为 *ast.CallExpr
|
||||
callExpr, ok := exprStmt.X.(*ast.CallExpr)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 检查是否调用了我们正在寻找的函数
|
||||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 检查调用的函数是否为 systemRouter.InitApiRouter
|
||||
ident, ok := selExpr.X.(*ast.Ident)
|
||||
//只要存在调用则+1
|
||||
if ok && ident.Name == a.ModuleName {
|
||||
exprNum++
|
||||
}
|
||||
//判断是否为目标结构
|
||||
if !ok || ident.Name != a.ModuleName || selExpr.Sel.Name != a.FunctionName {
|
||||
continue
|
||||
}
|
||||
exprNum--
|
||||
// 从语句列表中移除。
|
||||
funcDecl.Body.List[i].(*ast.BlockStmt).List = append(funcDecl.Body.List[i].(*ast.BlockStmt).List[:ii], funcDecl.Body.List[i].(*ast.BlockStmt).List[ii+1:]...)
|
||||
// 如果不再存在任何调用,则删除导入和变量。
|
||||
if exprNum == 0 {
|
||||
funcDecl.Body.List = append(funcDecl.Body.List[:i], funcDecl.Body.List[i+1:]...)
|
||||
}
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageInitializeRouter) Injection(file *ast.File) error {
|
||||
funcDecl := FindFunction(file, "initBizRouter")
|
||||
hasRouter := false
|
||||
var varBlock *ast.BlockStmt
|
||||
for i := range funcDecl.Body.List {
|
||||
if IsBlockStmt(funcDecl.Body.List[i]) {
|
||||
if VariableExistsInBlock(funcDecl.Body.List[i].(*ast.BlockStmt), a.ModuleName) {
|
||||
hasRouter = true
|
||||
varBlock = funcDecl.Body.List[i].(*ast.BlockStmt)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasRouter {
|
||||
stmt := a.CreateAssignStmt()
|
||||
varBlock = &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
stmt,
|
||||
},
|
||||
}
|
||||
}
|
||||
routerStmt := CreateStmt(fmt.Sprintf("%s.%s(%s,%s)", a.ModuleName, a.FunctionName, a.LeftRouterGroupName, a.RightRouterGroupName))
|
||||
varBlock.List = append(varBlock.List, routerStmt)
|
||||
if !hasRouter {
|
||||
funcDecl.Body.List = append(funcDecl.Body.List, varBlock)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageInitializeRouter) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
||||
|
||||
func (a *PackageInitializeRouter) CreateAssignStmt() *ast.AssignStmt {
|
||||
//创建左侧变量
|
||||
ident := &ast.Ident{
|
||||
Name: a.ModuleName,
|
||||
}
|
||||
|
||||
//创建右侧的赋值语句
|
||||
selector := &ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.AppName},
|
||||
},
|
||||
Sel: &ast.Ident{Name: a.GroupName},
|
||||
}
|
||||
|
||||
// 创建一个组合的赋值语句
|
||||
stmt := &ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ident},
|
||||
Tok: token.DEFINE,
|
||||
Rhs: []ast.Expr{selector},
|
||||
}
|
||||
|
||||
return stmt
|
||||
}
|
158
utils/ast/package_initialize_router_test.go
Normal file
158
utils/ast/package_initialize_router_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPackageInitializeRouter_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
AppName string
|
||||
GroupName string
|
||||
ModuleName string
|
||||
PackageName string
|
||||
FunctionName string
|
||||
RouterGroupName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 InitCustomerRouter 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/router"`,
|
||||
AppName: "RouterGroupApp",
|
||||
GroupName: "Example",
|
||||
ModuleName: "exampleRouter",
|
||||
PackageName: "router",
|
||||
FunctionName: "InitCustomerRouter",
|
||||
RouterGroupName: "privateGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 InitFileUploadAndDownloadRouter 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/router"`,
|
||||
AppName: "RouterGroupApp",
|
||||
GroupName: "Example",
|
||||
ModuleName: "exampleRouter",
|
||||
PackageName: "router",
|
||||
FunctionName: "InitFileUploadAndDownloadRouter",
|
||||
RouterGroupName: "privateGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageInitializeRouter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
AppName: tt.fields.AppName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
ModuleName: tt.fields.ModuleName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
FunctionName: tt.fields.FunctionName,
|
||||
RouterGroupName: tt.fields.RouterGroupName,
|
||||
LeftRouterGroupName: "privateGroup",
|
||||
RightRouterGroupName: "publicGroup",
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackageInitializeRouter_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
AppName string
|
||||
GroupName string
|
||||
ModuleName string
|
||||
PackageName string
|
||||
FunctionName string
|
||||
RouterGroupName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
|
||||
{
|
||||
name: "测试 InitCustomerRouter 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/router"`,
|
||||
AppName: "RouterGroupApp",
|
||||
GroupName: "Example",
|
||||
ModuleName: "exampleRouter",
|
||||
PackageName: "router",
|
||||
FunctionName: "InitCustomerRouter",
|
||||
RouterGroupName: "privateGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 InitFileUploadAndDownloadRouter 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/router"`,
|
||||
AppName: "RouterGroupApp",
|
||||
GroupName: "Example",
|
||||
ModuleName: "exampleRouter",
|
||||
PackageName: "router",
|
||||
FunctionName: "InitFileUploadAndDownloadRouter",
|
||||
RouterGroupName: "privateGroup",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageInitializeRouter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
AppName: tt.fields.AppName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
ModuleName: tt.fields.ModuleName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
FunctionName: tt.fields.FunctionName,
|
||||
RouterGroupName: tt.fields.RouterGroupName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
180
utils/ast/package_module_enter.go
Normal file
180
utils/ast/package_module_enter.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PackageModuleEnter 模块化入口
|
||||
// ModuleName := PackageName.AppName.GroupName.ServiceName
|
||||
type PackageModuleEnter struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
RelativePath string // 相对路径
|
||||
StructName string // 结构体名称
|
||||
AppName string // 应用名称
|
||||
GroupName string // 分组名称
|
||||
ModuleName string // 模块名称
|
||||
PackageName string // 包名
|
||||
ServiceName string // 服务名称
|
||||
}
|
||||
|
||||
func (a *PackageModuleEnter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PackageModuleEnter) Rollback(file *ast.File) error {
|
||||
for i := 0; i < len(file.Decls); i++ {
|
||||
v1, o1 := file.Decls[i].(*ast.GenDecl)
|
||||
if o1 {
|
||||
for j := 0; j < len(v1.Specs); j++ {
|
||||
v2, o2 := v1.Specs[j].(*ast.TypeSpec)
|
||||
if o2 {
|
||||
if v2.Name.Name != a.Type.Group() {
|
||||
continue
|
||||
}
|
||||
v3, o3 := v2.Type.(*ast.StructType)
|
||||
if o3 {
|
||||
for k := 0; k < len(v3.Fields.List); k++ {
|
||||
v4, o4 := v3.Fields.List[k].Type.(*ast.Ident)
|
||||
if o4 && v4.Name == a.StructName {
|
||||
v3.Fields.List = append(v3.Fields.List[:k], v3.Fields.List[k+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if a.Type == TypePackageServiceModuleEnter {
|
||||
continue
|
||||
}
|
||||
v3, o3 := v1.Specs[j].(*ast.ValueSpec)
|
||||
if o3 {
|
||||
if len(v3.Names) == 1 && v3.Names[0].Name == a.ModuleName {
|
||||
v1.Specs = append(v1.Specs[:j], v1.Specs[j+1:]...)
|
||||
}
|
||||
}
|
||||
if v1.Tok == token.VAR && len(v1.Specs) == 0 {
|
||||
_ = NewImport(a.ImportPath).Rollback(file)
|
||||
if i == len(file.Decls) {
|
||||
file.Decls = append(file.Decls[:i-1])
|
||||
break
|
||||
} // 空的var(), 如果不删除则会影响的注入变量, 因为识别不到*ast.ValueSpec
|
||||
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageModuleEnter) Injection(file *ast.File) error {
|
||||
_ = NewImport(a.ImportPath).Injection(file)
|
||||
var hasValue bool
|
||||
var hasVariables bool
|
||||
for i := 0; i < len(file.Decls); i++ {
|
||||
v1, o1 := file.Decls[i].(*ast.GenDecl)
|
||||
if o1 {
|
||||
if v1.Tok == token.VAR {
|
||||
hasVariables = true
|
||||
}
|
||||
for j := 0; j < len(v1.Specs); j++ {
|
||||
if a.Type == TypePackageServiceModuleEnter {
|
||||
hasValue = true
|
||||
}
|
||||
v2, o2 := v1.Specs[j].(*ast.TypeSpec)
|
||||
if o2 {
|
||||
if v2.Name.Name != a.Type.Group() {
|
||||
continue
|
||||
}
|
||||
v3, o3 := v2.Type.(*ast.StructType)
|
||||
if o3 {
|
||||
var hasStruct bool
|
||||
for k := 0; k < len(v3.Fields.List); k++ {
|
||||
v4, o4 := v3.Fields.List[k].Type.(*ast.Ident)
|
||||
if o4 && v4.Name == a.StructName {
|
||||
hasStruct = true
|
||||
}
|
||||
}
|
||||
if !hasStruct {
|
||||
field := &ast.Field{Type: &ast.Ident{Name: a.StructName}}
|
||||
v3.Fields.List = append(v3.Fields.List, field)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
v3, o3 := v1.Specs[j].(*ast.ValueSpec)
|
||||
if o3 {
|
||||
hasVariables = true
|
||||
if len(v3.Names) == 1 && v3.Names[0].Name == a.ModuleName {
|
||||
hasValue = true
|
||||
}
|
||||
}
|
||||
if v1.Tok == token.VAR && len(v1.Specs) == 0 {
|
||||
hasVariables = false
|
||||
} // 说明是空var()
|
||||
if hasVariables && !hasValue {
|
||||
spec := &ast.ValueSpec{
|
||||
Names: []*ast.Ident{{Name: a.ModuleName}},
|
||||
Values: []ast.Expr{
|
||||
&ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.AppName},
|
||||
},
|
||||
Sel: &ast.Ident{Name: a.GroupName},
|
||||
},
|
||||
Sel: &ast.Ident{Name: a.ServiceName},
|
||||
},
|
||||
},
|
||||
}
|
||||
v1.Specs = append(v1.Specs, spec)
|
||||
hasValue = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasValue && !hasVariables {
|
||||
decl := &ast.GenDecl{
|
||||
Tok: token.VAR,
|
||||
Specs: []ast.Spec{
|
||||
&ast.ValueSpec{
|
||||
Names: []*ast.Ident{{Name: a.ModuleName}},
|
||||
Values: []ast.Expr{
|
||||
&ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.AppName},
|
||||
},
|
||||
Sel: &ast.Ident{Name: a.GroupName},
|
||||
},
|
||||
Sel: &ast.Ident{Name: a.ServiceName},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
file.Decls = append(file.Decls, decl)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PackageModuleEnter) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
185
utils/ast/package_module_enter_test.go
Normal file
185
utils/ast/package_module_enter_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPackageModuleEnter_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
AppName string
|
||||
GroupName string
|
||||
ModuleName string
|
||||
PackageName string
|
||||
ServiceName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 FileUploadAndDownloadRouter 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageRouterModuleEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "example", "enter.go"),
|
||||
ImportPath: `api "git.echol.cn/loser/lckt/api/v1"`,
|
||||
StructName: "FileUploadAndDownloadRouter",
|
||||
AppName: "ApiGroupApp",
|
||||
GroupName: "ExampleApiGroup",
|
||||
ModuleName: "exaFileUploadAndDownloadApi",
|
||||
PackageName: "api",
|
||||
ServiceName: "FileUploadAndDownloadApi",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 FileUploadAndDownloadApi 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageApiModuleEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "example", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/service"`,
|
||||
StructName: "FileUploadAndDownloadApi",
|
||||
AppName: "ServiceGroupApp",
|
||||
GroupName: "ExampleServiceGroup",
|
||||
ModuleName: "fileUploadAndDownloadService",
|
||||
PackageName: "service",
|
||||
ServiceName: "FileUploadAndDownloadService",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 FileUploadAndDownloadService 回滚",
|
||||
fields: fields{
|
||||
Type: TypePackageServiceModuleEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "example", "enter.go"),
|
||||
ImportPath: ``,
|
||||
StructName: "FileUploadAndDownloadService",
|
||||
AppName: "",
|
||||
GroupName: "",
|
||||
ModuleName: "",
|
||||
PackageName: "",
|
||||
ServiceName: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageModuleEnter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
AppName: tt.fields.AppName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
ModuleName: tt.fields.ModuleName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
ServiceName: tt.fields.ServiceName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackageModuleEnter_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
AppName string
|
||||
GroupName string
|
||||
ModuleName string
|
||||
PackageName string
|
||||
ServiceName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 FileUploadAndDownloadRouter 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageRouterModuleEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "example", "enter.go"),
|
||||
ImportPath: `api "git.echol.cn/loser/lckt/api/v1"`,
|
||||
StructName: "FileUploadAndDownloadRouter",
|
||||
AppName: "ApiGroupApp",
|
||||
GroupName: "ExampleApiGroup",
|
||||
ModuleName: "exaFileUploadAndDownloadApi",
|
||||
PackageName: "api",
|
||||
ServiceName: "FileUploadAndDownloadApi",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 FileUploadAndDownloadApi 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageApiModuleEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "example", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/service"`,
|
||||
StructName: "FileUploadAndDownloadApi",
|
||||
AppName: "ServiceGroupApp",
|
||||
GroupName: "ExampleServiceGroup",
|
||||
ModuleName: "fileUploadAndDownloadService",
|
||||
PackageName: "service",
|
||||
ServiceName: "FileUploadAndDownloadService",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 FileUploadAndDownloadService 注入",
|
||||
fields: fields{
|
||||
Type: TypePackageServiceModuleEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "example", "enter.go"),
|
||||
ImportPath: ``,
|
||||
StructName: "FileUploadAndDownloadService",
|
||||
AppName: "",
|
||||
GroupName: "",
|
||||
ModuleName: "",
|
||||
PackageName: "",
|
||||
ServiceName: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PackageModuleEnter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
AppName: tt.fields.AppName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
ModuleName: tt.fields.ModuleName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
ServiceName: tt.fields.ServiceName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
167
utils/ast/plugin_enter.go
Normal file
167
utils/ast/plugin_enter.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PluginEnter 插件化入口
|
||||
// ModuleName := PackageName.GroupName.ServiceName
|
||||
type PluginEnter struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
RelativePath string // 相对路径
|
||||
StructName string // 结构体名称
|
||||
StructCamelName string // 结构体小驼峰名称
|
||||
ModuleName string // 模块名称
|
||||
GroupName string // 分组名称
|
||||
PackageName string // 包名
|
||||
ServiceName string // 服务名称
|
||||
}
|
||||
|
||||
func (a *PluginEnter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PluginEnter) Rollback(file *ast.File) error {
|
||||
//回滚结构体内内容
|
||||
var structType *ast.StructType
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
switch x := n.(type) {
|
||||
case *ast.TypeSpec:
|
||||
if s, ok := x.Type.(*ast.StructType); ok {
|
||||
structType = s
|
||||
for i, field := range x.Type.(*ast.StructType).Fields.List {
|
||||
if len(field.Names) > 0 && field.Names[0].Name == a.StructName {
|
||||
s.Fields.List = append(s.Fields.List[:i], s.Fields.List[i+1:]...)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(structType.Fields.List) == 0 {
|
||||
_ = NewImport(a.ImportPath).Rollback(file)
|
||||
}
|
||||
|
||||
if a.Type == TypePluginServiceEnter {
|
||||
return nil
|
||||
}
|
||||
|
||||
//回滚变量内容
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
genDecl, ok := n.(*ast.GenDecl)
|
||||
if ok && genDecl.Tok == token.VAR {
|
||||
for i, spec := range genDecl.Specs {
|
||||
valueSpec, vsok := spec.(*ast.ValueSpec)
|
||||
if vsok {
|
||||
for _, name := range valueSpec.Names {
|
||||
if name.Name == a.ModuleName {
|
||||
genDecl.Specs = append(genDecl.Specs[:i], genDecl.Specs[i+1:]...)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginEnter) Injection(file *ast.File) error {
|
||||
_ = NewImport(a.ImportPath).Injection(file)
|
||||
|
||||
has := false
|
||||
hasVar := false
|
||||
var firstStruct *ast.StructType
|
||||
var varSpec *ast.GenDecl
|
||||
//寻找是否存在结构且定位
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
switch x := n.(type) {
|
||||
case *ast.TypeSpec:
|
||||
if s, ok := x.Type.(*ast.StructType); ok {
|
||||
firstStruct = s
|
||||
for _, field := range x.Type.(*ast.StructType).Fields.List {
|
||||
if len(field.Names) > 0 && field.Names[0].Name == a.StructName {
|
||||
has = true
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !has {
|
||||
field := &ast.Field{
|
||||
Names: []*ast.Ident{{Name: a.StructName}},
|
||||
Type: &ast.Ident{Name: a.StructCamelName},
|
||||
}
|
||||
firstStruct.Fields.List = append(firstStruct.Fields.List, field)
|
||||
}
|
||||
|
||||
if a.Type == TypePluginServiceEnter {
|
||||
return nil
|
||||
}
|
||||
|
||||
//寻找是否存在变量且定位
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
genDecl, ok := n.(*ast.GenDecl)
|
||||
if ok && genDecl.Tok == token.VAR {
|
||||
for _, spec := range genDecl.Specs {
|
||||
valueSpec, vsok := spec.(*ast.ValueSpec)
|
||||
if vsok {
|
||||
varSpec = genDecl
|
||||
for _, name := range valueSpec.Names {
|
||||
if name.Name == a.ModuleName {
|
||||
hasVar = true
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !hasVar {
|
||||
spec := &ast.ValueSpec{
|
||||
Names: []*ast.Ident{{Name: a.ModuleName}},
|
||||
Values: []ast.Expr{
|
||||
&ast.SelectorExpr{
|
||||
X: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.GroupName},
|
||||
},
|
||||
Sel: &ast.Ident{Name: a.ServiceName},
|
||||
},
|
||||
},
|
||||
}
|
||||
varSpec.Specs = append(varSpec.Specs, spec)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginEnter) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
200
utils/ast/plugin_enter_test.go
Normal file
200
utils/ast/plugin_enter_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPluginEnter_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
StructCamelName string
|
||||
ModuleName string
|
||||
GroupName string
|
||||
PackageName string
|
||||
ServiceName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 Gva插件UserApi 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginApiEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "api", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/service"`,
|
||||
StructName: "User",
|
||||
StructCamelName: "user",
|
||||
ModuleName: "serviceUser",
|
||||
GroupName: "Service",
|
||||
PackageName: "service",
|
||||
ServiceName: "User",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 Gva插件UserRouter 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginRouterEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "router", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/api"`,
|
||||
StructName: "User",
|
||||
StructCamelName: "user",
|
||||
ModuleName: "userApi",
|
||||
GroupName: "Api",
|
||||
PackageName: "api",
|
||||
ServiceName: "User",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 Gva插件UserService 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginServiceEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "service", "enter.go"),
|
||||
ImportPath: "",
|
||||
StructName: "User",
|
||||
StructCamelName: "user",
|
||||
ModuleName: "",
|
||||
GroupName: "",
|
||||
PackageName: "",
|
||||
ServiceName: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 gva的User 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginServiceEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "service", "enter.go"),
|
||||
ImportPath: "",
|
||||
StructName: "User",
|
||||
StructCamelName: "user",
|
||||
ModuleName: "",
|
||||
GroupName: "",
|
||||
PackageName: "",
|
||||
ServiceName: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginEnter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
StructCamelName: tt.fields.StructCamelName,
|
||||
ModuleName: tt.fields.ModuleName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
ServiceName: tt.fields.ServiceName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginEnter_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
StructCamelName string
|
||||
ModuleName string
|
||||
GroupName string
|
||||
PackageName string
|
||||
ServiceName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 Gva插件UserRouter 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginRouterEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "router", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/api"`,
|
||||
StructName: "User",
|
||||
StructCamelName: "user",
|
||||
ModuleName: "userApi",
|
||||
GroupName: "Api",
|
||||
PackageName: "api",
|
||||
ServiceName: "User",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 Gva插件UserApi 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginApiEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "api", "enter.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/service"`,
|
||||
StructName: "User",
|
||||
StructCamelName: "user",
|
||||
ModuleName: "serviceUser",
|
||||
GroupName: "Service",
|
||||
PackageName: "service",
|
||||
ServiceName: "User",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 Gva插件UserService 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginServiceEnter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "service", "enter.go"),
|
||||
ImportPath: "",
|
||||
StructName: "User",
|
||||
StructCamelName: "user",
|
||||
ModuleName: "",
|
||||
GroupName: "",
|
||||
PackageName: "",
|
||||
ServiceName: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginEnter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
StructCamelName: tt.fields.StructCamelName,
|
||||
ModuleName: tt.fields.ModuleName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
ServiceName: tt.fields.ServiceName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
189
utils/ast/plugin_gen.go
Normal file
189
utils/ast/plugin_gen.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"io"
|
||||
)
|
||||
|
||||
type PluginGen struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
RelativePath string // 相对路径
|
||||
StructName string // 结构体名称
|
||||
PackageName string // 包名
|
||||
IsNew bool // 是否使用new关键字
|
||||
}
|
||||
|
||||
func (a *PluginGen) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
func (a *PluginGen) Rollback(file *ast.File) error {
|
||||
for i := 0; i < len(file.Decls); i++ {
|
||||
v1, o1 := file.Decls[i].(*ast.FuncDecl)
|
||||
if o1 {
|
||||
for j := 0; j < len(v1.Body.List); j++ {
|
||||
v2, o2 := v1.Body.List[j].(*ast.ExprStmt)
|
||||
if o2 {
|
||||
v3, o3 := v2.X.(*ast.CallExpr)
|
||||
if o3 {
|
||||
v4, o4 := v3.Fun.(*ast.SelectorExpr)
|
||||
if o4 {
|
||||
if v4.Sel.Name != "ApplyBasic" {
|
||||
continue
|
||||
}
|
||||
for k := 0; k < len(v3.Args); k++ {
|
||||
v5, o5 := v3.Args[k].(*ast.CallExpr)
|
||||
if o5 {
|
||||
v6, o6 := v5.Fun.(*ast.Ident)
|
||||
if o6 {
|
||||
if v6.Name != "new" {
|
||||
continue
|
||||
}
|
||||
for l := 0; l < len(v5.Args); l++ {
|
||||
v7, o7 := v5.Args[l].(*ast.SelectorExpr)
|
||||
if o7 {
|
||||
v8, o8 := v7.X.(*ast.Ident)
|
||||
if o8 {
|
||||
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
|
||||
v3.Args = append(v3.Args[:k], v3.Args[k+1:]...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if k >= len(v3.Args) {
|
||||
break
|
||||
}
|
||||
v6, o6 := v3.Args[k].(*ast.CompositeLit)
|
||||
if o6 {
|
||||
v7, o7 := v6.Type.(*ast.SelectorExpr)
|
||||
if o7 {
|
||||
v8, o8 := v7.X.(*ast.Ident)
|
||||
if o8 {
|
||||
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
|
||||
v3.Args = append(v3.Args[:k], v3.Args[k+1:]...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(v3.Args) == 0 {
|
||||
_ = NewImport(a.ImportPath).Rollback(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginGen) Injection(file *ast.File) error {
|
||||
_ = NewImport(a.ImportPath).Injection(file)
|
||||
for i := 0; i < len(file.Decls); i++ {
|
||||
v1, o1 := file.Decls[i].(*ast.FuncDecl)
|
||||
if o1 {
|
||||
for j := 0; j < len(v1.Body.List); j++ {
|
||||
v2, o2 := v1.Body.List[j].(*ast.ExprStmt)
|
||||
if o2 {
|
||||
v3, o3 := v2.X.(*ast.CallExpr)
|
||||
if o3 {
|
||||
v4, o4 := v3.Fun.(*ast.SelectorExpr)
|
||||
if o4 {
|
||||
if v4.Sel.Name != "ApplyBasic" {
|
||||
continue
|
||||
}
|
||||
var has bool
|
||||
for k := 0; k < len(v3.Args); k++ {
|
||||
v5, o5 := v3.Args[k].(*ast.CallExpr)
|
||||
if o5 {
|
||||
v6, o6 := v5.Fun.(*ast.Ident)
|
||||
if o6 {
|
||||
if v6.Name != "new" {
|
||||
continue
|
||||
}
|
||||
for l := 0; l < len(v5.Args); l++ {
|
||||
v7, o7 := v5.Args[l].(*ast.SelectorExpr)
|
||||
if o7 {
|
||||
v8, o8 := v7.X.(*ast.Ident)
|
||||
if o8 {
|
||||
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
|
||||
has = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
v6, o6 := v3.Args[k].(*ast.CompositeLit)
|
||||
if o6 {
|
||||
v7, o7 := v6.Type.(*ast.SelectorExpr)
|
||||
if o7 {
|
||||
v8, o8 := v7.X.(*ast.Ident)
|
||||
if o8 {
|
||||
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
|
||||
has = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !has {
|
||||
if a.IsNew {
|
||||
arg := &ast.CallExpr{
|
||||
Fun: &ast.Ident{Name: "\n\t\tnew"},
|
||||
Args: []ast.Expr{
|
||||
&ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.StructName},
|
||||
},
|
||||
},
|
||||
}
|
||||
v3.Args = append(v3.Args, arg)
|
||||
v3.Args = append(v3.Args, &ast.BasicLit{
|
||||
Kind: token.STRING,
|
||||
Value: "\n",
|
||||
})
|
||||
break
|
||||
}
|
||||
arg := &ast.CompositeLit{
|
||||
Type: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.StructName},
|
||||
},
|
||||
}
|
||||
v3.Args = append(v3.Args, arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginGen) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
127
utils/ast/plugin_gen_test.go
Normal file
127
utils/ast/plugin_gen_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPluginGenModel_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
PackageName string
|
||||
StructName string
|
||||
IsNew bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 GvaUser 结构体注入",
|
||||
fields: fields{
|
||||
Type: TypePluginGen,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
PackageName: "model",
|
||||
StructName: "User",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 GvaUser 结构体注入",
|
||||
fields: fields{
|
||||
Type: TypePluginGen,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
PackageName: "model",
|
||||
StructName: "User",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginGen{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
PackageName: tt.fields.PackageName,
|
||||
StructName: tt.fields.StructName,
|
||||
IsNew: tt.fields.IsNew,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginGenModel_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
PackageName string
|
||||
StructName string
|
||||
IsNew bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 GvaUser 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginGen,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
PackageName: "model",
|
||||
StructName: "User",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 GvaUser 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginGen,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
PackageName: "model",
|
||||
StructName: "User",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginGen{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
PackageName: tt.fields.PackageName,
|
||||
StructName: tt.fields.StructName,
|
||||
IsNew: tt.fields.IsNew,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
111
utils/ast/plugin_initialize_gorm.go
Normal file
111
utils/ast/plugin_initialize_gorm.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"io"
|
||||
)
|
||||
|
||||
type PluginInitializeGorm struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
RelativePath string // 相对路径
|
||||
StructName string // 结构体名称
|
||||
PackageName string // 包名
|
||||
IsNew bool // 是否使用new关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
|
||||
}
|
||||
|
||||
func (a *PluginInitializeGorm) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PluginInitializeGorm) Rollback(file *ast.File) error {
|
||||
var needRollBackImport bool
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
callExpr, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
selExpr, seok := callExpr.Fun.(*ast.SelectorExpr)
|
||||
if !seok || selExpr.Sel.Name != "AutoMigrate" {
|
||||
return true
|
||||
}
|
||||
if len(callExpr.Args) <= 1 {
|
||||
needRollBackImport = true
|
||||
}
|
||||
// 删除指定的参数
|
||||
for i, arg := range callExpr.Args {
|
||||
compLit, cok := arg.(*ast.CompositeLit)
|
||||
if !cok {
|
||||
continue
|
||||
}
|
||||
|
||||
cselExpr, sok := compLit.Type.(*ast.SelectorExpr)
|
||||
if !sok {
|
||||
continue
|
||||
}
|
||||
|
||||
ident, idok := cselExpr.X.(*ast.Ident)
|
||||
if idok && ident.Name == a.PackageName && cselExpr.Sel.Name == a.StructName {
|
||||
// 删除参数
|
||||
callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if needRollBackImport {
|
||||
_ = NewImport(a.ImportPath).Rollback(file)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginInitializeGorm) Injection(file *ast.File) error {
|
||||
_ = NewImport(a.ImportPath).Injection(file)
|
||||
var call *ast.CallExpr
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
callExpr, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||||
if ok && selExpr.Sel.Name == "AutoMigrate" {
|
||||
call = callExpr
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
arg := &ast.CompositeLit{
|
||||
Type: &ast.SelectorExpr{
|
||||
X: &ast.Ident{Name: a.PackageName},
|
||||
Sel: &ast.Ident{Name: a.StructName},
|
||||
},
|
||||
}
|
||||
|
||||
call.Args = append(call.Args, arg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginInitializeGorm) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
138
utils/ast/plugin_initialize_gorm_test.go
Normal file
138
utils/ast/plugin_initialize_gorm_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPluginInitializeGorm_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
PackageName string
|
||||
IsNew bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 &model.User{} 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
StructName: "User",
|
||||
PackageName: "model",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 new(model.ExaCustomer) 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
StructName: "User",
|
||||
PackageName: "model",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 new(model.SysUsers) 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
StructName: "SysUser",
|
||||
PackageName: "model",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginInitializeGorm{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
IsNew: tt.fields.IsNew,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginInitializeGorm_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
StructName string
|
||||
PackageName string
|
||||
IsNew bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 &model.User{} 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
StructName: "User",
|
||||
PackageName: "model",
|
||||
IsNew: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "测试 new(model.ExaCustomer) 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeGorm,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/model"`,
|
||||
StructName: "User",
|
||||
PackageName: "model",
|
||||
IsNew: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginInitializeGorm{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: tt.fields.StructName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
IsNew: tt.fields.IsNew,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
124
utils/ast/plugin_initialize_router.go
Normal file
124
utils/ast/plugin_initialize_router.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PluginInitializeRouter 插件初始化路由
|
||||
// PackageName.AppName.GroupName.FunctionName()
|
||||
type PluginInitializeRouter struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
ImportPath string // 导包路径
|
||||
ImportGlobalPath string // 导包全局变量路径
|
||||
ImportMiddlewarePath string // 导包中间件路径
|
||||
RelativePath string // 相对路径
|
||||
AppName string // 应用名称
|
||||
GroupName string // 分组名称
|
||||
PackageName string // 包名
|
||||
FunctionName string // 函数名
|
||||
LeftRouterGroupName string // 左路由分组名称
|
||||
RightRouterGroupName string // 右路由分组名称
|
||||
}
|
||||
|
||||
func (a *PluginInitializeRouter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.Path
|
||||
a.RelativePath = a.Base.RelativePath(a.Path)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.Path = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PluginInitializeRouter) Rollback(file *ast.File) error {
|
||||
funcDecl := FindFunction(file, "Router")
|
||||
delI := 0
|
||||
routerNum := 0
|
||||
for i := len(funcDecl.Body.List) - 1; i >= 0; i-- {
|
||||
stmt, ok := funcDecl.Body.List[i].(*ast.ExprStmt)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
callExpr, ok := stmt.X.(*ast.CallExpr)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
ident, ok := selExpr.X.(*ast.SelectorExpr)
|
||||
|
||||
if ok {
|
||||
if iExpr, ieok := ident.X.(*ast.SelectorExpr); ieok {
|
||||
if iden, idok := iExpr.X.(*ast.Ident); idok {
|
||||
if iden.Name == "router" {
|
||||
routerNum++
|
||||
}
|
||||
}
|
||||
}
|
||||
if ident.Sel.Name == a.GroupName && selExpr.Sel.Name == a.FunctionName {
|
||||
// 删除语句
|
||||
delI = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
funcDecl.Body.List = append(funcDecl.Body.List[:delI], funcDecl.Body.List[delI+1:]...)
|
||||
|
||||
if routerNum <= 1 {
|
||||
_ = NewImport(a.ImportPath).Rollback(file)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginInitializeRouter) Injection(file *ast.File) error {
|
||||
_ = NewImport(a.ImportPath).Injection(file)
|
||||
funcDecl := FindFunction(file, "Router")
|
||||
|
||||
var exists bool
|
||||
|
||||
ast.Inspect(funcDecl, func(n ast.Node) bool {
|
||||
callExpr, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
ident, ok := selExpr.X.(*ast.SelectorExpr)
|
||||
if ok && ident.Sel.Name == a.GroupName && selExpr.Sel.Name == a.FunctionName {
|
||||
exists = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !exists {
|
||||
stmtStr := fmt.Sprintf("%s.%s.%s.%s(%s, %s)", a.PackageName, a.AppName, a.GroupName, a.FunctionName, a.LeftRouterGroupName, a.RightRouterGroupName)
|
||||
stmt := CreateStmt(stmtStr)
|
||||
funcDecl.Body.List = append(funcDecl.Body.List, stmt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginInitializeRouter) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.Path
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
155
utils/ast/plugin_initialize_router_test.go
Normal file
155
utils/ast/plugin_initialize_router_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPluginInitializeRouter_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
AppName string
|
||||
GroupName string
|
||||
PackageName string
|
||||
FunctionName string
|
||||
LeftRouterGroupName string
|
||||
RightRouterGroupName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 Gva插件User 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/router"`,
|
||||
AppName: "Router",
|
||||
GroupName: "User",
|
||||
PackageName: "router",
|
||||
FunctionName: "Init",
|
||||
LeftRouterGroupName: "public",
|
||||
RightRouterGroupName: "private",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 中文 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/router"`,
|
||||
AppName: "Router",
|
||||
GroupName: "U中文",
|
||||
PackageName: "router",
|
||||
FunctionName: "Init",
|
||||
LeftRouterGroupName: "public",
|
||||
RightRouterGroupName: "private",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginInitializeRouter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
AppName: tt.fields.AppName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
FunctionName: tt.fields.FunctionName,
|
||||
LeftRouterGroupName: tt.fields.LeftRouterGroupName,
|
||||
RightRouterGroupName: tt.fields.RightRouterGroupName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginInitializeRouter_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
ImportPath string
|
||||
AppName string
|
||||
GroupName string
|
||||
PackageName string
|
||||
FunctionName string
|
||||
LeftRouterGroupName string
|
||||
RightRouterGroupName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 Gva插件User 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/router"`,
|
||||
AppName: "Router",
|
||||
GroupName: "User",
|
||||
PackageName: "router",
|
||||
FunctionName: "Init",
|
||||
LeftRouterGroupName: "public",
|
||||
RightRouterGroupName: "private",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "测试 中文 注入",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeRouter,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva/router"`,
|
||||
AppName: "Router",
|
||||
GroupName: "U中文",
|
||||
PackageName: "router",
|
||||
FunctionName: "Init",
|
||||
LeftRouterGroupName: "public",
|
||||
RightRouterGroupName: "private",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &PluginInitializeRouter{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
AppName: tt.fields.AppName,
|
||||
GroupName: tt.fields.GroupName,
|
||||
PackageName: tt.fields.PackageName,
|
||||
FunctionName: tt.fields.FunctionName,
|
||||
LeftRouterGroupName: tt.fields.LeftRouterGroupName,
|
||||
RightRouterGroupName: tt.fields.RightRouterGroupName,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
52
utils/ast/plugin_initialize_v2.go
Normal file
52
utils/ast/plugin_initialize_v2.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"io"
|
||||
)
|
||||
|
||||
type PluginInitializeV2 struct {
|
||||
Base
|
||||
Type Type // 类型
|
||||
Path string // 文件路径
|
||||
PluginPath string // 插件路径
|
||||
RelativePath string // 相对路径
|
||||
ImportPath string // 导包路径
|
||||
StructName string // 结构体名称
|
||||
PackageName string // 包名
|
||||
}
|
||||
|
||||
func (a *PluginInitializeV2) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
|
||||
if filename == "" {
|
||||
if a.RelativePath == "" {
|
||||
filename = a.PluginPath
|
||||
a.RelativePath = a.Base.RelativePath(a.PluginPath)
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
a.PluginPath = a.Base.AbsolutePath(a.RelativePath)
|
||||
filename = a.PluginPath
|
||||
}
|
||||
return a.Base.Parse(filename, writer)
|
||||
}
|
||||
|
||||
func (a *PluginInitializeV2) Injection(file *ast.File) error {
|
||||
if !CheckImport(file, a.ImportPath) {
|
||||
NewImport(a.ImportPath).Injection(file)
|
||||
funcDecl := FindFunction(file, "bizPluginV2")
|
||||
stmt := CreateStmt(fmt.Sprintf("PluginInitV2(engine, %s.Plugin)", a.PackageName))
|
||||
funcDecl.Body.List = append(funcDecl.Body.List, stmt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginInitializeV2) Rollback(file *ast.File) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *PluginInitializeV2) Format(filename string, writer io.Writer, file *ast.File) error {
|
||||
if filename == "" {
|
||||
filename = a.PluginPath
|
||||
}
|
||||
return a.Base.Format(filename, writer, file)
|
||||
}
|
100
utils/ast/plugin_initialize_v2_test.go
Normal file
100
utils/ast/plugin_initialize_v2_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPluginInitialize_Injection(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
PluginPath string
|
||||
ImportPath string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 Gva插件 注册注入",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeV2,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "plugin_biz_v2.go"),
|
||||
PluginPath: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "plugin.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva"`,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := PluginInitializeV2{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
PluginPath: tt.fields.PluginPath,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Injection(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginInitialize_Rollback(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Path string
|
||||
PluginPath string
|
||||
ImportPath string
|
||||
PluginName string
|
||||
StructName string
|
||||
PackageName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "测试 Gva插件 回滚",
|
||||
fields: fields{
|
||||
Type: TypePluginInitializeV2,
|
||||
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "plugin_biz_v2.go"),
|
||||
PluginPath: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "plugin.go"),
|
||||
ImportPath: `"git.echol.cn/loser/lckt/plugin/gva"`,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := PluginInitializeV2{
|
||||
Type: tt.fields.Type,
|
||||
Path: tt.fields.Path,
|
||||
PluginPath: tt.fields.PluginPath,
|
||||
ImportPath: tt.fields.ImportPath,
|
||||
StructName: "Plugin",
|
||||
PackageName: "gva",
|
||||
}
|
||||
file, err := a.Parse(a.Path, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
a.Rollback(file)
|
||||
err = a.Format(a.Path, nil, file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
112
utils/breakpoint_continue.go
Normal file
112
utils/breakpoint_continue.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 前端传来文件片与当前片为什么文件的第几片
|
||||
// 后端拿到以后比较次分片是否上传 或者是否为不完全片
|
||||
// 前端发送每片多大
|
||||
// 前端告知是否为最后一片且是否完成
|
||||
|
||||
const (
|
||||
breakpointDir = "./breakpointDir/"
|
||||
finishDir = "./fileDir/"
|
||||
)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: BreakPointContinue
|
||||
//@description: 断点续传
|
||||
//@param: content []byte, fileName string, contentNumber int, contentTotal int, fileMd5 string
|
||||
//@return: error, string
|
||||
|
||||
func BreakPointContinue(content []byte, fileName string, contentNumber int, contentTotal int, fileMd5 string) (string, error) {
|
||||
path := breakpointDir + fileMd5 + "/"
|
||||
err := os.MkdirAll(path, os.ModePerm)
|
||||
if err != nil {
|
||||
return path, err
|
||||
}
|
||||
pathC, err := makeFileContent(content, fileName, path, contentNumber)
|
||||
return pathC, err
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: CheckMd5
|
||||
//@description: 检查Md5
|
||||
//@param: content []byte, chunkMd5 string
|
||||
//@return: CanUpload bool
|
||||
|
||||
func CheckMd5(content []byte, chunkMd5 string) (CanUpload bool) {
|
||||
fileMd5 := MD5V(content)
|
||||
if fileMd5 == chunkMd5 {
|
||||
return true // 可以继续上传
|
||||
} else {
|
||||
return false // 切片不完整,废弃
|
||||
}
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: makeFileContent
|
||||
//@description: 创建切片内容
|
||||
//@param: content []byte, fileName string, FileDir string, contentNumber int
|
||||
//@return: string, error
|
||||
|
||||
func makeFileContent(content []byte, fileName string, FileDir string, contentNumber int) (string, error) {
|
||||
if strings.Contains(fileName, "..") || strings.Contains(FileDir, "..") {
|
||||
return "", errors.New("文件名或路径不合法")
|
||||
}
|
||||
path := FileDir + fileName + "_" + strconv.Itoa(contentNumber)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return path, err
|
||||
} else {
|
||||
_, err = f.Write(content)
|
||||
if err != nil {
|
||||
return path, err
|
||||
}
|
||||
}
|
||||
defer f.Close()
|
||||
return path, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: makeFileContent
|
||||
//@description: 创建切片文件
|
||||
//@param: fileName string, FileMd5 string
|
||||
//@return: error, string
|
||||
|
||||
func MakeFile(fileName string, FileMd5 string) (string, error) {
|
||||
rd, err := os.ReadDir(breakpointDir + FileMd5)
|
||||
if err != nil {
|
||||
return finishDir + fileName, err
|
||||
}
|
||||
_ = os.MkdirAll(finishDir, os.ModePerm)
|
||||
fd, err := os.OpenFile(finishDir+fileName, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return finishDir + fileName, err
|
||||
}
|
||||
defer fd.Close()
|
||||
for k := range rd {
|
||||
content, _ := os.ReadFile(breakpointDir + FileMd5 + "/" + fileName + "_" + strconv.Itoa(k))
|
||||
_, err = fd.Write(content)
|
||||
if err != nil {
|
||||
_ = os.Remove(finishDir + fileName)
|
||||
return finishDir + fileName, err
|
||||
}
|
||||
}
|
||||
return finishDir + fileName, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: RemoveChunk
|
||||
//@description: 移除切片
|
||||
//@param: FileMd5 string
|
||||
//@return: error
|
||||
|
||||
func RemoveChunk(FileMd5 string) error {
|
||||
err := os.RemoveAll(breakpointDir + FileMd5)
|
||||
return err
|
||||
}
|
60
utils/captcha/redis.go
Normal file
60
utils/captcha/redis.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func NewDefaultRedisStore() *RedisStore {
|
||||
return &RedisStore{
|
||||
Expiration: time.Second * 180,
|
||||
PreKey: "CAPTCHA_",
|
||||
Context: context.TODO(),
|
||||
}
|
||||
}
|
||||
|
||||
type RedisStore struct {
|
||||
Expiration time.Duration
|
||||
PreKey string
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
func (rs *RedisStore) UseWithCtx(ctx context.Context) base64Captcha.Store {
|
||||
rs.Context = ctx
|
||||
return rs
|
||||
}
|
||||
|
||||
func (rs *RedisStore) Set(id string, value string) error {
|
||||
err := global.GVA_REDIS.Set(rs.Context, rs.PreKey+id, value, rs.Expiration).Err()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("RedisStoreSetError!", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RedisStore) Get(key string, clear bool) string {
|
||||
val, err := global.GVA_REDIS.Get(rs.Context, key).Result()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("RedisStoreGetError!", zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
if clear {
|
||||
err := global.GVA_REDIS.Del(rs.Context, key).Err()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("RedisStoreClearError!", zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func (rs *RedisStore) Verify(id, answer string, clear bool) bool {
|
||||
key := rs.PreKey + id
|
||||
v := rs.Get(key, clear)
|
||||
return v == answer
|
||||
}
|
148
utils/claims.go
Normal file
148
utils/claims.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"git.echol.cn/loser/lckt/model/system"
|
||||
systemReq "git.echol.cn/loser/lckt/model/system/request"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func ClearToken(c *gin.Context) {
|
||||
// 增加cookie x-token 向来源的web添加
|
||||
host, _, err := net.SplitHostPort(c.Request.Host)
|
||||
if err != nil {
|
||||
host = c.Request.Host
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
c.SetCookie("x-token", "", -1, "/", "", false, false)
|
||||
} else {
|
||||
c.SetCookie("x-token", "", -1, "/", host, false, false)
|
||||
}
|
||||
}
|
||||
|
||||
func SetToken(c *gin.Context, token string, maxAge int) {
|
||||
// 增加cookie x-token 向来源的web添加
|
||||
host, _, err := net.SplitHostPort(c.Request.Host)
|
||||
if err != nil {
|
||||
host = c.Request.Host
|
||||
}
|
||||
|
||||
if net.ParseIP(host) != nil {
|
||||
c.SetCookie("x-token", token, maxAge, "/", "", false, false)
|
||||
} else {
|
||||
c.SetCookie("x-token", token, maxAge, "/", host, false, false)
|
||||
}
|
||||
}
|
||||
|
||||
func GetToken(c *gin.Context) string {
|
||||
token := c.Request.Header.Get("x-token")
|
||||
if token == "" {
|
||||
j := NewJWT()
|
||||
token, _ = c.Cookie("x-token")
|
||||
claims, err := j.ParseToken(token)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("重新写入cookie token失败,未能成功解析token,请检查请求头是否存在x-token且claims是否为规定结构")
|
||||
return token
|
||||
}
|
||||
SetToken(c, token, int((claims.ExpiresAt.Unix()-time.Now().Unix())/60))
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func GetClaims(c *gin.Context) (*systemReq.CustomClaims, error) {
|
||||
token := GetToken(c)
|
||||
j := NewJWT()
|
||||
claims, err := j.ParseToken(token)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("从Gin的Context中获取从jwt解析信息失败, 请检查请求头是否存在x-token且claims是否为规定结构")
|
||||
}
|
||||
return claims, err
|
||||
}
|
||||
|
||||
// GetUserID 从Gin的Context中获取从jwt解析出来的用户ID
|
||||
func GetUserID(c *gin.Context) uint {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return 0
|
||||
} else {
|
||||
return cl.BaseClaims.ID
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.BaseClaims.ID
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserUuid 从Gin的Context中获取从jwt解析出来的用户UUID
|
||||
func GetUserUuid(c *gin.Context) uuid.UUID {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return uuid.UUID{}
|
||||
} else {
|
||||
return cl.UUID
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.UUID
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAuthorityId 从Gin的Context中获取从jwt解析出来的用户角色id
|
||||
func GetUserAuthorityId(c *gin.Context) uint {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return 0
|
||||
} else {
|
||||
return cl.AuthorityId
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.AuthorityId
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserInfo 从Gin的Context中获取从jwt解析出来的用户角色id
|
||||
func GetUserInfo(c *gin.Context) *systemReq.CustomClaims {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return nil
|
||||
} else {
|
||||
return cl
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserName 从Gin的Context中获取从jwt解析出来的用户名
|
||||
func GetUserName(c *gin.Context) string {
|
||||
if claims, exists := c.Get("claims"); !exists {
|
||||
if cl, err := GetClaims(c); err != nil {
|
||||
return ""
|
||||
} else {
|
||||
return cl.Username
|
||||
}
|
||||
} else {
|
||||
waitUse := claims.(*systemReq.CustomClaims)
|
||||
return waitUse.Username
|
||||
}
|
||||
}
|
||||
|
||||
func LoginToken(user system.Login) (token string, claims systemReq.CustomClaims, err error) {
|
||||
j := NewJWT()
|
||||
claims = j.CreateClaims(systemReq.BaseClaims{
|
||||
UUID: user.GetUUID(),
|
||||
ID: user.GetUserId(),
|
||||
NickName: user.GetNickname(),
|
||||
Username: user.GetUsername(),
|
||||
AuthorityId: user.GetAuthorityId(),
|
||||
})
|
||||
token, err = j.CreateToken(claims)
|
||||
return
|
||||
}
|
124
utils/directory.go
Normal file
124
utils/directory.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: PathExists
|
||||
//@description: 文件目录是否存在
|
||||
//@param: path string
|
||||
//@return: bool, error
|
||||
|
||||
func PathExists(path string) (bool, error) {
|
||||
fi, err := os.Stat(path)
|
||||
if err == nil {
|
||||
if fi.IsDir() {
|
||||
return true, nil
|
||||
}
|
||||
return false, errors.New("存在同名文件")
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: CreateDir
|
||||
//@description: 批量创建文件夹
|
||||
//@param: dirs ...string
|
||||
//@return: err error
|
||||
|
||||
func CreateDir(dirs ...string) (err error) {
|
||||
for _, v := range dirs {
|
||||
exist, err := PathExists(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exist {
|
||||
global.GVA_LOG.Debug("create directory" + v)
|
||||
if err := os.MkdirAll(v, os.ModePerm); err != nil {
|
||||
global.GVA_LOG.Error("create directory"+v, zap.Any(" error:", err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
//@author: [songzhibin97](https://github.com/songzhibin97)
|
||||
//@function: FileMove
|
||||
//@description: 文件移动供外部调用
|
||||
//@param: src string, dst string(src: 源位置,绝对路径or相对路径, dst: 目标位置,绝对路径or相对路径,必须为文件夹)
|
||||
//@return: err error
|
||||
|
||||
func FileMove(src string, dst string) (err error) {
|
||||
if dst == "" {
|
||||
return nil
|
||||
}
|
||||
src, err = filepath.Abs(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst, err = filepath.Abs(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revoke := false
|
||||
dir := filepath.Dir(dst)
|
||||
Redirect:
|
||||
_, err = os.Stat(dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !revoke {
|
||||
revoke = true
|
||||
goto Redirect
|
||||
}
|
||||
}
|
||||
return os.Rename(src, dst)
|
||||
}
|
||||
|
||||
func DeLFile(filePath string) error {
|
||||
return os.RemoveAll(filePath)
|
||||
}
|
||||
|
||||
//@author: [songzhibin97](https://github.com/songzhibin97)
|
||||
//@function: TrimSpace
|
||||
//@description: 去除结构体空格
|
||||
//@param: target interface (target: 目标结构体,传入必须是指针类型)
|
||||
//@return: null
|
||||
|
||||
func TrimSpace(target interface{}) {
|
||||
t := reflect.TypeOf(target)
|
||||
if t.Kind() != reflect.Ptr {
|
||||
return
|
||||
}
|
||||
t = t.Elem()
|
||||
v := reflect.ValueOf(target).Elem()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
switch v.Field(i).Kind() {
|
||||
case reflect.String:
|
||||
v.Field(i).SetString(strings.TrimSpace(v.Field(i).String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FileExist 判断文件是否存在
|
||||
func FileExist(path string) bool {
|
||||
fi, err := os.Lstat(path)
|
||||
if err == nil {
|
||||
return !fi.IsDir()
|
||||
}
|
||||
return !os.IsNotExist(err)
|
||||
}
|
108
utils/fmt_plus.go
Normal file
108
utils/fmt_plus.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.echol.cn/loser/lckt/model/common"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: StructToMap
|
||||
//@description: 利用反射将结构体转化为map
|
||||
//@param: obj interface{}
|
||||
//@return: map[string]interface{}
|
||||
|
||||
func StructToMap(obj interface{}) map[string]interface{} {
|
||||
obj1 := reflect.TypeOf(obj)
|
||||
obj2 := reflect.ValueOf(obj)
|
||||
|
||||
data := make(map[string]interface{})
|
||||
for i := 0; i < obj1.NumField(); i++ {
|
||||
if obj1.Field(i).Tag.Get("mapstructure") != "" {
|
||||
data[obj1.Field(i).Tag.Get("mapstructure")] = obj2.Field(i).Interface()
|
||||
} else {
|
||||
data[obj1.Field(i).Name] = obj2.Field(i).Interface()
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: ArrayToString
|
||||
//@description: 将数组格式化为字符串
|
||||
//@param: array []interface{}
|
||||
//@return: string
|
||||
|
||||
func ArrayToString(array []interface{}) string {
|
||||
return strings.Replace(strings.Trim(fmt.Sprint(array), "[]"), " ", ",", -1)
|
||||
}
|
||||
|
||||
func Pointer[T any](in T) (out *T) {
|
||||
return &in
|
||||
}
|
||||
|
||||
func FirstUpper(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func FirstLower(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
// MaheHump 将字符串转换为驼峰命名
|
||||
func MaheHump(s string) string {
|
||||
words := strings.Split(s, "-")
|
||||
|
||||
for i := 1; i < len(words); i++ {
|
||||
words[i] = strings.Title(words[i])
|
||||
}
|
||||
|
||||
return strings.Join(words, "")
|
||||
}
|
||||
|
||||
// RandomString 随机字符串
|
||||
func RandomString(n int) string {
|
||||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
b := make([]rune, n)
|
||||
for i := range b {
|
||||
b[i] = letters[RandomInt(0, len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func RandomInt(min, max int) int {
|
||||
return min + rand.Intn(max-min)
|
||||
}
|
||||
|
||||
// BuildTree 用于构建一个树形结构
|
||||
func BuildTree[T common.TreeNode[T]](nodes []T) []T {
|
||||
nodeMap := make(map[int]T)
|
||||
// 创建一个基本map
|
||||
for i := range nodes {
|
||||
nodeMap[nodes[i].GetID()] = nodes[i]
|
||||
}
|
||||
|
||||
for i := range nodes {
|
||||
if nodes[i].GetParentID() != 0 {
|
||||
parent := nodeMap[nodes[i].GetParentID()]
|
||||
parent.SetChildren(nodes[i])
|
||||
}
|
||||
}
|
||||
|
||||
var rootNodes []T
|
||||
|
||||
for i := range nodeMap {
|
||||
if nodeMap[i].GetParentID() == 0 {
|
||||
rootNodes = append(rootNodes, nodeMap[i])
|
||||
}
|
||||
}
|
||||
return rootNodes
|
||||
}
|
31
utils/hash.go
Normal file
31
utils/hash.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// BcryptHash 使用 bcrypt 对密码进行加密
|
||||
func BcryptHash(password string) string {
|
||||
bytes, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// BcryptCheck 对比明文密码和数据库的哈希值
|
||||
func BcryptCheck(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: MD5V
|
||||
//@description: md5加密
|
||||
//@param: str []byte
|
||||
//@return: string
|
||||
|
||||
func MD5V(str []byte, b ...byte) string {
|
||||
h := md5.New()
|
||||
h.Write(str)
|
||||
return hex.EncodeToString(h.Sum(b))
|
||||
}
|
29
utils/human_duration.go
Normal file
29
utils/human_duration.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ParseDuration(d string) (time.Duration, error) {
|
||||
d = strings.TrimSpace(d)
|
||||
dr, err := time.ParseDuration(d)
|
||||
if err == nil {
|
||||
return dr, nil
|
||||
}
|
||||
if strings.Contains(d, "d") {
|
||||
index := strings.Index(d, "d")
|
||||
|
||||
hour, _ := strconv.Atoi(d[:index])
|
||||
dr = time.Hour * 24 * time.Duration(hour)
|
||||
ndr, err := time.ParseDuration(d[index+1:])
|
||||
if err != nil {
|
||||
return dr, nil
|
||||
}
|
||||
return dr + ndr, nil
|
||||
}
|
||||
|
||||
dv, err := strconv.ParseInt(d, 10, 64)
|
||||
return time.Duration(dv), err
|
||||
}
|
49
utils/human_duration_test.go
Normal file
49
utils/human_duration_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseDuration(t *testing.T) {
|
||||
type args struct {
|
||||
d string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "5h20m",
|
||||
args: args{"5h20m"},
|
||||
want: time.Hour*5 + 20*time.Minute,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "1d5h20m",
|
||||
args: args{"1d5h20m"},
|
||||
want: 24*time.Hour + time.Hour*5 + 20*time.Minute,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "1d",
|
||||
args: args{"1d"},
|
||||
want: 24 * time.Hour,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseDuration(tt.args.d)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseDuration() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseDuration() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
34
utils/json.go
Normal file
34
utils/json.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetJSONKeys(jsonStr string) (keys []string, err error) {
|
||||
// 使用json.Decoder,以便在解析过程中记录键的顺序
|
||||
dec := json.NewDecoder(strings.NewReader(jsonStr))
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 确保数据是一个对象
|
||||
if t != json.Delim('{') {
|
||||
return nil, err
|
||||
}
|
||||
for dec.More() {
|
||||
t, err = dec.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, t.(string))
|
||||
|
||||
// 解析值
|
||||
var value interface{}
|
||||
err = dec.Decode(&value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
53
utils/json_test.go
Normal file
53
utils/json_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetJSONKeys(t *testing.T) {
|
||||
var jsonStr = `
|
||||
{
|
||||
"Name": "test",
|
||||
"TableName": "test",
|
||||
"TemplateID": "test",
|
||||
"TemplateInfo": "test",
|
||||
"Limit": 0
|
||||
}`
|
||||
keys, err := GetJSONKeys(jsonStr)
|
||||
if err != nil {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
return
|
||||
}
|
||||
if len(keys) != 5 {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
return
|
||||
}
|
||||
if keys[0] != "Name" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[1] != "TableName" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[2] != "TemplateID" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[3] != "TemplateInfo" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
if keys[4] != "Limit" {
|
||||
t.Errorf("GetJSONKeys failed" + err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(keys)
|
||||
}
|
87
utils/jwt.go
Normal file
87
utils/jwt.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"git.echol.cn/loser/lckt/model/system/request"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type JWT struct {
|
||||
SigningKey []byte
|
||||
}
|
||||
|
||||
var (
|
||||
TokenValid = errors.New("未知错误")
|
||||
TokenExpired = errors.New("token已过期")
|
||||
TokenNotValidYet = errors.New("token尚未激活")
|
||||
TokenMalformed = errors.New("这不是一个token")
|
||||
TokenSignatureInvalid = errors.New("无效签名")
|
||||
TokenInvalid = errors.New("无法处理此token")
|
||||
)
|
||||
|
||||
func NewJWT() *JWT {
|
||||
return &JWT{
|
||||
[]byte(global.GVA_CONFIG.JWT.SigningKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (j *JWT) CreateClaims(baseClaims request.BaseClaims) request.CustomClaims {
|
||||
bf, _ := ParseDuration(global.GVA_CONFIG.JWT.BufferTime)
|
||||
ep, _ := ParseDuration(global.GVA_CONFIG.JWT.ExpiresTime)
|
||||
claims := request.CustomClaims{
|
||||
BaseClaims: baseClaims,
|
||||
BufferTime: int64(bf / time.Second), // 缓冲时间1天 缓冲时间内会获得新的token刷新令牌 此时一个用户会存在两个有效令牌 但是前端只留一个 另一个会丢失
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Audience: jwt.ClaimStrings{"GVA"}, // 受众
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-1000)), // 签名生效时间
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ep)), // 过期时间 7天 配置文件
|
||||
Issuer: global.GVA_CONFIG.JWT.Issuer, // 签名的发行者
|
||||
},
|
||||
}
|
||||
return claims
|
||||
}
|
||||
|
||||
// CreateToken 创建一个token
|
||||
func (j *JWT) CreateToken(claims request.CustomClaims) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(j.SigningKey)
|
||||
}
|
||||
|
||||
// CreateTokenByOldToken 旧token 换新token 使用归并回源避免并发问题
|
||||
func (j *JWT) CreateTokenByOldToken(oldToken string, claims request.CustomClaims) (string, error) {
|
||||
v, err, _ := global.GVA_Concurrency_Control.Do("JWT:"+oldToken, func() (interface{}, error) {
|
||||
return j.CreateToken(claims)
|
||||
})
|
||||
return v.(string), err
|
||||
}
|
||||
|
||||
// ParseToken 解析 token
|
||||
func (j *JWT) ParseToken(tokenString string) (*request.CustomClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &request.CustomClaims{}, func(token *jwt.Token) (i interface{}, e error) {
|
||||
return j.SigningKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, jwt.ErrTokenExpired):
|
||||
return nil, TokenExpired
|
||||
case errors.Is(err, jwt.ErrTokenMalformed):
|
||||
return nil, TokenMalformed
|
||||
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
||||
return nil, TokenSignatureInvalid
|
||||
case errors.Is(err, jwt.ErrTokenNotValidYet):
|
||||
return nil, TokenNotValidYet
|
||||
default:
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
if token != nil {
|
||||
if claims, ok := token.Claims.(*request.CustomClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
}
|
||||
return nil, TokenValid
|
||||
}
|
18
utils/plugin/plugin.go
Normal file
18
utils/plugin/plugin.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
OnlyFuncName = "Plugin"
|
||||
)
|
||||
|
||||
// Plugin 插件模式接口化
|
||||
type Plugin interface {
|
||||
// Register 注册路由
|
||||
Register(group *gin.RouterGroup)
|
||||
|
||||
// RouterPath 用户返回注册路由
|
||||
RouterPath() string
|
||||
}
|
11
utils/plugin/v2/plugin.go
Normal file
11
utils/plugin/v2/plugin.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Plugin 插件模式接口化v2
|
||||
type Plugin interface {
|
||||
// Register 注册路由
|
||||
Register(group *gin.Engine)
|
||||
}
|
18
utils/reload.go
Normal file
18
utils/reload.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func Reload() error {
|
||||
if runtime.GOOS == "windows" {
|
||||
return errors.New("系统不支持")
|
||||
}
|
||||
pid := os.Getpid()
|
||||
cmd := exec.Command("kill", "-1", strconv.Itoa(pid))
|
||||
return cmd.Run()
|
||||
}
|
62
utils/request/http.go
Normal file
62
utils/request/http.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func HttpRequest(
|
||||
urlStr string,
|
||||
method string,
|
||||
headers map[string]string,
|
||||
params map[string]string,
|
||||
data any) (*http.Response, error) {
|
||||
// 创建URL
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 添加查询参数
|
||||
query := u.Query()
|
||||
for k, v := range params {
|
||||
query.Set(k, v)
|
||||
}
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
// 将数据编码为JSON
|
||||
buf := new(bytes.Buffer)
|
||||
if data != nil {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf = bytes.NewBuffer(b)
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequest(method, u.String(), buf)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 返回响应,让调用者处理
|
||||
return resp, nil
|
||||
}
|
126
utils/server.go
Normal file
126
utils/server.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
)
|
||||
|
||||
const (
|
||||
B = 1
|
||||
KB = 1024 * B
|
||||
MB = 1024 * KB
|
||||
GB = 1024 * MB
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Os Os `json:"os"`
|
||||
Cpu Cpu `json:"cpu"`
|
||||
Ram Ram `json:"ram"`
|
||||
Disk []Disk `json:"disk"`
|
||||
}
|
||||
|
||||
type Os struct {
|
||||
GOOS string `json:"goos"`
|
||||
NumCPU int `json:"numCpu"`
|
||||
Compiler string `json:"compiler"`
|
||||
GoVersion string `json:"goVersion"`
|
||||
NumGoroutine int `json:"numGoroutine"`
|
||||
}
|
||||
|
||||
type Cpu struct {
|
||||
Cpus []float64 `json:"cpus"`
|
||||
Cores int `json:"cores"`
|
||||
}
|
||||
|
||||
type Ram struct {
|
||||
UsedMB int `json:"usedMb"`
|
||||
TotalMB int `json:"totalMb"`
|
||||
UsedPercent int `json:"usedPercent"`
|
||||
}
|
||||
|
||||
type Disk struct {
|
||||
MountPoint string `json:"mountPoint"`
|
||||
UsedMB int `json:"usedMb"`
|
||||
UsedGB int `json:"usedGb"`
|
||||
TotalMB int `json:"totalMb"`
|
||||
TotalGB int `json:"totalGb"`
|
||||
UsedPercent int `json:"usedPercent"`
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitCPU
|
||||
//@description: OS信息
|
||||
//@return: o Os, err error
|
||||
|
||||
func InitOS() (o Os) {
|
||||
o.GOOS = runtime.GOOS
|
||||
o.NumCPU = runtime.NumCPU()
|
||||
o.Compiler = runtime.Compiler
|
||||
o.GoVersion = runtime.Version()
|
||||
o.NumGoroutine = runtime.NumGoroutine()
|
||||
return o
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitCPU
|
||||
//@description: CPU信息
|
||||
//@return: c Cpu, err error
|
||||
|
||||
func InitCPU() (c Cpu, err error) {
|
||||
if cores, err := cpu.Counts(false); err != nil {
|
||||
return c, err
|
||||
} else {
|
||||
c.Cores = cores
|
||||
}
|
||||
if cpus, err := cpu.Percent(time.Duration(200)*time.Millisecond, true); err != nil {
|
||||
return c, err
|
||||
} else {
|
||||
c.Cpus = cpus
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitRAM
|
||||
//@description: RAM信息
|
||||
//@return: r Ram, err error
|
||||
|
||||
func InitRAM() (r Ram, err error) {
|
||||
if u, err := mem.VirtualMemory(); err != nil {
|
||||
return r, err
|
||||
} else {
|
||||
r.UsedMB = int(u.Used) / MB
|
||||
r.TotalMB = int(u.Total) / MB
|
||||
r.UsedPercent = int(u.UsedPercent)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@function: InitDisk
|
||||
//@description: 硬盘信息
|
||||
//@return: d Disk, err error
|
||||
|
||||
func InitDisk() (d []Disk, err error) {
|
||||
for i := range global.GVA_CONFIG.DiskList {
|
||||
mp := global.GVA_CONFIG.DiskList[i].MountPoint
|
||||
if u, err := disk.Usage(mp); err != nil {
|
||||
return d, err
|
||||
} else {
|
||||
d = append(d, Disk{
|
||||
MountPoint: mp,
|
||||
UsedMB: int(u.Used) / MB,
|
||||
UsedGB: int(u.Used) / GB,
|
||||
TotalMB: int(u.Total) / MB,
|
||||
TotalGB: int(u.Total) / GB,
|
||||
UsedPercent: int(u.UsedPercent),
|
||||
})
|
||||
}
|
||||
}
|
||||
return d, nil
|
||||
}
|
229
utils/timer/timed_task.go
Normal file
229
utils/timer/timed_task.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package timer
|
||||
|
||||
import (
|
||||
"github.com/robfig/cron/v3"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Timer interface {
|
||||
// 寻找所有Cron
|
||||
FindCronList() map[string]*taskManager
|
||||
// 添加Task 方法形式以秒的形式加入
|
||||
AddTaskByFuncWithSecond(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) // 添加Task Func以秒的形式加入
|
||||
// 添加Task 接口形式以秒的形式加入
|
||||
AddTaskByJobWithSeconds(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error)
|
||||
// 通过函数的方法添加任务
|
||||
AddTaskByFunc(cronName string, spec string, task func(), taskName string, option ...cron.Option) (cron.EntryID, error)
|
||||
// 通过接口的方法添加任务 要实现一个带有 Run方法的接口触发
|
||||
AddTaskByJob(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error)
|
||||
// 获取对应taskName的cron 可能会为空
|
||||
FindCron(cronName string) (*taskManager, bool)
|
||||
// 指定cron开始执行
|
||||
StartCron(cronName string)
|
||||
// 指定cron停止执行
|
||||
StopCron(cronName string)
|
||||
// 查找指定cron下的指定task
|
||||
FindTask(cronName string, taskName string) (*task, bool)
|
||||
// 根据id删除指定cron下的指定task
|
||||
RemoveTask(cronName string, id int)
|
||||
// 根据taskName删除指定cron下的指定task
|
||||
RemoveTaskByName(cronName string, taskName string)
|
||||
// 清理掉指定cronName
|
||||
Clear(cronName string)
|
||||
// 停止所有的cron
|
||||
Close()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
EntryID cron.EntryID
|
||||
Spec string
|
||||
TaskName string
|
||||
}
|
||||
|
||||
type taskManager struct {
|
||||
corn *cron.Cron
|
||||
tasks map[cron.EntryID]*task
|
||||
}
|
||||
|
||||
// timer 定时任务管理
|
||||
type timer struct {
|
||||
cronList map[string]*taskManager
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// AddTaskByFunc 通过函数的方法添加任务
|
||||
func (t *timer) AddTaskByFunc(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddFunc(spec, fun)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// AddTaskByFuncWithSecond 通过函数的方法使用WithSeconds添加任务
|
||||
func (t *timer) AddTaskByFuncWithSecond(cronName string, spec string, fun func(), taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
option = append(option, cron.WithSeconds())
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddFunc(spec, fun)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// AddTaskByJob 通过接口的方法添加任务
|
||||
func (t *timer) AddTaskByJob(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddJob(spec, job)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// AddTaskByJobWithSeconds 通过接口的方法添加任务
|
||||
func (t *timer) AddTaskByJobWithSeconds(cronName string, spec string, job interface{ Run() }, taskName string, option ...cron.Option) (cron.EntryID, error) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
option = append(option, cron.WithSeconds())
|
||||
if _, ok := t.cronList[cronName]; !ok {
|
||||
tasks := make(map[cron.EntryID]*task)
|
||||
t.cronList[cronName] = &taskManager{
|
||||
corn: cron.New(option...),
|
||||
tasks: tasks,
|
||||
}
|
||||
}
|
||||
id, err := t.cronList[cronName].corn.AddJob(spec, job)
|
||||
t.cronList[cronName].corn.Start()
|
||||
t.cronList[cronName].tasks[id] = &task{
|
||||
EntryID: id,
|
||||
Spec: spec,
|
||||
TaskName: taskName,
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
// FindCron 获取对应cronName的cron 可能会为空
|
||||
func (t *timer) FindCron(cronName string) (*taskManager, bool) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
v, ok := t.cronList[cronName]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// FindTask 获取对应cronName的cron 可能会为空
|
||||
func (t *timer) FindTask(cronName string, taskName string) (*task, bool) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
v, ok := t.cronList[cronName]
|
||||
if !ok {
|
||||
return nil, ok
|
||||
}
|
||||
for _, t2 := range v.tasks {
|
||||
if t2.TaskName == taskName {
|
||||
return t2, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindCronList 获取所有的任务列表
|
||||
func (t *timer) FindCronList() map[string]*taskManager {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
return t.cronList
|
||||
}
|
||||
|
||||
// StartCron 开始任务
|
||||
func (t *timer) StartCron(cronName string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// StopCron 停止任务
|
||||
func (t *timer) StopCron(cronName string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveTask 从cronName 删除指定任务
|
||||
func (t *timer) RemoveTask(cronName string, id int) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Remove(cron.EntryID(id))
|
||||
delete(v.tasks, cron.EntryID(id))
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveTaskByName 从cronName 使用taskName 删除指定任务
|
||||
func (t *timer) RemoveTaskByName(cronName string, taskName string) {
|
||||
fTask, ok := t.FindTask(cronName, taskName)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
t.RemoveTask(cronName, int(fTask.EntryID))
|
||||
}
|
||||
|
||||
// Clear 清除任务
|
||||
func (t *timer) Clear(cronName string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if v, ok := t.cronList[cronName]; ok {
|
||||
v.corn.Stop()
|
||||
delete(t.cronList, cronName)
|
||||
}
|
||||
}
|
||||
|
||||
// Close 释放资源
|
||||
func (t *timer) Close() {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
for _, v := range t.cronList {
|
||||
v.corn.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func NewTimerTask() Timer {
|
||||
return &timer{cronList: make(map[string]*taskManager)}
|
||||
}
|
72
utils/timer/timed_task_test.go
Normal file
72
utils/timer/timed_task_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package timer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var job = mockJob{}
|
||||
|
||||
type mockJob struct{}
|
||||
|
||||
func (job mockJob) Run() {
|
||||
mockFunc()
|
||||
}
|
||||
|
||||
func mockFunc() {
|
||||
time.Sleep(time.Second)
|
||||
fmt.Println("1s...")
|
||||
}
|
||||
|
||||
func TestNewTimerTask(t *testing.T) {
|
||||
tm := NewTimerTask()
|
||||
_tm := tm.(*timer)
|
||||
|
||||
{
|
||||
_, err := tm.AddTaskByFunc("func", "@every 1s", mockFunc, "测试mockfunc")
|
||||
assert.Nil(t, err)
|
||||
_, ok := _tm.cronList["func"]
|
||||
if !ok {
|
||||
t.Error("no find func")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
_, err := tm.AddTaskByJob("job", "@every 1s", job, "测试job mockfunc")
|
||||
assert.Nil(t, err)
|
||||
_, ok := _tm.cronList["job"]
|
||||
if !ok {
|
||||
t.Error("no find job")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
_, ok := tm.FindCron("func")
|
||||
if !ok {
|
||||
t.Error("no find func")
|
||||
}
|
||||
_, ok = tm.FindCron("job")
|
||||
if !ok {
|
||||
t.Error("no find job")
|
||||
}
|
||||
_, ok = tm.FindCron("none")
|
||||
if ok {
|
||||
t.Error("find none")
|
||||
}
|
||||
}
|
||||
{
|
||||
tm.Clear("func")
|
||||
_, ok := tm.FindCron("func")
|
||||
if ok {
|
||||
t.Error("find func")
|
||||
}
|
||||
}
|
||||
{
|
||||
a := tm.FindCronList()
|
||||
b, c := tm.FindCron("job")
|
||||
fmt.Println(a, b, c)
|
||||
}
|
||||
}
|
75
utils/upload/aliyun_oss.go
Normal file
75
utils/upload/aliyun_oss.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"github.com/aliyun/aliyun-oss-go-sdk/oss"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AliyunOSS struct{}
|
||||
|
||||
func (*AliyunOSS) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
bucket, err := NewBucket()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function AliyunOSS.NewBucket() Failed", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("function AliyunOSS.NewBucket() Failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
// 读取本地文件。
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() Failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() Failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
// 上传阿里云路径 文件名格式 自己可以改 建议保证唯一性
|
||||
// yunFileTmpPath := filepath.Join("uploads", time.Now().Format("2006-01-02")) + "/" + file.Filename
|
||||
yunFileTmpPath := global.GVA_CONFIG.AliyunOSS.BasePath + "/" + "uploads" + "/" + time.Now().Format("2006-01-02") + "/" + file.Filename
|
||||
|
||||
// 上传文件流。
|
||||
err = bucket.PutObject(yunFileTmpPath, f)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function formUploader.Put() Failed", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("function formUploader.Put() Failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
return global.GVA_CONFIG.AliyunOSS.BucketUrl + "/" + yunFileTmpPath, yunFileTmpPath, nil
|
||||
}
|
||||
|
||||
func (*AliyunOSS) DeleteFile(key string) error {
|
||||
bucket, err := NewBucket()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function AliyunOSS.NewBucket() Failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function AliyunOSS.NewBucket() Failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
// 删除单个文件。objectName表示删除OSS文件时需要指定包含文件后缀在内的完整路径,例如abc/efg/123.jpg。
|
||||
// 如需删除文件夹,请将objectName设置为对应的文件夹名称。如果文件夹非空,则需要将文件夹下的所有object删除后才能删除该文件夹。
|
||||
err = bucket.DeleteObject(key)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewBucket() (*oss.Bucket, error) {
|
||||
// 创建OSSClient实例。
|
||||
client, err := oss.New(global.GVA_CONFIG.AliyunOSS.Endpoint, global.GVA_CONFIG.AliyunOSS.AccessKeyId, global.GVA_CONFIG.AliyunOSS.AccessKeySecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取存储空间。
|
||||
bucket, err := client.Bucket(global.GVA_CONFIG.AliyunOSS.BucketName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bucket, nil
|
||||
}
|
97
utils/upload/aws_s3.go
Normal file
97
utils/upload/aws_s3.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AwsS3 struct{}
|
||||
|
||||
//@author: [WqyJh](https://github.com/WqyJh)
|
||||
//@object: *AwsS3
|
||||
//@function: UploadFile
|
||||
//@description: Upload file to Aws S3 using aws-sdk-go. See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/s3-example-basic-bucket-operations.html#s3-examples-bucket-ops-upload-file-to-bucket
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*AwsS3) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
session := newSession()
|
||||
uploader := s3manager.NewUploader(session)
|
||||
|
||||
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename)
|
||||
filename := global.GVA_CONFIG.AwsS3.PathPrefix + "/" + fileKey
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
|
||||
_, err := uploader.Upload(&s3manager.UploadInput{
|
||||
Bucket: aws.String(global.GVA_CONFIG.AwsS3.Bucket),
|
||||
Key: aws.String(filename),
|
||||
Body: f,
|
||||
})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function uploader.Upload() failed", zap.Any("err", err.Error()))
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return global.GVA_CONFIG.AwsS3.BaseURL + "/" + filename, fileKey, nil
|
||||
}
|
||||
|
||||
//@author: [WqyJh](https://github.com/WqyJh)
|
||||
//@object: *AwsS3
|
||||
//@function: DeleteFile
|
||||
//@description: Delete file from Aws S3 using aws-sdk-go. See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/s3-example-basic-bucket-operations.html#s3-examples-bucket-ops-delete-bucket-item
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*AwsS3) DeleteFile(key string) error {
|
||||
session := newSession()
|
||||
svc := s3.New(session)
|
||||
filename := global.GVA_CONFIG.AwsS3.PathPrefix + "/" + key
|
||||
bucket := global.GVA_CONFIG.AwsS3.Bucket
|
||||
|
||||
_, err := svc.DeleteObject(&s3.DeleteObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function svc.DeleteObject() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function svc.DeleteObject() failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
_ = svc.WaitUntilObjectNotExists(&s3.HeadObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// newSession Create S3 session
|
||||
func newSession() *session.Session {
|
||||
sess, _ := session.NewSession(&aws.Config{
|
||||
Region: aws.String(global.GVA_CONFIG.AwsS3.Region),
|
||||
Endpoint: aws.String(global.GVA_CONFIG.AwsS3.Endpoint), //minio在这里设置地址,可以兼容
|
||||
S3ForcePathStyle: aws.Bool(global.GVA_CONFIG.AwsS3.S3ForcePathStyle),
|
||||
DisableSSL: aws.Bool(global.GVA_CONFIG.AwsS3.DisableSSL),
|
||||
Credentials: credentials.NewStaticCredentials(
|
||||
global.GVA_CONFIG.AwsS3.SecretID,
|
||||
global.GVA_CONFIG.AwsS3.SecretKey,
|
||||
"",
|
||||
),
|
||||
})
|
||||
return sess
|
||||
}
|
85
utils/upload/cloudflare_r2.go
Normal file
85
utils/upload/cloudflare_r2.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type CloudflareR2 struct{}
|
||||
|
||||
func (c *CloudflareR2) UploadFile(file *multipart.FileHeader) (fileUrl string, fileName string, err error) {
|
||||
session := c.newSession()
|
||||
client := s3manager.NewUploader(session)
|
||||
|
||||
fileKey := fmt.Sprintf("%d_%s", time.Now().Unix(), file.Filename)
|
||||
fileName = fmt.Sprintf("%s/%s", global.GVA_CONFIG.CloudflareR2.Path, fileKey)
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
|
||||
input := &s3manager.UploadInput{
|
||||
Bucket: aws.String(global.GVA_CONFIG.CloudflareR2.Bucket),
|
||||
Key: aws.String(fileName),
|
||||
Body: f,
|
||||
}
|
||||
|
||||
_, err = client.Upload(input)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function uploader.Upload() failed", zap.Any("err", err.Error()))
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s", global.GVA_CONFIG.CloudflareR2.BaseURL,
|
||||
fileName),
|
||||
fileKey,
|
||||
nil
|
||||
}
|
||||
|
||||
func (c *CloudflareR2) DeleteFile(key string) error {
|
||||
session := newSession()
|
||||
svc := s3.New(session)
|
||||
filename := global.GVA_CONFIG.CloudflareR2.Path + "/" + key
|
||||
bucket := global.GVA_CONFIG.CloudflareR2.Bucket
|
||||
|
||||
_, err := svc.DeleteObject(&s3.DeleteObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function svc.DeleteObject() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function svc.DeleteObject() failed, err:" + err.Error())
|
||||
}
|
||||
|
||||
_ = svc.WaitUntilObjectNotExists(&s3.HeadObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(filename),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*CloudflareR2) newSession() *session.Session {
|
||||
endpoint := fmt.Sprintf("%s.r2.cloudflarestorage.com", global.GVA_CONFIG.CloudflareR2.AccountID)
|
||||
|
||||
return session.Must(session.NewSession(&aws.Config{
|
||||
Region: aws.String("auto"),
|
||||
Endpoint: aws.String(endpoint),
|
||||
Credentials: credentials.NewStaticCredentials(
|
||||
global.GVA_CONFIG.CloudflareR2.AccessKeyID,
|
||||
global.GVA_CONFIG.CloudflareR2.SecretAccessKey,
|
||||
"",
|
||||
),
|
||||
}))
|
||||
}
|
109
utils/upload/local.go
Normal file
109
utils/upload/local.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"git.echol.cn/loser/lckt/utils"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
type Local struct{}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Local
|
||||
//@function: UploadFile
|
||||
//@description: 上传文件
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*Local) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
// 读取文件后缀
|
||||
ext := filepath.Ext(file.Filename)
|
||||
// 读取文件名并加密
|
||||
name := strings.TrimSuffix(file.Filename, ext)
|
||||
name = utils.MD5V([]byte(name))
|
||||
// 拼接新文件名
|
||||
filename := name + "_" + time.Now().Format("20060102150405") + ext
|
||||
// 尝试创建此路径
|
||||
mkdirErr := os.MkdirAll(global.GVA_CONFIG.Local.StorePath, os.ModePerm)
|
||||
if mkdirErr != nil {
|
||||
global.GVA_LOG.Error("function os.MkdirAll() failed", zap.Any("err", mkdirErr.Error()))
|
||||
return "", "", errors.New("function os.MkdirAll() failed, err:" + mkdirErr.Error())
|
||||
}
|
||||
// 拼接路径和文件名
|
||||
p := global.GVA_CONFIG.Local.StorePath + "/" + filename
|
||||
filepath := global.GVA_CONFIG.Local.Path + "/" + filename
|
||||
|
||||
f, openError := file.Open() // 读取文件
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
|
||||
out, createErr := os.Create(p)
|
||||
if createErr != nil {
|
||||
global.GVA_LOG.Error("function os.Create() failed", zap.Any("err", createErr.Error()))
|
||||
|
||||
return "", "", errors.New("function os.Create() failed, err:" + createErr.Error())
|
||||
}
|
||||
defer out.Close() // 创建文件 defer 关闭
|
||||
|
||||
_, copyErr := io.Copy(out, f) // 传输(拷贝)文件
|
||||
if copyErr != nil {
|
||||
global.GVA_LOG.Error("function io.Copy() failed", zap.Any("err", copyErr.Error()))
|
||||
return "", "", errors.New("function io.Copy() failed, err:" + copyErr.Error())
|
||||
}
|
||||
return filepath, filename, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Local
|
||||
//@function: DeleteFile
|
||||
//@description: 删除文件
|
||||
//@param: key string
|
||||
//@return: error
|
||||
|
||||
func (*Local) DeleteFile(key string) error {
|
||||
// 检查 key 是否为空
|
||||
if key == "" {
|
||||
return errors.New("key不能为空")
|
||||
}
|
||||
|
||||
// 验证 key 是否包含非法字符或尝试访问存储路径之外的文件
|
||||
if strings.Contains(key, "..") || strings.ContainsAny(key, `\/:*?"<>|`) {
|
||||
return errors.New("非法的key")
|
||||
}
|
||||
|
||||
p := filepath.Join(global.GVA_CONFIG.Local.StorePath, key)
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(p); os.IsNotExist(err) {
|
||||
return errors.New("文件不存在")
|
||||
}
|
||||
|
||||
// 使用文件锁防止并发删除
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
err := os.Remove(p)
|
||||
if err != nil {
|
||||
return errors.New("文件删除失败: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
99
utils/upload/minio_oss.go
Normal file
99
utils/upload/minio_oss.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"git.echol.cn/loser/lckt/utils"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var MinioClient *Minio // 优化性能,但是不支持动态配置
|
||||
|
||||
type Minio struct {
|
||||
Client *minio.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
func GetMinio(endpoint, accessKeyID, secretAccessKey, bucketName string, useSSL bool) (*Minio, error) {
|
||||
if MinioClient != nil {
|
||||
return MinioClient, nil
|
||||
}
|
||||
// Initialize minio client object.
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
|
||||
Secure: useSSL, // Set to true if using https
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 尝试创建bucket
|
||||
err = minioClient.MakeBucket(context.Background(), bucketName, minio.MakeBucketOptions{})
|
||||
if err != nil {
|
||||
// Check to see if we already own this bucket (which happens if you run this twice)
|
||||
exists, errBucketExists := minioClient.BucketExists(context.Background(), bucketName)
|
||||
if errBucketExists == nil && exists {
|
||||
// log.Printf("We already own %s\n", bucketName)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
MinioClient = &Minio{Client: minioClient, bucket: bucketName}
|
||||
return MinioClient, nil
|
||||
}
|
||||
|
||||
func (m *Minio) UploadFile(file *multipart.FileHeader) (filePathres, key string, uploadErr error) {
|
||||
f, openError := file.Open()
|
||||
// mutipart.File to os.File
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() Failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() Failed, err:" + openError.Error())
|
||||
}
|
||||
|
||||
filecontent := bytes.Buffer{}
|
||||
_, err := io.Copy(&filecontent, f)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("读取文件失败", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("读取文件失败, err:" + err.Error())
|
||||
}
|
||||
f.Close() // 创建文件 defer 关闭
|
||||
|
||||
// 对文件名进行加密存储
|
||||
ext := filepath.Ext(file.Filename)
|
||||
filename := utils.MD5V([]byte(strings.TrimSuffix(file.Filename, ext))) + ext
|
||||
if global.GVA_CONFIG.Minio.BasePath == "" {
|
||||
filePathres = "uploads" + "/" + time.Now().Format("2006-01-02") + "/" + filename
|
||||
} else {
|
||||
filePathres = global.GVA_CONFIG.Minio.BasePath + "/" + time.Now().Format("2006-01-02") + "/" + filename
|
||||
}
|
||||
|
||||
// 设置超时10分钟
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
|
||||
defer cancel()
|
||||
|
||||
// Upload the file with PutObject 大文件自动切换为分片上传
|
||||
info, err := m.Client.PutObject(ctx, global.GVA_CONFIG.Minio.BucketName, filePathres, &filecontent, file.Size, minio.PutObjectOptions{ContentType: "application/octet-stream"})
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("上传文件到minio失败", zap.Any("err", err.Error()))
|
||||
return "", "", errors.New("上传文件到minio失败, err:" + err.Error())
|
||||
}
|
||||
return global.GVA_CONFIG.Minio.BucketUrl + "/" + info.Key, filePathres, nil
|
||||
}
|
||||
|
||||
func (m *Minio) DeleteFile(key string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
// Delete the object from MinIO
|
||||
err := m.Client.RemoveObject(ctx, m.bucket, key, minio.RemoveObjectOptions{})
|
||||
return err
|
||||
}
|
69
utils/upload/obs.go
Normal file
69
utils/upload/obs.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"github.com/huaweicloud/huaweicloud-sdk-go-obs/obs"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var HuaWeiObs = new(Obs)
|
||||
|
||||
type Obs struct{}
|
||||
|
||||
func NewHuaWeiObsClient() (client *obs.ObsClient, err error) {
|
||||
return obs.New(global.GVA_CONFIG.HuaWeiObs.AccessKey, global.GVA_CONFIG.HuaWeiObs.SecretKey, global.GVA_CONFIG.HuaWeiObs.Endpoint)
|
||||
}
|
||||
|
||||
func (o *Obs) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
// var open multipart.File
|
||||
open, err := file.Open()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer open.Close()
|
||||
filename := file.Filename
|
||||
input := &obs.PutObjectInput{
|
||||
PutObjectBasicInput: obs.PutObjectBasicInput{
|
||||
ObjectOperationInput: obs.ObjectOperationInput{
|
||||
Bucket: global.GVA_CONFIG.HuaWeiObs.Bucket,
|
||||
Key: filename,
|
||||
},
|
||||
HttpHeader: obs.HttpHeader{
|
||||
ContentType: file.Header.Get("content-type"),
|
||||
},
|
||||
},
|
||||
Body: open,
|
||||
}
|
||||
|
||||
var client *obs.ObsClient
|
||||
client, err = NewHuaWeiObsClient()
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(err, "获取华为对象存储对象失败!")
|
||||
}
|
||||
|
||||
_, err = client.PutObject(input)
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(err, "文件上传失败!")
|
||||
}
|
||||
filepath := global.GVA_CONFIG.HuaWeiObs.Path + "/" + filename
|
||||
return filepath, filename, err
|
||||
}
|
||||
|
||||
func (o *Obs) DeleteFile(key string) error {
|
||||
client, err := NewHuaWeiObsClient()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "获取华为对象存储对象失败!")
|
||||
}
|
||||
input := &obs.DeleteObjectInput{
|
||||
Bucket: global.GVA_CONFIG.HuaWeiObs.Bucket,
|
||||
Key: key,
|
||||
}
|
||||
var output *obs.DeleteObjectOutput
|
||||
output, err = client.DeleteObject(input)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "删除对象(%s)失败!, output: %v", key, output)
|
||||
}
|
||||
return nil
|
||||
}
|
96
utils/upload/qiniu.go
Normal file
96
utils/upload/qiniu.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Qiniu struct{}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Qiniu
|
||||
//@function: UploadFile
|
||||
//@description: 上传文件
|
||||
//@param: file *multipart.FileHeader
|
||||
//@return: string, string, error
|
||||
|
||||
func (*Qiniu) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
putPolicy := storage.PutPolicy{Scope: global.GVA_CONFIG.Qiniu.Bucket}
|
||||
mac := qbox.NewMac(global.GVA_CONFIG.Qiniu.AccessKey, global.GVA_CONFIG.Qiniu.SecretKey)
|
||||
upToken := putPolicy.UploadToken(mac)
|
||||
cfg := qiniuConfig()
|
||||
formUploader := storage.NewFormUploader(cfg)
|
||||
ret := storage.PutRet{}
|
||||
putExtra := storage.PutExtra{Params: map[string]string{"x:name": "github logo"}}
|
||||
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename) // 文件名格式 自己可以改 建议保证唯一性
|
||||
putErr := formUploader.Put(context.Background(), &ret, upToken, fileKey, f, file.Size, &putExtra)
|
||||
if putErr != nil {
|
||||
global.GVA_LOG.Error("function formUploader.Put() failed", zap.Any("err", putErr.Error()))
|
||||
return "", "", errors.New("function formUploader.Put() failed, err:" + putErr.Error())
|
||||
}
|
||||
return global.GVA_CONFIG.Qiniu.ImgPath + "/" + ret.Key, ret.Key, nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@author: [ccfish86](https://github.com/ccfish86)
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Qiniu
|
||||
//@function: DeleteFile
|
||||
//@description: 删除文件
|
||||
//@param: key string
|
||||
//@return: error
|
||||
|
||||
func (*Qiniu) DeleteFile(key string) error {
|
||||
mac := qbox.NewMac(global.GVA_CONFIG.Qiniu.AccessKey, global.GVA_CONFIG.Qiniu.SecretKey)
|
||||
cfg := qiniuConfig()
|
||||
bucketManager := storage.NewBucketManager(mac, cfg)
|
||||
if err := bucketManager.Delete(global.GVA_CONFIG.Qiniu.Bucket, key); err != nil {
|
||||
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//@author: [SliverHorn](https://github.com/SliverHorn)
|
||||
//@object: *Qiniu
|
||||
//@function: qiniuConfig
|
||||
//@description: 根据配置文件进行返回七牛云的配置
|
||||
//@return: *storage.Config
|
||||
|
||||
func qiniuConfig() *storage.Config {
|
||||
cfg := storage.Config{
|
||||
UseHTTPS: global.GVA_CONFIG.Qiniu.UseHTTPS,
|
||||
UseCdnDomains: global.GVA_CONFIG.Qiniu.UseCdnDomains,
|
||||
}
|
||||
switch global.GVA_CONFIG.Qiniu.Zone { // 根据配置文件进行初始化空间对应的机房
|
||||
case "ZoneHuadong":
|
||||
cfg.Zone = &storage.ZoneHuadong
|
||||
case "ZoneHuabei":
|
||||
cfg.Zone = &storage.ZoneHuabei
|
||||
case "ZoneHuanan":
|
||||
cfg.Zone = &storage.ZoneHuanan
|
||||
case "ZoneBeimei":
|
||||
cfg.Zone = &storage.ZoneBeimei
|
||||
case "ZoneXinjiapo":
|
||||
cfg.Zone = &storage.ZoneXinjiapo
|
||||
}
|
||||
return &cfg
|
||||
}
|
61
utils/upload/tencent_cos.go
Normal file
61
utils/upload/tencent_cos.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
|
||||
"github.com/tencentyun/cos-go-sdk-v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type TencentCOS struct{}
|
||||
|
||||
// UploadFile upload file to COS
|
||||
func (*TencentCOS) UploadFile(file *multipart.FileHeader) (string, string, error) {
|
||||
client := NewClient()
|
||||
f, openError := file.Open()
|
||||
if openError != nil {
|
||||
global.GVA_LOG.Error("function file.Open() failed", zap.Any("err", openError.Error()))
|
||||
return "", "", errors.New("function file.Open() failed, err:" + openError.Error())
|
||||
}
|
||||
defer f.Close() // 创建文件 defer 关闭
|
||||
fileKey := fmt.Sprintf("%d%s", time.Now().Unix(), file.Filename)
|
||||
|
||||
_, err := client.Object.Put(context.Background(), global.GVA_CONFIG.TencentCOS.PathPrefix+"/"+fileKey, f, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return global.GVA_CONFIG.TencentCOS.BaseURL + "/" + global.GVA_CONFIG.TencentCOS.PathPrefix + "/" + fileKey, fileKey, nil
|
||||
}
|
||||
|
||||
// DeleteFile delete file form COS
|
||||
func (*TencentCOS) DeleteFile(key string) error {
|
||||
client := NewClient()
|
||||
name := global.GVA_CONFIG.TencentCOS.PathPrefix + "/" + key
|
||||
_, err := client.Object.Delete(context.Background(), name)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("function bucketManager.Delete() failed", zap.Any("err", err.Error()))
|
||||
return errors.New("function bucketManager.Delete() failed, err:" + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewClient init COS client
|
||||
func NewClient() *cos.Client {
|
||||
urlStr, _ := url.Parse("https://" + global.GVA_CONFIG.TencentCOS.Bucket + ".cos." + global.GVA_CONFIG.TencentCOS.Region + ".myqcloud.com")
|
||||
baseURL := &cos.BaseURL{BucketURL: urlStr}
|
||||
client := cos.NewClient(baseURL, &http.Client{
|
||||
Transport: &cos.AuthorizationTransport{
|
||||
SecretID: global.GVA_CONFIG.TencentCOS.SecretID,
|
||||
SecretKey: global.GVA_CONFIG.TencentCOS.SecretKey,
|
||||
},
|
||||
})
|
||||
return client
|
||||
}
|
46
utils/upload/upload.go
Normal file
46
utils/upload/upload.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
|
||||
"git.echol.cn/loser/lckt/global"
|
||||
)
|
||||
|
||||
// OSS 对象存储接口
|
||||
// Author [SliverHorn](https://github.com/SliverHorn)
|
||||
// Author [ccfish86](https://github.com/ccfish86)
|
||||
type OSS interface {
|
||||
UploadFile(file *multipart.FileHeader) (string, string, error)
|
||||
DeleteFile(key string) error
|
||||
}
|
||||
|
||||
// NewOss OSS的实例化方法
|
||||
// Author [SliverHorn](https://github.com/SliverHorn)
|
||||
// Author [ccfish86](https://github.com/ccfish86)
|
||||
func NewOss() OSS {
|
||||
switch global.GVA_CONFIG.System.OssType {
|
||||
case "local":
|
||||
return &Local{}
|
||||
case "qiniu":
|
||||
return &Qiniu{}
|
||||
case "tencent-cos":
|
||||
return &TencentCOS{}
|
||||
case "aliyun-oss":
|
||||
return &AliyunOSS{}
|
||||
case "huawei-obs":
|
||||
return HuaWeiObs
|
||||
case "aws-s3":
|
||||
return &AwsS3{}
|
||||
case "cloudflare-r2":
|
||||
return &CloudflareR2{}
|
||||
case "minio":
|
||||
minioClient, err := GetMinio(global.GVA_CONFIG.Minio.Endpoint, global.GVA_CONFIG.Minio.AccessKeyId, global.GVA_CONFIG.Minio.AccessKeySecret, global.GVA_CONFIG.Minio.BucketName, global.GVA_CONFIG.Minio.UseSSL)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Warn("你配置了使用minio,但是初始化失败,请检查minio可用性或安全配置: " + err.Error())
|
||||
panic("minio初始化失败") // 建议这样做,用户自己配置了minio,如果报错了还要把服务开起来,使用起来也很危险
|
||||
}
|
||||
return minioClient
|
||||
default:
|
||||
return &Local{}
|
||||
}
|
||||
}
|
294
utils/validator.go
Normal file
294
utils/validator.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Rules map[string][]string
|
||||
|
||||
type RulesMap map[string]Rules
|
||||
|
||||
var CustomizeMap = make(map[string]Rules)
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: RegisterRule
|
||||
//@description: 注册自定义规则方案建议在路由初始化层即注册
|
||||
//@param: key string, rule Rules
|
||||
//@return: err error
|
||||
|
||||
func RegisterRule(key string, rule Rules) (err error) {
|
||||
if CustomizeMap[key] != nil {
|
||||
return errors.New(key + "已注册,无法重复注册")
|
||||
} else {
|
||||
CustomizeMap[key] = rule
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: NotEmpty
|
||||
//@description: 非空 不能为其对应类型的0值
|
||||
//@return: string
|
||||
|
||||
func NotEmpty() string {
|
||||
return "notEmpty"
|
||||
}
|
||||
|
||||
// @author: [zooqkl](https://github.com/zooqkl)
|
||||
// @function: RegexpMatch
|
||||
// @description: 正则校验 校验输入项是否满足正则表达式
|
||||
// @param: rule string
|
||||
// @return: string
|
||||
|
||||
func RegexpMatch(rule string) string {
|
||||
return "regexp=" + rule
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Lt
|
||||
//@description: 小于入参(<) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Lt(mark string) string {
|
||||
return "lt=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Le
|
||||
//@description: 小于等于入参(<=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Le(mark string) string {
|
||||
return "le=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Eq
|
||||
//@description: 等于入参(==) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Eq(mark string) string {
|
||||
return "eq=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Ne
|
||||
//@description: 不等于入参(!=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Ne(mark string) string {
|
||||
return "ne=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Ge
|
||||
//@description: 大于等于入参(>=) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Ge(mark string) string {
|
||||
return "ge=" + mark
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Gt
|
||||
//@description: 大于入参(>) 如果为string array Slice则为长度比较 如果是 int uint float 则为数值比较
|
||||
//@param: mark string
|
||||
//@return: string
|
||||
|
||||
func Gt(mark string) string {
|
||||
return "gt=" + mark
|
||||
}
|
||||
|
||||
//
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: Verify
|
||||
//@description: 校验方法
|
||||
//@param: st interface{}, roleMap Rules(入参实例,规则map)
|
||||
//@return: err error
|
||||
|
||||
func Verify(st interface{}, roleMap Rules) (err error) {
|
||||
compareMap := map[string]bool{
|
||||
"lt": true,
|
||||
"le": true,
|
||||
"eq": true,
|
||||
"ne": true,
|
||||
"ge": true,
|
||||
"gt": true,
|
||||
}
|
||||
|
||||
typ := reflect.TypeOf(st)
|
||||
val := reflect.ValueOf(st) // 获取reflect.Type类型
|
||||
|
||||
kd := val.Kind() // 获取到st对应的类别
|
||||
if kd != reflect.Struct {
|
||||
return errors.New("expect struct")
|
||||
}
|
||||
num := val.NumField()
|
||||
// 遍历结构体的所有字段
|
||||
for i := 0; i < num; i++ {
|
||||
tagVal := typ.Field(i)
|
||||
val := val.Field(i)
|
||||
if tagVal.Type.Kind() == reflect.Struct {
|
||||
if err = Verify(val.Interface(), roleMap); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(roleMap[tagVal.Name]) > 0 {
|
||||
for _, v := range roleMap[tagVal.Name] {
|
||||
switch {
|
||||
case v == "notEmpty":
|
||||
if isBlank(val) {
|
||||
return errors.New(tagVal.Name + "值不能为空")
|
||||
}
|
||||
case strings.Split(v, "=")[0] == "regexp":
|
||||
if !regexpMatch(strings.Split(v, "=")[1], val.String()) {
|
||||
return errors.New(tagVal.Name + "格式校验不通过")
|
||||
}
|
||||
case compareMap[strings.Split(v, "=")[0]]:
|
||||
if !compareVerify(val, v) {
|
||||
return errors.New(tagVal.Name + "长度或值不在合法范围," + v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: compareVerify
|
||||
//@description: 长度和数字的校验方法 根据类型自动校验
|
||||
//@param: value reflect.Value, VerifyStr string
|
||||
//@return: bool
|
||||
|
||||
func compareVerify(value reflect.Value, VerifyStr string) bool {
|
||||
switch value.Kind() {
|
||||
case reflect.String:
|
||||
return compare(len([]rune(value.String())), VerifyStr)
|
||||
case reflect.Slice, reflect.Array:
|
||||
return compare(value.Len(), VerifyStr)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return compare(value.Uint(), VerifyStr)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return compare(value.Float(), VerifyStr)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return compare(value.Int(), VerifyStr)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: isBlank
|
||||
//@description: 非空校验
|
||||
//@param: value reflect.Value
|
||||
//@return: bool
|
||||
|
||||
func isBlank(value reflect.Value) bool {
|
||||
switch value.Kind() {
|
||||
case reflect.String, reflect.Slice:
|
||||
return value.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !value.Bool()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return value.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return value.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return value.Float() == 0
|
||||
case reflect.Interface, reflect.Ptr:
|
||||
return value.IsNil()
|
||||
}
|
||||
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
//@function: compare
|
||||
//@description: 比较函数
|
||||
//@param: value interface{}, VerifyStr string
|
||||
//@return: bool
|
||||
|
||||
func compare(value interface{}, VerifyStr string) bool {
|
||||
VerifyStrArr := strings.Split(VerifyStr, "=")
|
||||
val := reflect.ValueOf(value)
|
||||
switch val.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
VInt, VErr := strconv.ParseInt(VerifyStrArr[1], 10, 64)
|
||||
if VErr != nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case VerifyStrArr[0] == "lt":
|
||||
return val.Int() < VInt
|
||||
case VerifyStrArr[0] == "le":
|
||||
return val.Int() <= VInt
|
||||
case VerifyStrArr[0] == "eq":
|
||||
return val.Int() == VInt
|
||||
case VerifyStrArr[0] == "ne":
|
||||
return val.Int() != VInt
|
||||
case VerifyStrArr[0] == "ge":
|
||||
return val.Int() >= VInt
|
||||
case VerifyStrArr[0] == "gt":
|
||||
return val.Int() > VInt
|
||||
default:
|
||||
return false
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
VInt, VErr := strconv.Atoi(VerifyStrArr[1])
|
||||
if VErr != nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case VerifyStrArr[0] == "lt":
|
||||
return val.Uint() < uint64(VInt)
|
||||
case VerifyStrArr[0] == "le":
|
||||
return val.Uint() <= uint64(VInt)
|
||||
case VerifyStrArr[0] == "eq":
|
||||
return val.Uint() == uint64(VInt)
|
||||
case VerifyStrArr[0] == "ne":
|
||||
return val.Uint() != uint64(VInt)
|
||||
case VerifyStrArr[0] == "ge":
|
||||
return val.Uint() >= uint64(VInt)
|
||||
case VerifyStrArr[0] == "gt":
|
||||
return val.Uint() > uint64(VInt)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
VFloat, VErr := strconv.ParseFloat(VerifyStrArr[1], 64)
|
||||
if VErr != nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case VerifyStrArr[0] == "lt":
|
||||
return val.Float() < VFloat
|
||||
case VerifyStrArr[0] == "le":
|
||||
return val.Float() <= VFloat
|
||||
case VerifyStrArr[0] == "eq":
|
||||
return val.Float() == VFloat
|
||||
case VerifyStrArr[0] == "ne":
|
||||
return val.Float() != VFloat
|
||||
case VerifyStrArr[0] == "ge":
|
||||
return val.Float() >= VFloat
|
||||
case VerifyStrArr[0] == "gt":
|
||||
return val.Float() > VFloat
|
||||
default:
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func regexpMatch(rule, matchStr string) bool {
|
||||
return regexp.MustCompile(rule).MatchString(matchStr)
|
||||
}
|
37
utils/validator_test.go
Normal file
37
utils/validator_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"git.echol.cn/loser/lckt/model/common/request"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type PageInfoTest struct {
|
||||
PageInfo request.PageInfo
|
||||
Name string
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
PageInfoVerify := Rules{"Page": {NotEmpty()}, "PageSize": {NotEmpty()}, "Name": {NotEmpty()}}
|
||||
var testInfo PageInfoTest
|
||||
testInfo.Name = "test"
|
||||
testInfo.PageInfo.Page = 0
|
||||
testInfo.PageInfo.PageSize = 0
|
||||
err := Verify(testInfo, PageInfoVerify)
|
||||
if err == nil {
|
||||
t.Error("校验失败,未能捕捉0值")
|
||||
}
|
||||
testInfo.Name = ""
|
||||
testInfo.PageInfo.Page = 1
|
||||
testInfo.PageInfo.PageSize = 10
|
||||
err = Verify(testInfo, PageInfoVerify)
|
||||
if err == nil {
|
||||
t.Error("校验失败,未能正常检测name为空")
|
||||
}
|
||||
testInfo.Name = "test"
|
||||
testInfo.PageInfo.Page = 1
|
||||
testInfo.PageInfo.PageSize = 10
|
||||
err = Verify(testInfo, PageInfoVerify)
|
||||
if err != nil {
|
||||
t.Error("校验失败,未能正常通过检测")
|
||||
}
|
||||
}
|
19
utils/verify.go
Normal file
19
utils/verify.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
var (
|
||||
IdVerify = Rules{"ID": []string{NotEmpty()}}
|
||||
ApiVerify = Rules{"Path": {NotEmpty()}, "Description": {NotEmpty()}, "ApiGroup": {NotEmpty()}, "Method": {NotEmpty()}}
|
||||
MenuVerify = Rules{"Path": {NotEmpty()}, "Name": {NotEmpty()}, "Component": {NotEmpty()}, "Sort": {Ge("0")}}
|
||||
MenuMetaVerify = Rules{"Title": {NotEmpty()}}
|
||||
LoginVerify = Rules{"CaptchaId": {NotEmpty()}, "Username": {NotEmpty()}, "Password": {NotEmpty()}}
|
||||
RegisterVerify = Rules{"Username": {NotEmpty()}, "NickName": {NotEmpty()}, "Password": {NotEmpty()}, "AuthorityId": {NotEmpty()}}
|
||||
PageInfoVerify = Rules{"Page": {NotEmpty()}, "PageSize": {NotEmpty()}}
|
||||
CustomerVerify = Rules{"CustomerName": {NotEmpty()}, "CustomerPhoneData": {NotEmpty()}}
|
||||
AutoCodeVerify = Rules{"Abbreviation": {NotEmpty()}, "StructName": {NotEmpty()}, "PackageName": {NotEmpty()}}
|
||||
AutoPackageVerify = Rules{"PackageName": {NotEmpty()}}
|
||||
AuthorityVerify = Rules{"AuthorityId": {NotEmpty()}, "AuthorityName": {NotEmpty()}}
|
||||
AuthorityIdVerify = Rules{"AuthorityId": {NotEmpty()}}
|
||||
OldAuthorityVerify = Rules{"OldAuthorityId": {NotEmpty()}}
|
||||
ChangePasswordVerify = Rules{"Password": {NotEmpty()}, "NewPassword": {NotEmpty()}}
|
||||
SetUserAuthorityVerify = Rules{"AuthorityId": {NotEmpty()}}
|
||||
)
|
53
utils/zip.go
Normal file
53
utils/zip.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 解压
|
||||
func Unzip(zipFile string, destDir string) ([]string, error) {
|
||||
zipReader, err := zip.OpenReader(zipFile)
|
||||
var paths []string
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
for _, f := range zipReader.File {
|
||||
if strings.Contains(f.Name, "..") {
|
||||
return []string{}, fmt.Errorf("%s 文件名不合法", f.Name)
|
||||
}
|
||||
fpath := filepath.Join(destDir, f.Name)
|
||||
paths = append(paths, fpath)
|
||||
if f.FileInfo().IsDir() {
|
||||
os.MkdirAll(fpath, os.ModePerm)
|
||||
} else {
|
||||
if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
inFile, err := f.Open()
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer inFile.Close()
|
||||
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
_, err = io.Copy(outFile, inFile)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return paths, nil
|
||||
}
|
Reference in New Issue
Block a user