JM-WechatMini/utils/ast/ast_enter.go
2023-11-02 04:34:46 +08:00

182 lines
3.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ast
import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"log"
"os"
"strconv"
"strings"
)
type Visitor struct {
ImportCode string
StructName string
PackageName string
GroupName string
}
func (vi *Visitor) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) {
case *ast.GenDecl:
// 查找有没有import context包
// Notice没有考虑没有import任何包的情况
if n.Tok == token.IMPORT && vi.ImportCode != "" {
vi.addImport(n)
// 不需要再遍历子树
return nil
}
if n.Tok == token.TYPE && vi.StructName != "" && vi.PackageName != "" && vi.GroupName != "" {
vi.addStruct(n)
return nil
}
case *ast.FuncDecl:
if n.Name.Name == "Routers" {
vi.addFuncBodyVar(n)
return nil
}
}
return vi
}
func (vi *Visitor) addStruct(genDecl *ast.GenDecl) ast.Visitor {
for i := range genDecl.Specs {
switch n := genDecl.Specs[i].(type) {
case *ast.TypeSpec:
if strings.Index(n.Name.Name, "Group") > -1 {
switch t := n.Type.(type) {
case *ast.StructType:
f := &ast.Field{
Names: []*ast.Ident{
{
Name: vi.StructName,
Obj: &ast.Object{
Kind: ast.Var,
Name: vi.StructName,
},
},
},
Type: &ast.SelectorExpr{
X: &ast.Ident{
Name: vi.PackageName,
},
Sel: &ast.Ident{
Name: vi.GroupName,
},
},
}
t.Fields.List = append(t.Fields.List, f)
}
}
}
}
return vi
}
func (vi *Visitor) addImport(genDecl *ast.GenDecl) ast.Visitor {
// 是否已经import
hasImported := false
for _, v := range genDecl.Specs {
importSpec := v.(*ast.ImportSpec)
// 如果已经包含
if importSpec.Path.Value == strconv.Quote(vi.ImportCode) {
hasImported = true
}
}
if !hasImported {
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(vi.ImportCode),
},
})
}
return vi
}
func (vi *Visitor) addFuncBodyVar(funDecl *ast.FuncDecl) ast.Visitor {
hasVar := false
for _, v := range funDecl.Body.List {
switch varSpec := v.(type) {
case *ast.AssignStmt:
for i := range varSpec.Lhs {
switch nn := varSpec.Lhs[i].(type) {
case *ast.Ident:
if nn.Name == vi.PackageName+"Router" {
hasVar = true
}
}
}
}
}
if !hasVar {
assignStmt := &ast.AssignStmt{
Lhs: []ast.Expr{
&ast.Ident{
Name: vi.PackageName + "Router",
Obj: &ast.Object{
Kind: ast.Var,
Name: vi.PackageName + "Router",
},
},
},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.Ident{
Name: "router",
},
Sel: &ast.Ident{
Name: "RouterGroupApp",
},
},
Sel: &ast.Ident{
Name: cases.Title(language.English).String(vi.PackageName),
},
},
},
}
funDecl.Body.List = append(funDecl.Body.List, funDecl.Body.List[1])
index := 1
copy(funDecl.Body.List[index+1:], funDecl.Body.List[index:])
funDecl.Body.List[index] = assignStmt
}
return vi
}
func ImportReference(filepath, importCode, structName, packageName, groupName string) error {
fSet := token.NewFileSet()
fParser, err := parser.ParseFile(fSet, filepath, nil, parser.ParseComments)
if err != nil {
return err
}
importCode = strings.TrimSpace(importCode)
v := &Visitor{
ImportCode: importCode,
StructName: structName,
PackageName: packageName,
GroupName: groupName,
}
if importCode == "" {
ast.Print(fSet, fParser)
}
ast.Walk(v, fParser)
var output []byte
buffer := bytes.NewBuffer(output)
err = format.Node(buffer, fSet, fParser)
if err != nil {
log.Fatal(err)
}
// 写回数据
return os.WriteFile(filepath, buffer.Bytes(), 0o600)
}