🎨 优化项目结构 && 完善ai配置

This commit is contained in:
2026-03-03 15:39:23 +08:00
parent 557c865948
commit 2714e63d2a
585 changed files with 62223 additions and 100018 deletions

View File

@@ -1,84 +0,0 @@
package utils
import (
"errors"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/golang-jwt/jwt/v5"
)
const (
UserTypeApp = "app" // 前台用户类型标识
)
// AppJWTClaims 前台用户 JWT Claims
type AppJWTClaims struct {
UserID uint `json:"userId"`
Username string `json:"username"`
UserType string `json:"userType"` // 用户类型标识
jwt.RegisteredClaims
}
// CreateAppToken 创建前台用户 Token有效期 7 天)
func CreateAppToken(userID uint, username string) (tokenString string, expiresAt int64, err error) {
// Token 有效期为 7 天
expiresTime := time.Now().Add(7 * 24 * time.Hour)
expiresAt = expiresTime.Unix()
claims := AppJWTClaims{
UserID: userID,
Username: username,
UserType: UserTypeApp,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: global.GVA_CONFIG.JWT.Issuer,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err = token.SignedString([]byte(global.GVA_CONFIG.JWT.SigningKey))
return
}
// CreateAppRefreshToken 创建前台用户刷新 Token有效期更长
func CreateAppRefreshToken(userID uint, username string) (tokenString string, expiresAt int64, err error) {
// 刷新 Token 有效期为 7 天
expiresTime := time.Now().Add(7 * 24 * time.Hour)
expiresAt = expiresTime.Unix()
claims := AppJWTClaims{
UserID: userID,
Username: username,
UserType: UserTypeApp,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: global.GVA_CONFIG.JWT.Issuer,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err = token.SignedString([]byte(global.GVA_CONFIG.JWT.SigningKey))
return
}
// ParseAppToken 解析前台用户 Token
func ParseAppToken(tokenString string) (*AppJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AppJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(global.GVA_CONFIG.JWT.SigningKey), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*AppJWTClaims); ok && token.Valid {
return claims, nil
}
return nil, errors.New("invalid token")
}

409
server/utils/ast/ast.go Normal file
View File

@@ -0,0 +1,409 @@
package ast
import (
"fmt"
"git.echol.cn/loser/ai_proxy/server/model/system"
"go/ast"
"go/parser"
"go/token"
"log"
)
// AddImport 增加 import 方法
func AddImport(astNode ast.Node, imp string) {
impStr := fmt.Sprintf("\"%s\"", imp)
ast.Inspect(astNode, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok {
if genDecl.Tok == token.IMPORT {
for i := range genDecl.Specs {
if impNode, ok := genDecl.Specs[i].(*ast.ImportSpec); ok {
if impNode.Path.Value == impStr {
return false
}
}
}
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: impStr,
},
})
}
}
return true
})
}
// FindFunction 查询特定function方法
func FindFunction(astNode ast.Node, FunctionName string) *ast.FuncDecl {
var funcDeclP *ast.FuncDecl
ast.Inspect(astNode, func(node ast.Node) bool {
if funcDecl, ok := node.(*ast.FuncDecl); ok {
if funcDecl.Name.String() == FunctionName {
funcDeclP = funcDecl
return false
}
}
return true
})
return funcDeclP
}
// FindArray 查询特定数组方法
func FindArray(astNode ast.Node, identName, selectorExprName string) *ast.CompositeLit {
var assignStmt *ast.CompositeLit
ast.Inspect(astNode, func(n ast.Node) bool {
switch node := n.(type) {
case *ast.AssignStmt:
for _, expr := range node.Rhs {
if exprType, ok := expr.(*ast.CompositeLit); ok {
if arrayType, ok := exprType.Type.(*ast.ArrayType); ok {
sel, ok1 := arrayType.Elt.(*ast.SelectorExpr)
x, ok2 := sel.X.(*ast.Ident)
if ok1 && ok2 && x.Name == identName && sel.Sel.Name == selectorExprName {
assignStmt = exprType
return false
}
}
}
}
}
return true
})
return assignStmt
}
func CreateMenuStructAst(menus []system.SysBaseMenu) *[]ast.Expr {
var menuElts []ast.Expr
for i := range menus {
elts := []ast.Expr{ // 结构体的字段
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "ParentId"},
Value: &ast.BasicLit{Kind: token.INT, Value: "0"},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Path"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Path)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Name"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Name)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Hidden"},
Value: &ast.Ident{Name: "false"},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Component"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Component)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Sort"},
Value: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", menus[i].Sort)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Meta"},
Value: &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "Meta"},
},
Elts: []ast.Expr{
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Title"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Title)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Icon"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", menus[i].Icon)},
},
},
},
},
}
// 添加菜单参数
if len(menus[i].Parameters) > 0 {
var paramElts []ast.Expr
for _, param := range menus[i].Parameters {
paramElts = append(paramElts, &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "SysBaseMenuParameter"},
},
Elts: []ast.Expr{
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Type"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", param.Type)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Key"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", param.Key)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Value"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", param.Value)},
},
},
})
}
elts = append(elts, &ast.KeyValueExpr{
Key: &ast.Ident{Name: "Parameters"},
Value: &ast.CompositeLit{
Type: &ast.ArrayType{
Elt: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "SysBaseMenuParameter"},
},
},
Elts: paramElts,
},
})
}
// 添加菜单按钮
if len(menus[i].MenuBtn) > 0 {
var btnElts []ast.Expr
for _, btn := range menus[i].MenuBtn {
btnElts = append(btnElts, &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "SysBaseMenuBtn"},
},
Elts: []ast.Expr{
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Name"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", btn.Name)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Desc"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", btn.Desc)},
},
},
})
}
elts = append(elts, &ast.KeyValueExpr{
Key: &ast.Ident{Name: "MenuBtn"},
Value: &ast.CompositeLit{
Type: &ast.ArrayType{
Elt: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "SysBaseMenuBtn"},
},
},
Elts: btnElts,
},
})
}
menuElts = append(menuElts, &ast.CompositeLit{
Type: nil,
Elts: elts,
})
}
return &menuElts
}
func CreateApiStructAst(apis []system.SysApi) *[]ast.Expr {
var apiElts []ast.Expr
for i := range apis {
elts := []ast.Expr{ // 结构体的字段
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Path"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].Path)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Description"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].Description)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "ApiGroup"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].ApiGroup)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Method"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", apis[i].Method)},
},
}
apiElts = append(apiElts, &ast.CompositeLit{
Type: nil,
Elts: elts,
})
}
return &apiElts
}
// CheckImport 检查是否存在Import
func CheckImport(file *ast.File, importPath string) bool {
for _, imp := range file.Imports {
// Remove quotes around the import path
path := imp.Path.Value[1 : len(imp.Path.Value)-1]
if path == importPath {
return true
}
}
return false
}
func clearPosition(astNode ast.Node) {
ast.Inspect(astNode, func(n ast.Node) bool {
switch node := n.(type) {
case *ast.Ident:
// 清除位置信息
node.NamePos = token.NoPos
case *ast.CallExpr:
// 清除位置信息
node.Lparen = token.NoPos
node.Rparen = token.NoPos
case *ast.BasicLit:
// 清除位置信息
node.ValuePos = token.NoPos
case *ast.SelectorExpr:
// 清除位置信息
node.Sel.NamePos = token.NoPos
case *ast.BinaryExpr:
node.OpPos = token.NoPos
case *ast.UnaryExpr:
node.OpPos = token.NoPos
case *ast.StarExpr:
node.Star = token.NoPos
}
return true
})
}
func CreateStmt(statement string) *ast.ExprStmt {
expr, err := parser.ParseExpr(statement)
if err != nil {
log.Fatal(err)
}
clearPosition(expr)
return &ast.ExprStmt{X: expr}
}
func IsBlockStmt(node ast.Node) bool {
_, ok := node.(*ast.BlockStmt)
return ok
}
func VariableExistsInBlock(block *ast.BlockStmt, varName string) bool {
exists := false
ast.Inspect(block, func(n ast.Node) bool {
switch node := n.(type) {
case *ast.AssignStmt:
for _, expr := range node.Lhs {
if ident, ok := expr.(*ast.Ident); ok && ident.Name == varName {
exists = true
return false
}
}
}
return true
})
return exists
}
func CreateDictionaryStructAst(dictionaries []system.SysDictionary) *[]ast.Expr {
var dictElts []ast.Expr
for i := range dictionaries {
statusStr := "true"
if dictionaries[i].Status != nil && !*dictionaries[i].Status {
statusStr = "false"
}
elts := []ast.Expr{
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Name"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", dictionaries[i].Name)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Type"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", dictionaries[i].Type)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Status"},
Value: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "utils"},
Sel: &ast.Ident{Name: "Pointer"},
},
Args: []ast.Expr{
&ast.Ident{Name: statusStr},
},
},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Desc"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", dictionaries[i].Desc)},
},
}
if len(dictionaries[i].SysDictionaryDetails) > 0 {
var detailElts []ast.Expr
for _, detail := range dictionaries[i].SysDictionaryDetails {
detailStatusStr := "true"
if detail.Status != nil && !*detail.Status {
detailStatusStr = "false"
}
detailElts = append(detailElts, &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "SysDictionaryDetail"},
},
Elts: []ast.Expr{
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Label"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", detail.Label)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Value"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", detail.Value)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Extend"},
Value: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", detail.Extend)},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Status"},
Value: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "utils"},
Sel: &ast.Ident{Name: "Pointer"},
},
Args: []ast.Expr{
&ast.Ident{Name: detailStatusStr},
},
},
},
&ast.KeyValueExpr{
Key: &ast.Ident{Name: "Sort"},
Value: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", detail.Sort)},
},
},
})
}
elts = append(elts, &ast.KeyValueExpr{
Key: &ast.Ident{Name: "SysDictionaryDetails"},
Value: &ast.CompositeLit{
Type: &ast.ArrayType{Elt: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "SysDictionaryDetail"},
}},
Elts: detailElts,
},
})
}
dictElts = append(dictElts, &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: "model"},
Sel: &ast.Ident{Name: "SysDictionary"},
},
Elts: elts,
})
}
return &dictElts
}

View File

@@ -0,0 +1,47 @@
package ast
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
)
func ImportForAutoEnter(path string, funcName string, code string) {
src, err := os.ReadFile(path)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
ast.Inspect(astFile, func(node ast.Node) bool {
if typeSpec, ok := node.(*ast.TypeSpec); ok {
if typeSpec.Name.Name == funcName {
if st, ok := typeSpec.Type.(*ast.StructType); ok {
for i := range st.Fields.List {
if t, ok := st.Fields.List[i].Type.(*ast.Ident); ok {
if t.Name == code {
return false
}
}
}
sn := &ast.Field{
Type: &ast.Ident{Name: code},
}
st.Fields.List = append(st.Fields.List, sn)
}
}
}
return true
})
var out []byte
bf := bytes.NewBuffer(out)
err = printer.Fprint(bf, fileSet, astFile)
if err != nil {
return
}
_ = os.WriteFile(path, bf.Bytes(), 0666)
}

View File

@@ -0,0 +1,181 @@
package ast
import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"log"
"os"
"strconv"
"strings"
)
type Visitor struct {
ImportCode string
StructName string
PackageName string
GroupName string
}
func (vi *Visitor) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) {
case *ast.GenDecl:
// 查找有没有import context包
// Notice没有考虑没有import任何包的情况
if n.Tok == token.IMPORT && vi.ImportCode != "" {
vi.addImport(n)
// 不需要再遍历子树
return nil
}
if n.Tok == token.TYPE && vi.StructName != "" && vi.PackageName != "" && vi.GroupName != "" {
vi.addStruct(n)
return nil
}
case *ast.FuncDecl:
if n.Name.Name == "Routers" {
vi.addFuncBodyVar(n)
return nil
}
}
return vi
}
func (vi *Visitor) addStruct(genDecl *ast.GenDecl) ast.Visitor {
for i := range genDecl.Specs {
switch n := genDecl.Specs[i].(type) {
case *ast.TypeSpec:
if strings.Index(n.Name.Name, "Group") > -1 {
switch t := n.Type.(type) {
case *ast.StructType:
f := &ast.Field{
Names: []*ast.Ident{
{
Name: vi.StructName,
Obj: &ast.Object{
Kind: ast.Var,
Name: vi.StructName,
},
},
},
Type: &ast.SelectorExpr{
X: &ast.Ident{
Name: vi.PackageName,
},
Sel: &ast.Ident{
Name: vi.GroupName,
},
},
}
t.Fields.List = append(t.Fields.List, f)
}
}
}
}
return vi
}
func (vi *Visitor) addImport(genDecl *ast.GenDecl) ast.Visitor {
// 是否已经import
hasImported := false
for _, v := range genDecl.Specs {
importSpec := v.(*ast.ImportSpec)
// 如果已经包含
if importSpec.Path.Value == strconv.Quote(vi.ImportCode) {
hasImported = true
}
}
if !hasImported {
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(vi.ImportCode),
},
})
}
return vi
}
func (vi *Visitor) addFuncBodyVar(funDecl *ast.FuncDecl) ast.Visitor {
hasVar := false
for _, v := range funDecl.Body.List {
switch varSpec := v.(type) {
case *ast.AssignStmt:
for i := range varSpec.Lhs {
switch nn := varSpec.Lhs[i].(type) {
case *ast.Ident:
if nn.Name == vi.PackageName+"Router" {
hasVar = true
}
}
}
}
}
if !hasVar {
assignStmt := &ast.AssignStmt{
Lhs: []ast.Expr{
&ast.Ident{
Name: vi.PackageName + "Router",
Obj: &ast.Object{
Kind: ast.Var,
Name: vi.PackageName + "Router",
},
},
},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.Ident{
Name: "router",
},
Sel: &ast.Ident{
Name: "RouterGroupApp",
},
},
Sel: &ast.Ident{
Name: cases.Title(language.English).String(vi.PackageName),
},
},
},
}
funDecl.Body.List = append(funDecl.Body.List, funDecl.Body.List[1])
index := 1
copy(funDecl.Body.List[index+1:], funDecl.Body.List[index:])
funDecl.Body.List[index] = assignStmt
}
return vi
}
func ImportReference(filepath, importCode, structName, packageName, groupName string) error {
fSet := token.NewFileSet()
fParser, err := parser.ParseFile(fSet, filepath, nil, parser.ParseComments)
if err != nil {
return err
}
importCode = strings.TrimSpace(importCode)
v := &Visitor{
ImportCode: importCode,
StructName: structName,
PackageName: packageName,
GroupName: groupName,
}
if importCode == "" {
ast.Print(fSet, fParser)
}
ast.Walk(v, fParser)
var output []byte
buffer := bytes.NewBuffer(output)
err = format.Node(buffer, fSet, fParser)
if err != nil {
log.Fatal(err)
}
// 写回数据
return os.WriteFile(filepath, buffer.Bytes(), 0o600)
}

View File

@@ -0,0 +1,166 @@
package ast
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
)
// AddRegisterTablesAst 自动为 gorm.go 注册一个自动迁移
func AddRegisterTablesAst(path, funcName, pk, varName, dbName, model string) {
modelPk := fmt.Sprintf("git.echol.cn/loser/ai_proxy/server/model/%s", pk)
src, err := os.ReadFile(path)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
if err != nil {
fmt.Println(err)
}
AddImport(astFile, modelPk)
FuncNode := FindFunction(astFile, funcName)
if FuncNode != nil {
ast.Print(fileSet, FuncNode)
}
addDBVar(FuncNode.Body, varName, dbName)
addAutoMigrate(FuncNode.Body, varName, pk, model)
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
os.WriteFile(path, bf.Bytes(), 0666)
}
// 增加一个 db库变量
func addDBVar(astBody *ast.BlockStmt, varName, dbName string) {
if dbName == "" {
return
}
dbStr := fmt.Sprintf("\"%s\"", dbName)
for i := range astBody.List {
if assignStmt, ok := astBody.List[i].(*ast.AssignStmt); ok {
if ident, ok := assignStmt.Lhs[0].(*ast.Ident); ok {
if ident.Name == varName {
return
}
}
}
}
assignNode := &ast.AssignStmt{
Lhs: []ast.Expr{
&ast.Ident{
Name: varName,
},
},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{
Name: "global",
},
Sel: &ast.Ident{
Name: "GetGlobalDBByDBName",
},
},
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: dbStr,
},
},
},
},
}
astBody.List = append([]ast.Stmt{assignNode}, astBody.List...)
}
// 为db库变量增加 AutoMigrate 方法
func addAutoMigrate(astBody *ast.BlockStmt, dbname string, pk string, model string) {
if dbname == "" {
dbname = "db"
}
flag := true
ast.Inspect(astBody, func(node ast.Node) bool {
// 首先判断需要加入的方法调用语句是否存在 不存在则直接走到下方逻辑
switch n := node.(type) {
case *ast.CallExpr:
// 判断是否找到了AutoMigrate语句
if s, ok := n.Fun.(*ast.SelectorExpr); ok {
if x, ok := s.X.(*ast.Ident); ok {
if s.Sel.Name == "AutoMigrate" && x.Name == dbname {
flag = false
if !NeedAppendModel(n, pk, model) {
return false
}
// 判断已经找到了AutoMigrate语句
n.Args = append(n.Args, &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{
Name: pk,
},
Sel: &ast.Ident{
Name: model,
},
},
})
return false
}
}
}
}
return true
//然后判断 pk.model是否存在 如果存在直接跳出 如果不存在 则向已经找到的方法调用语句的node里面push一条
})
if flag {
exprStmt := &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{
Name: dbname,
},
Sel: &ast.Ident{
Name: "AutoMigrate",
},
},
Args: []ast.Expr{
&ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{
Name: pk,
},
Sel: &ast.Ident{
Name: model,
},
},
},
},
}}
astBody.List = append(astBody.List, exprStmt)
}
}
// NeedAppendModel 为automigrate增加实参
func NeedAppendModel(callNode ast.Node, pk string, model string) bool {
flag := true
ast.Inspect(callNode, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.SelectorExpr:
if x, ok := n.X.(*ast.Ident); ok {
if n.Sel.Name == model && x.Name == pk {
flag = false
return false
}
}
}
return true
})
return flag
}

View File

@@ -0,0 +1,11 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
)
func init() {
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("../../../")
global.GVA_CONFIG.AutoCode.Server = "server"
}

View File

@@ -0,0 +1,173 @@
package ast
import (
"bytes"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
"path/filepath"
)
func RollBackAst(pk, model string) {
RollGormBack(pk, model)
RollRouterBack(pk, model)
}
func RollGormBack(pk, model string) {
// 首先分析存在多少个ttt作为调用方的node块
// 如果多个 仅仅删除对应块即可
// 如果单个 那么还需要剔除import
path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go")
src, err := os.ReadFile(path)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
if err != nil {
fmt.Println(err)
}
var n *ast.CallExpr
var k int = -1
var pkNum = 0
ast.Inspect(astFile, func(node ast.Node) bool {
if node, ok := node.(*ast.CallExpr); ok {
for i := range node.Args {
pkOK := false
modelOK := false
ast.Inspect(node.Args[i], func(item ast.Node) bool {
if ii, ok := item.(*ast.Ident); ok {
if ii.Name == pk {
pkOK = true
pkNum++
}
if ii.Name == model {
modelOK = true
}
}
if pkOK && modelOK {
n = node
k = i
}
return true
})
}
}
return true
})
if k > -1 {
n.Args = append(append([]ast.Expr{}, n.Args[:k]...), n.Args[k+1:]...)
}
if pkNum == 1 {
var imI int = -1
var gp *ast.GenDecl
ast.Inspect(astFile, func(node ast.Node) bool {
if gen, ok := node.(*ast.GenDecl); ok {
for i := range gen.Specs {
if imspec, ok := gen.Specs[i].(*ast.ImportSpec); ok {
if imspec.Path.Value == "\"git.echol.cn/loser/ai_proxy/server/model/"+pk+"\"" {
gp = gen
imI = i
return false
}
}
}
}
return true
})
if imI > -1 {
gp.Specs = append(append([]ast.Spec{}, gp.Specs[:imI]...), gp.Specs[imI+1:]...)
}
}
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
os.Remove(path)
os.WriteFile(path, bf.Bytes(), 0666)
}
func RollRouterBack(pk, model string) {
// 首先抓到所有的代码块结构 {}
// 分析结构中是否存在一个变量叫做 pk+Router
// 然后获取到代码块指针 对内部需要回滚的代码进行剔除
path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go")
src, err := os.ReadFile(path)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
if err != nil {
fmt.Println(err)
}
var block *ast.BlockStmt
var routerStmt *ast.FuncDecl
ast.Inspect(astFile, func(node ast.Node) bool {
if n, ok := node.(*ast.FuncDecl); ok {
if n.Name.Name == "initBizRouter" {
routerStmt = n
}
}
if n, ok := node.(*ast.BlockStmt); ok {
ast.Inspect(n, func(bNode ast.Node) bool {
if in, ok := bNode.(*ast.Ident); ok {
if in.Name == pk+"Router" {
block = n
return false
}
}
return true
})
return true
}
return true
})
var k int
for i := range block.List {
if stmtNode, ok := block.List[i].(*ast.ExprStmt); ok {
ast.Inspect(stmtNode, func(node ast.Node) bool {
if n, ok := node.(*ast.Ident); ok {
if n.Name == "Init"+model+"Router" {
k = i
return false
}
}
return true
})
}
}
block.List = append(append([]ast.Stmt{}, block.List[:k]...), block.List[k+1:]...)
if len(block.List) == 1 {
// 说明这个块就没任何意义了
block.List = nil
}
for i, n := range routerStmt.Body.List {
if n, ok := n.(*ast.BlockStmt); ok {
if n.List == nil {
routerStmt.Body.List = append(append([]ast.Stmt{}, routerStmt.Body.List[:i]...), routerStmt.Body.List[i+1:]...)
i--
}
}
}
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
os.Remove(path)
os.WriteFile(path, bf.Bytes(), 0666)
}

View File

@@ -0,0 +1,135 @@
package ast
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
"strings"
)
func AppendNodeToList(stmts []ast.Stmt, stmt ast.Stmt, index int) []ast.Stmt {
return append(stmts[:index], append([]ast.Stmt{stmt}, stmts[index:]...)...)
}
func AddRouterCode(path, funcName, pk, model string) {
src, err := os.ReadFile(path)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, parser.ParseComments)
if err != nil {
fmt.Println(err)
}
FuncNode := FindFunction(astFile, funcName)
pkName := strings.ToUpper(pk[:1]) + pk[1:]
routerName := fmt.Sprintf("%sRouter", pk)
modelName := fmt.Sprintf("Init%sRouter", model)
var bloctPre *ast.BlockStmt
for i := len(FuncNode.Body.List) - 1; i >= 0; i-- {
if block, ok := FuncNode.Body.List[i].(*ast.BlockStmt); ok {
bloctPre = block
}
}
ast.Print(fileSet, FuncNode)
if ok, b := needAppendRouter(FuncNode, pk); ok {
routerNode :=
&ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{
&ast.Ident{Name: routerName},
},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.Ident{Name: "router"},
Sel: &ast.Ident{Name: "RouterGroupApp"},
},
Sel: &ast.Ident{Name: pkName},
},
},
},
},
}
FuncNode.Body.List = AppendNodeToList(FuncNode.Body.List, routerNode, len(FuncNode.Body.List)-1)
bloctPre = routerNode
} else {
bloctPre = b
}
if needAppendInit(FuncNode, routerName, modelName) {
bloctPre.List = append(bloctPre.List,
&ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: routerName},
Sel: &ast.Ident{Name: modelName},
},
Args: []ast.Expr{
&ast.Ident{
Name: "privateGroup",
},
&ast.Ident{
Name: "publicGroup",
},
},
},
})
}
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
os.WriteFile(path, bf.Bytes(), 0666)
}
func needAppendRouter(funcNode ast.Node, pk string) (bool, *ast.BlockStmt) {
flag := true
var block *ast.BlockStmt
ast.Inspect(funcNode, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.BlockStmt:
for i := range n.List {
if assignNode, ok := n.List[i].(*ast.AssignStmt); ok {
if identNode, ok := assignNode.Lhs[0].(*ast.Ident); ok {
if identNode.Name == fmt.Sprintf("%sRouter", pk) {
flag = false
block = n
return false
}
}
}
}
}
return true
})
return flag, block
}
func needAppendInit(funcNode ast.Node, routerName string, modelName string) bool {
flag := true
ast.Inspect(funcNode, func(node ast.Node) bool {
switch n := funcNode.(type) {
case *ast.CallExpr:
if selectNode, ok := n.Fun.(*ast.SelectorExpr); ok {
x, xok := selectNode.X.(*ast.Ident)
if xok && x.Name == routerName && selectNode.Sel.Name == modelName {
flag = false
return false
}
}
}
return true
})
return flag
}

View File

@@ -0,0 +1,32 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"os"
"path/filepath"
"testing"
)
func TestAst(t *testing.T) {
filename := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "plugin.go")
fileSet := token.NewFileSet()
file, err := parser.ParseFile(fileSet, filename, nil, parser.ParseComments)
if err != nil {
t.Error(err)
return
}
err = ast.Print(fileSet, file)
if err != nil {
t.Error(err)
return
}
err = printer.Fprint(os.Stdout, token.NewFileSet(), file)
if err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,53 @@
package ast
type Type string
func (r Type) String() string {
return string(r)
}
func (r Type) Group() string {
switch r {
case TypePackageApiEnter:
return "ApiGroup"
case TypePackageRouterEnter:
return "RouterGroup"
case TypePackageServiceEnter:
return "ServiceGroup"
case TypePackageApiModuleEnter:
return "ApiGroup"
case TypePackageRouterModuleEnter:
return "RouterGroup"
case TypePackageServiceModuleEnter:
return "ServiceGroup"
case TypePluginApiEnter:
return "api"
case TypePluginRouterEnter:
return "router"
case TypePluginServiceEnter:
return "service"
default:
return ""
}
}
const (
TypePackageApiEnter = "PackageApiEnter" // server/api/v1/enter.go
TypePackageRouterEnter = "PackageRouterEnter" // server/router/enter.go
TypePackageServiceEnter = "PackageServiceEnter" // server/service/enter.go
TypePackageApiModuleEnter = "PackageApiModuleEnter" // server/api/v1/{package}/enter.go
TypePackageRouterModuleEnter = "PackageRouterModuleEnter" // server/router/{package}/enter.go
TypePackageServiceModuleEnter = "PackageServiceModuleEnter" // server/service/{package}/enter.go
TypePackageInitializeGorm = "PackageInitializeGorm" // server/initialize/gorm_biz.go
TypePackageInitializeRouter = "PackageInitializeRouter" // server/initialize/router_biz.go
TypePluginGen = "PluginGen" // server/plugin/{package}/gen/main.go
TypePluginApiEnter = "PluginApiEnter" // server/plugin/{package}/enter.go
TypePluginInitializeV1 = "PluginInitializeV1" // server/initialize/plugin_biz_v1.go
TypePluginInitializeV2 = "PluginInitializeV2" // server/plugin/register.go
TypePluginRouterEnter = "PluginRouterEnter" // server/plugin/{package}/enter.go
TypePluginServiceEnter = "PluginServiceEnter" // server/plugin/{package}/enter.go
TypePluginInitializeApi = "PluginInitializeApi" // server/plugin/{package}/initialize/api.go
TypePluginInitializeGorm = "PluginInitializeGorm" // server/plugin/{package}/initialize/gorm.go
TypePluginInitializeMenu = "PluginInitializeMenu" // server/plugin/{package}/initialize/menu.go
TypePluginInitializeRouter = "PluginInitializeRouter" // server/plugin/{package}/initialize/router.go
)

View File

@@ -0,0 +1,62 @@
package ast
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
)
// ExtractFuncSourceByPosition 根据文件路径与行号,提取包含该行的整个方法源码
// 返回:方法名、完整源码、起止行号
func ExtractFuncSourceByPosition(filePath string, line int) (name string, source string, startLine int, endLine int, err error) {
// 读取源文件
src, readErr := os.ReadFile(filePath)
if readErr != nil {
err = fmt.Errorf("read file failed: %w", readErr)
return
}
// 解析 AST
fset := token.NewFileSet()
file, parseErr := parser.ParseFile(fset, filePath, src, parser.ParseComments)
if parseErr != nil {
err = fmt.Errorf("parse file failed: %w", parseErr)
return
}
// 在 AST 中定位包含指定行号的函数声明
var target *ast.FuncDecl
ast.Inspect(file, func(n ast.Node) bool {
fd, ok := n.(*ast.FuncDecl)
if !ok {
return true
}
s := fset.Position(fd.Pos()).Line
e := fset.Position(fd.End()).Line
if line >= s && line <= e {
target = fd
startLine = s
endLine = e
return false
}
return true
})
if target == nil {
err = fmt.Errorf("no function encloses line %d in %s", line, filePath)
return
}
// 使用字节偏移精确提取源码片段(包含注释与原始格式)
start := fset.Position(target.Pos()).Offset
end := fset.Position(target.End()).Offset
if start < 0 || end > len(src) || start >= end {
err = fmt.Errorf("invalid offsets for function: start=%d end=%d len=%d", start, end, len(src))
return
}
source = string(src[start:end])
name = target.Name.Name
return
}

View File

@@ -0,0 +1,94 @@
package ast
import (
"go/ast"
"go/token"
"io"
"strings"
)
type Import struct {
Base
ImportPath string // 导包路径
}
func NewImport(importPath string) *Import {
return &Import{ImportPath: importPath}
}
func (a *Import) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
return a.Base.Parse(filename, writer)
}
func (a *Import) Rollback(file *ast.File) error {
if a.ImportPath == "" {
return nil
}
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.GenDecl)
if o1 {
if v1.Tok != token.IMPORT {
break
}
for j := 0; j < len(v1.Specs); j++ {
v2, o2 := v1.Specs[j].(*ast.ImportSpec)
if o2 && strings.HasSuffix(a.ImportPath, v2.Path.Value) {
v1.Specs = append(v1.Specs[:j], v1.Specs[j+1:]...)
if len(v1.Specs) == 0 {
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
} // 如果没有import声明就删除, 如果不删除则会出现import()
break
}
}
}
}
return nil
}
func (a *Import) Injection(file *ast.File) error {
if a.ImportPath == "" {
return nil
}
var has bool
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.GenDecl)
if o1 {
if v1.Tok != token.IMPORT {
break
}
for j := 0; j < len(v1.Specs); j++ {
v2, o2 := v1.Specs[j].(*ast.ImportSpec)
if o2 && strings.HasSuffix(a.ImportPath, v2.Path.Value) {
has = true
break
}
}
if !has {
spec := &ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: a.ImportPath},
}
v1.Specs = append(v1.Specs, spec)
return nil
}
}
}
if !has {
decls := file.Decls
file.Decls = make([]ast.Decl, 0, len(file.Decls)+1)
decl := &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: a.ImportPath},
},
},
}
file.Decls = append(file.Decls, decl)
file.Decls = append(file.Decls, decls...)
} // 如果没有import声明就创建一个, 主要要放在第一个
return nil
}
func (a *Import) Format(filename string, writer io.Writer, file *ast.File) error {
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,17 @@
package ast
import (
"go/ast"
"io"
)
type Ast interface {
// Parse 解析文件/代码
Parse(filename string, writer io.Writer) (file *ast.File, err error)
// Rollback 回滚
Rollback(file *ast.File) error
// Injection 注入
Injection(file *ast.File) error
// Format 格式化输出
Format(filename string, writer io.Writer, file *ast.File) error
}

View File

@@ -0,0 +1,76 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/pkg/errors"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io"
"os"
"path"
"path/filepath"
"strings"
)
type Base struct{}
func (a *Base) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
fileSet := token.NewFileSet()
if writer != nil {
file, err = parser.ParseFile(fileSet, filename, nil, parser.ParseComments)
} else {
file, err = parser.ParseFile(fileSet, filename, writer, parser.ParseComments)
}
if err != nil {
return nil, errors.Wrapf(err, "[filepath:%s]打开/解析文件失败!", filename)
}
return file, nil
}
func (a *Base) Rollback(file *ast.File) error {
return nil
}
func (a *Base) Injection(file *ast.File) error {
return nil
}
func (a *Base) Format(filename string, writer io.Writer, file *ast.File) error {
fileSet := token.NewFileSet()
if writer == nil {
open, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666)
defer open.Close()
if err != nil {
return errors.Wrapf(err, "[filepath:%s]打开文件失败!", filename)
}
writer = open
}
err := format.Node(writer, fileSet, file)
if err != nil {
return errors.Wrapf(err, "[filepath:%s]注入失败!", filename)
}
return nil
}
// RelativePath 绝对路径转相对路径
func (a *Base) RelativePath(filePath string) string {
server := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server)
hasServer := strings.Index(filePath, server)
if hasServer != -1 {
filePath = strings.TrimPrefix(filePath, server)
keys := strings.Split(filePath, string(filepath.Separator))
filePath = path.Join(keys...)
}
return filePath
}
// AbsolutePath 相对路径转绝对路径
func (a *Base) AbsolutePath(filePath string) string {
server := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server)
keys := strings.Split(filePath, "/")
filePath = filepath.Join(keys...)
filePath = filepath.Join(server, filePath)
return filePath
}

View File

@@ -0,0 +1,85 @@
package ast
import (
"go/ast"
"go/token"
"io"
)
// PackageEnter 模块化入口
type PackageEnter struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
StructName string // 结构体名称
PackageName string // 包名
RelativePath string // 相对路径
PackageStructName string // 包结构体名称
}
func (a *PackageEnter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PackageEnter) Rollback(file *ast.File) error {
// 无需回滚
return nil
}
func (a *PackageEnter) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
ast.Inspect(file, func(n ast.Node) bool {
genDecl, ok := n.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
return true
}
for _, spec := range genDecl.Specs {
typeSpec, specok := spec.(*ast.TypeSpec)
if !specok || typeSpec.Name.Name != a.Type.Group() {
continue
}
structType, structTypeOK := typeSpec.Type.(*ast.StructType)
if !structTypeOK {
continue
}
for _, field := range structType.Fields.List {
if len(field.Names) == 1 && field.Names[0].Name == a.StructName {
return true
}
}
field := &ast.Field{
Names: []*ast.Ident{{Name: a.StructName}},
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.PackageStructName},
},
}
structType.Fields.List = append(structType.Fields.List, field)
return false
}
return true
})
return nil
}
func (a *PackageEnter) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,154 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPackageEnter_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
PackageName string
PackageStructName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试ExampleApiGroup回滚",
fields: fields{
Type: TypePackageApiEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/api/v1/example"`,
StructName: "ExampleApiGroup",
PackageName: "example",
PackageStructName: "ApiGroup",
},
wantErr: false,
},
{
name: "测试ExampleRouterGroup回滚",
fields: fields{
Type: TypePackageRouterEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/router/example"`,
StructName: "Example",
PackageName: "example",
PackageStructName: "RouterGroup",
},
wantErr: false,
},
{
name: "测试ExampleServiceGroup回滚",
fields: fields{
Type: TypePackageServiceEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/service/example"`,
StructName: "ExampleServiceGroup",
PackageName: "example",
PackageStructName: "ServiceGroup",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageEnter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
PackageName: tt.fields.PackageName,
PackageStructName: tt.fields.PackageStructName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPackageEnter_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
PackageName string
PackageStructName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试ExampleApiGroup注入",
fields: fields{
Type: TypePackageApiEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/api/v1/example"`,
StructName: "ExampleApiGroup",
PackageName: "example",
PackageStructName: "ApiGroup",
},
},
{
name: "测试ExampleRouterGroup注入",
fields: fields{
Type: TypePackageRouterEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/router/example"`,
StructName: "Example",
PackageName: "example",
PackageStructName: "RouterGroup",
},
wantErr: false,
},
{
name: "测试ExampleServiceGroup注入",
fields: fields{
Type: TypePackageServiceEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/service/example"`,
StructName: "ExampleServiceGroup",
PackageName: "example",
PackageStructName: "ServiceGroup",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageEnter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
PackageName: tt.fields.PackageName,
PackageStructName: tt.fields.PackageStructName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,196 @@
package ast
import (
"fmt"
"go/ast"
"go/token"
"io"
)
// PackageInitializeGorm 包初始化gorm
type PackageInitializeGorm struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
Business string // 业务库 gva => gva, 不要传"gva"
StructName string // 结构体名称
PackageName string // 包名
RelativePath string // 相对路径
IsNew bool // 是否使用new关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
}
func (a *PackageInitializeGorm) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PackageInitializeGorm) Rollback(file *ast.File) error {
packageNameNum := 0
// 寻找目标结构
ast.Inspect(file, func(n ast.Node) bool {
// 总调用的db变量根据business来决定
varDB := a.Business + "Db"
if a.Business == "" {
varDB = "db"
}
callExpr, ok := n.(*ast.CallExpr)
if !ok {
return true
}
// 检查是不是 db.AutoMigrate() 方法
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok || selExpr.Sel.Name != "AutoMigrate" {
return true
}
// 检查调用方是不是 db
ident, ok := selExpr.X.(*ast.Ident)
if !ok || ident.Name != varDB {
return true
}
// 删除结构体参数
for i := 0; i < len(callExpr.Args); i++ {
if com, comok := callExpr.Args[i].(*ast.CompositeLit); comok {
if selector, exprok := com.Type.(*ast.SelectorExpr); exprok {
if x, identok := selector.X.(*ast.Ident); identok {
if x.Name == a.PackageName {
packageNameNum++
if selector.Sel.Name == a.StructName {
callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...)
i--
}
}
}
}
}
}
return true
})
if packageNameNum == 1 {
_ = NewImport(a.ImportPath).Rollback(file)
}
return nil
}
func (a *PackageInitializeGorm) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
bizModelDecl := FindFunction(file, "bizModel")
if bizModelDecl != nil {
a.addDbVar(bizModelDecl.Body)
}
// 寻找目标结构
ast.Inspect(file, func(n ast.Node) bool {
// 总调用的db变量根据business来决定
varDB := a.Business + "Db"
if a.Business == "" {
varDB = "db"
}
callExpr, ok := n.(*ast.CallExpr)
if !ok {
return true
}
// 检查是不是 db.AutoMigrate() 方法
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok || selExpr.Sel.Name != "AutoMigrate" {
return true
}
// 检查调用方是不是 db
ident, ok := selExpr.X.(*ast.Ident)
if !ok || ident.Name != varDB {
return true
}
// 添加结构体参数
callExpr.Args = append(callExpr.Args, &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: ast.NewIdent(a.PackageName),
Sel: ast.NewIdent(a.StructName),
},
})
return true
})
return nil
}
func (a *PackageInitializeGorm) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}
// 创建businessDB变量
func (a *PackageInitializeGorm) addDbVar(astBody *ast.BlockStmt) {
for i := range astBody.List {
if assignStmt, ok := astBody.List[i].(*ast.AssignStmt); ok {
if ident, ok := assignStmt.Lhs[0].(*ast.Ident); ok {
if (a.Business == "" && ident.Name == "db") || ident.Name == a.Business+"Db" {
return
}
}
}
}
// 添加 businessDb := global.GetGlobalDBByDBName("business") 变量
assignNode := &ast.AssignStmt{
Lhs: []ast.Expr{
&ast.Ident{
Name: a.Business + "Db",
},
},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{
Name: "global",
},
Sel: &ast.Ident{
Name: "GetGlobalDBByDBName",
},
},
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("\"%s\"", a.Business),
},
},
},
},
}
// 添加 businessDb.AutoMigrate() 方法
autoMigrateCall := &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{
Name: a.Business + "Db",
},
Sel: &ast.Ident{
Name: "AutoMigrate",
},
},
},
}
returnNode := astBody.List[len(astBody.List)-1]
astBody.List = append(astBody.List[:len(astBody.List)-1], assignNode, autoMigrateCall, returnNode)
}

View File

@@ -0,0 +1,171 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPackageInitializeGorm_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
PackageName string
IsNew bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 &example.ExaFileUploadAndDownload{} 注入",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaFileUploadAndDownload",
PackageName: "example",
IsNew: false,
},
},
{
name: "测试 &example.ExaCustomer{} 注入",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaCustomer",
PackageName: "example",
IsNew: false,
},
},
{
name: "测试 new(example.ExaFileUploadAndDownload) 注入",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaFileUploadAndDownload",
PackageName: "example",
IsNew: true,
},
},
{
name: "测试 new(example.ExaCustomer) 注入",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaCustomer",
PackageName: "example",
IsNew: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageInitializeGorm{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
PackageName: tt.fields.PackageName,
IsNew: tt.fields.IsNew,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPackageInitializeGorm_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
PackageName string
IsNew bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 &example.ExaFileUploadAndDownload{} 回滚",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaFileUploadAndDownload",
PackageName: "example",
IsNew: false,
},
},
{
name: "测试 &example.ExaCustomer{} 回滚",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaCustomer",
PackageName: "example",
IsNew: false,
},
},
{
name: "测试 new(example.ExaFileUploadAndDownload) 回滚",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaFileUploadAndDownload",
PackageName: "example",
IsNew: true,
},
},
{
name: "测试 new(example.ExaCustomer) 回滚",
fields: fields{
Type: TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/model/example"`,
StructName: "ExaCustomer",
PackageName: "example",
IsNew: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageInitializeGorm{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
PackageName: tt.fields.PackageName,
IsNew: tt.fields.IsNew,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,150 @@
package ast
import (
"fmt"
"go/ast"
"go/token"
"io"
)
// PackageInitializeRouter 包初始化路由
// ModuleName := PackageName.AppName.GroupName
// ModuleName.FunctionName(RouterGroupName)
type PackageInitializeRouter struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
RelativePath string // 相对路径
AppName string // 应用名称
GroupName string // 分组名称
ModuleName string // 模块名称
PackageName string // 包名
FunctionName string // 函数名
RouterGroupName string // 路由分组名称
LeftRouterGroupName string // 左路由分组名称
RightRouterGroupName string // 右路由分组名称
}
func (a *PackageInitializeRouter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PackageInitializeRouter) Rollback(file *ast.File) error {
funcDecl := FindFunction(file, "initBizRouter")
exprNum := 0
for i := range funcDecl.Body.List {
if IsBlockStmt(funcDecl.Body.List[i]) {
if VariableExistsInBlock(funcDecl.Body.List[i].(*ast.BlockStmt), a.ModuleName) {
for ii, stmt := range funcDecl.Body.List[i].(*ast.BlockStmt).List {
// 检查语句是否为 *ast.ExprStmt
exprStmt, ok := stmt.(*ast.ExprStmt)
if !ok {
continue
}
// 检查表达式是否为 *ast.CallExpr
callExpr, ok := exprStmt.X.(*ast.CallExpr)
if !ok {
continue
}
// 检查是否调用了我们正在寻找的函数
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
continue
}
// 检查调用的函数是否为 systemRouter.InitApiRouter
ident, ok := selExpr.X.(*ast.Ident)
//只要存在调用则+1
if ok && ident.Name == a.ModuleName {
exprNum++
}
//判断是否为目标结构
if !ok || ident.Name != a.ModuleName || selExpr.Sel.Name != a.FunctionName {
continue
}
exprNum--
// 从语句列表中移除。
funcDecl.Body.List[i].(*ast.BlockStmt).List = append(funcDecl.Body.List[i].(*ast.BlockStmt).List[:ii], funcDecl.Body.List[i].(*ast.BlockStmt).List[ii+1:]...)
// 如果不再存在任何调用,则删除导入和变量。
if exprNum == 0 {
funcDecl.Body.List = append(funcDecl.Body.List[:i], funcDecl.Body.List[i+1:]...)
}
break
}
break
}
}
}
return nil
}
func (a *PackageInitializeRouter) Injection(file *ast.File) error {
funcDecl := FindFunction(file, "initBizRouter")
hasRouter := false
var varBlock *ast.BlockStmt
for i := range funcDecl.Body.List {
if IsBlockStmt(funcDecl.Body.List[i]) {
if VariableExistsInBlock(funcDecl.Body.List[i].(*ast.BlockStmt), a.ModuleName) {
hasRouter = true
varBlock = funcDecl.Body.List[i].(*ast.BlockStmt)
break
}
}
}
if !hasRouter {
stmt := a.CreateAssignStmt()
varBlock = &ast.BlockStmt{
List: []ast.Stmt{
stmt,
},
}
}
routerStmt := CreateStmt(fmt.Sprintf("%s.%s(%s,%s)", a.ModuleName, a.FunctionName, a.LeftRouterGroupName, a.RightRouterGroupName))
varBlock.List = append(varBlock.List, routerStmt)
if !hasRouter {
funcDecl.Body.List = append(funcDecl.Body.List, varBlock)
}
return nil
}
func (a *PackageInitializeRouter) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}
func (a *PackageInitializeRouter) CreateAssignStmt() *ast.AssignStmt {
//创建左侧变量
ident := &ast.Ident{
Name: a.ModuleName,
}
//创建右侧的赋值语句
selector := &ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.AppName},
},
Sel: &ast.Ident{Name: a.GroupName},
}
// 创建一个组合的赋值语句
stmt := &ast.AssignStmt{
Lhs: []ast.Expr{ident},
Tok: token.DEFINE,
Rhs: []ast.Expr{selector},
}
return stmt
}

View File

@@ -0,0 +1,158 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPackageInitializeRouter_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
AppName string
GroupName string
ModuleName string
PackageName string
FunctionName string
RouterGroupName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 InitCustomerRouter 注入",
fields: fields{
Type: TypePackageInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/router"`,
AppName: "RouterGroupApp",
GroupName: "Example",
ModuleName: "exampleRouter",
PackageName: "router",
FunctionName: "InitCustomerRouter",
RouterGroupName: "privateGroup",
},
wantErr: false,
},
{
name: "测试 InitFileUploadAndDownloadRouter 注入",
fields: fields{
Type: TypePackageInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/router"`,
AppName: "RouterGroupApp",
GroupName: "Example",
ModuleName: "exampleRouter",
PackageName: "router",
FunctionName: "InitFileUploadAndDownloadRouter",
RouterGroupName: "privateGroup",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageInitializeRouter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
AppName: tt.fields.AppName,
GroupName: tt.fields.GroupName,
ModuleName: tt.fields.ModuleName,
PackageName: tt.fields.PackageName,
FunctionName: tt.fields.FunctionName,
RouterGroupName: tt.fields.RouterGroupName,
LeftRouterGroupName: "privateGroup",
RightRouterGroupName: "publicGroup",
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPackageInitializeRouter_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
AppName string
GroupName string
ModuleName string
PackageName string
FunctionName string
RouterGroupName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 InitCustomerRouter 回滚",
fields: fields{
Type: TypePackageInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/router"`,
AppName: "RouterGroupApp",
GroupName: "Example",
ModuleName: "exampleRouter",
PackageName: "router",
FunctionName: "InitCustomerRouter",
RouterGroupName: "privateGroup",
},
wantErr: false,
},
{
name: "测试 InitFileUploadAndDownloadRouter 回滚",
fields: fields{
Type: TypePackageInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/router"`,
AppName: "RouterGroupApp",
GroupName: "Example",
ModuleName: "exampleRouter",
PackageName: "router",
FunctionName: "InitFileUploadAndDownloadRouter",
RouterGroupName: "privateGroup",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageInitializeRouter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
AppName: tt.fields.AppName,
GroupName: tt.fields.GroupName,
ModuleName: tt.fields.ModuleName,
PackageName: tt.fields.PackageName,
FunctionName: tt.fields.FunctionName,
RouterGroupName: tt.fields.RouterGroupName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,180 @@
package ast
import (
"go/ast"
"go/token"
"io"
)
// PackageModuleEnter 模块化入口
// ModuleName := PackageName.AppName.GroupName.ServiceName
type PackageModuleEnter struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
RelativePath string // 相对路径
StructName string // 结构体名称
AppName string // 应用名称
GroupName string // 分组名称
ModuleName string // 模块名称
PackageName string // 包名
ServiceName string // 服务名称
}
func (a *PackageModuleEnter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PackageModuleEnter) Rollback(file *ast.File) error {
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.GenDecl)
if o1 {
for j := 0; j < len(v1.Specs); j++ {
v2, o2 := v1.Specs[j].(*ast.TypeSpec)
if o2 {
if v2.Name.Name != a.Type.Group() {
continue
}
v3, o3 := v2.Type.(*ast.StructType)
if o3 {
for k := 0; k < len(v3.Fields.List); k++ {
v4, o4 := v3.Fields.List[k].Type.(*ast.Ident)
if o4 && v4.Name == a.StructName {
v3.Fields.List = append(v3.Fields.List[:k], v3.Fields.List[k+1:]...)
}
}
}
continue
}
if a.Type == TypePackageServiceModuleEnter {
continue
}
v3, o3 := v1.Specs[j].(*ast.ValueSpec)
if o3 {
if len(v3.Names) == 1 && v3.Names[0].Name == a.ModuleName {
v1.Specs = append(v1.Specs[:j], v1.Specs[j+1:]...)
}
}
if v1.Tok == token.VAR && len(v1.Specs) == 0 {
_ = NewImport(a.ImportPath).Rollback(file)
if i == len(file.Decls) {
file.Decls = append(file.Decls[:i-1])
break
} // 空的var(), 如果不删除则会影响的注入变量, 因为识别不到*ast.ValueSpec
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
}
}
}
}
return nil
}
func (a *PackageModuleEnter) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
var hasValue bool
var hasVariables bool
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.GenDecl)
if o1 {
if v1.Tok == token.VAR {
hasVariables = true
}
for j := 0; j < len(v1.Specs); j++ {
if a.Type == TypePackageServiceModuleEnter {
hasValue = true
}
v2, o2 := v1.Specs[j].(*ast.TypeSpec)
if o2 {
if v2.Name.Name != a.Type.Group() {
continue
}
v3, o3 := v2.Type.(*ast.StructType)
if o3 {
var hasStruct bool
for k := 0; k < len(v3.Fields.List); k++ {
v4, o4 := v3.Fields.List[k].Type.(*ast.Ident)
if o4 && v4.Name == a.StructName {
hasStruct = true
}
}
if !hasStruct {
field := &ast.Field{Type: &ast.Ident{Name: a.StructName}}
v3.Fields.List = append(v3.Fields.List, field)
}
}
continue
}
v3, o3 := v1.Specs[j].(*ast.ValueSpec)
if o3 {
hasVariables = true
if len(v3.Names) == 1 && v3.Names[0].Name == a.ModuleName {
hasValue = true
}
}
if v1.Tok == token.VAR && len(v1.Specs) == 0 {
hasVariables = false
} // 说明是空var()
if hasVariables && !hasValue {
spec := &ast.ValueSpec{
Names: []*ast.Ident{{Name: a.ModuleName}},
Values: []ast.Expr{
&ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.AppName},
},
Sel: &ast.Ident{Name: a.GroupName},
},
Sel: &ast.Ident{Name: a.ServiceName},
},
},
}
v1.Specs = append(v1.Specs, spec)
hasValue = true
}
}
}
}
if !hasValue && !hasVariables {
decl := &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{
&ast.ValueSpec{
Names: []*ast.Ident{{Name: a.ModuleName}},
Values: []ast.Expr{
&ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.AppName},
},
Sel: &ast.Ident{Name: a.GroupName},
},
Sel: &ast.Ident{Name: a.ServiceName},
},
},
},
},
}
file.Decls = append(file.Decls, decl)
}
return nil
}
func (a *PackageModuleEnter) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,185 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPackageModuleEnter_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
AppName string
GroupName string
ModuleName string
PackageName string
ServiceName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 FileUploadAndDownloadRouter 回滚",
fields: fields{
Type: TypePackageRouterModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "example", "enter.go"),
ImportPath: `api "git.echol.cn/loser/ai_proxy/server/api/v1"`,
StructName: "FileUploadAndDownloadRouter",
AppName: "ApiGroupApp",
GroupName: "ExampleApiGroup",
ModuleName: "exaFileUploadAndDownloadApi",
PackageName: "api",
ServiceName: "FileUploadAndDownloadApi",
},
wantErr: false,
},
{
name: "测试 FileUploadAndDownloadApi 回滚",
fields: fields{
Type: TypePackageApiModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "example", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/service"`,
StructName: "FileUploadAndDownloadApi",
AppName: "ServiceGroupApp",
GroupName: "ExampleServiceGroup",
ModuleName: "fileUploadAndDownloadService",
PackageName: "service",
ServiceName: "FileUploadAndDownloadService",
},
wantErr: false,
},
{
name: "测试 FileUploadAndDownloadService 回滚",
fields: fields{
Type: TypePackageServiceModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "example", "enter.go"),
ImportPath: ``,
StructName: "FileUploadAndDownloadService",
AppName: "",
GroupName: "",
ModuleName: "",
PackageName: "",
ServiceName: "",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageModuleEnter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
AppName: tt.fields.AppName,
GroupName: tt.fields.GroupName,
ModuleName: tt.fields.ModuleName,
PackageName: tt.fields.PackageName,
ServiceName: tt.fields.ServiceName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPackageModuleEnter_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
AppName string
GroupName string
ModuleName string
PackageName string
ServiceName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 FileUploadAndDownloadRouter 注入",
fields: fields{
Type: TypePackageRouterModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", "example", "enter.go"),
ImportPath: `api "git.echol.cn/loser/ai_proxy/server/api/v1"`,
StructName: "FileUploadAndDownloadRouter",
AppName: "ApiGroupApp",
GroupName: "ExampleApiGroup",
ModuleName: "exaFileUploadAndDownloadApi",
PackageName: "api",
ServiceName: "FileUploadAndDownloadApi",
},
wantErr: false,
},
{
name: "测试 FileUploadAndDownloadApi 注入",
fields: fields{
Type: TypePackageApiModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", "example", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/service"`,
StructName: "FileUploadAndDownloadApi",
AppName: "ServiceGroupApp",
GroupName: "ExampleServiceGroup",
ModuleName: "fileUploadAndDownloadService",
PackageName: "service",
ServiceName: "FileUploadAndDownloadService",
},
wantErr: false,
},
{
name: "测试 FileUploadAndDownloadService 注入",
fields: fields{
Type: TypePackageServiceModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", "example", "enter.go"),
ImportPath: ``,
StructName: "FileUploadAndDownloadService",
AppName: "",
GroupName: "",
ModuleName: "",
PackageName: "",
ServiceName: "",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PackageModuleEnter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
AppName: tt.fields.AppName,
GroupName: tt.fields.GroupName,
ModuleName: tt.fields.ModuleName,
PackageName: tt.fields.PackageName,
ServiceName: tt.fields.ServiceName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,167 @@
package ast
import (
"go/ast"
"go/token"
"io"
)
// PluginEnter 插件化入口
// ModuleName := PackageName.GroupName.ServiceName
type PluginEnter struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
RelativePath string // 相对路径
StructName string // 结构体名称
StructCamelName string // 结构体小驼峰名称
ModuleName string // 模块名称
GroupName string // 分组名称
PackageName string // 包名
ServiceName string // 服务名称
}
func (a *PluginEnter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PluginEnter) Rollback(file *ast.File) error {
//回滚结构体内内容
var structType *ast.StructType
ast.Inspect(file, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.TypeSpec:
if s, ok := x.Type.(*ast.StructType); ok {
structType = s
for i, field := range x.Type.(*ast.StructType).Fields.List {
if len(field.Names) > 0 && field.Names[0].Name == a.StructName {
s.Fields.List = append(s.Fields.List[:i], s.Fields.List[i+1:]...)
return false
}
}
}
}
return true
})
if len(structType.Fields.List) == 0 {
_ = NewImport(a.ImportPath).Rollback(file)
}
if a.Type == TypePluginServiceEnter {
return nil
}
//回滚变量内容
ast.Inspect(file, func(n ast.Node) bool {
genDecl, ok := n.(*ast.GenDecl)
if ok && genDecl.Tok == token.VAR {
for i, spec := range genDecl.Specs {
valueSpec, vsok := spec.(*ast.ValueSpec)
if vsok {
for _, name := range valueSpec.Names {
if name.Name == a.ModuleName {
genDecl.Specs = append(genDecl.Specs[:i], genDecl.Specs[i+1:]...)
return false
}
}
}
}
}
return true
})
return nil
}
func (a *PluginEnter) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
has := false
hasVar := false
var firstStruct *ast.StructType
var varSpec *ast.GenDecl
//寻找是否存在结构且定位
ast.Inspect(file, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.TypeSpec:
if s, ok := x.Type.(*ast.StructType); ok {
firstStruct = s
for _, field := range x.Type.(*ast.StructType).Fields.List {
if len(field.Names) > 0 && field.Names[0].Name == a.StructName {
has = true
return false
}
}
}
}
return true
})
if !has {
field := &ast.Field{
Names: []*ast.Ident{{Name: a.StructName}},
Type: &ast.Ident{Name: a.StructCamelName},
}
firstStruct.Fields.List = append(firstStruct.Fields.List, field)
}
if a.Type == TypePluginServiceEnter {
return nil
}
//寻找是否存在变量且定位
ast.Inspect(file, func(n ast.Node) bool {
genDecl, ok := n.(*ast.GenDecl)
if ok && genDecl.Tok == token.VAR {
for _, spec := range genDecl.Specs {
valueSpec, vsok := spec.(*ast.ValueSpec)
if vsok {
varSpec = genDecl
for _, name := range valueSpec.Names {
if name.Name == a.ModuleName {
hasVar = true
return false
}
}
}
}
}
return true
})
if !hasVar {
spec := &ast.ValueSpec{
Names: []*ast.Ident{{Name: a.ModuleName}},
Values: []ast.Expr{
&ast.SelectorExpr{
X: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.GroupName},
},
Sel: &ast.Ident{Name: a.ServiceName},
},
},
}
varSpec.Specs = append(varSpec.Specs, spec)
}
return nil
}
func (a *PluginEnter) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,200 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPluginEnter_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
StructCamelName string
ModuleName string
GroupName string
PackageName string
ServiceName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 Gva插件UserApi 注入",
fields: fields{
Type: TypePluginApiEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "api", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/service"`,
StructName: "User",
StructCamelName: "user",
ModuleName: "serviceUser",
GroupName: "Service",
PackageName: "service",
ServiceName: "User",
},
wantErr: false,
},
{
name: "测试 Gva插件UserRouter 注入",
fields: fields{
Type: TypePluginRouterEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "router", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/api"`,
StructName: "User",
StructCamelName: "user",
ModuleName: "userApi",
GroupName: "Api",
PackageName: "api",
ServiceName: "User",
},
wantErr: false,
},
{
name: "测试 Gva插件UserService 注入",
fields: fields{
Type: TypePluginServiceEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "service", "enter.go"),
ImportPath: "",
StructName: "User",
StructCamelName: "user",
ModuleName: "",
GroupName: "",
PackageName: "",
ServiceName: "",
},
wantErr: false,
},
{
name: "测试 gva的User 注入",
fields: fields{
Type: TypePluginServiceEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "service", "enter.go"),
ImportPath: "",
StructName: "User",
StructCamelName: "user",
ModuleName: "",
GroupName: "",
PackageName: "",
ServiceName: "",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginEnter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
StructCamelName: tt.fields.StructCamelName,
ModuleName: tt.fields.ModuleName,
GroupName: tt.fields.GroupName,
PackageName: tt.fields.PackageName,
ServiceName: tt.fields.ServiceName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPluginEnter_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
StructCamelName string
ModuleName string
GroupName string
PackageName string
ServiceName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 Gva插件UserRouter 回滚",
fields: fields{
Type: TypePluginRouterEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "router", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/api"`,
StructName: "User",
StructCamelName: "user",
ModuleName: "userApi",
GroupName: "Api",
PackageName: "api",
ServiceName: "User",
},
wantErr: false,
},
{
name: "测试 Gva插件UserApi 回滚",
fields: fields{
Type: TypePluginApiEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "api", "enter.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/service"`,
StructName: "User",
StructCamelName: "user",
ModuleName: "serviceUser",
GroupName: "Service",
PackageName: "service",
ServiceName: "User",
},
wantErr: false,
},
{
name: "测试 Gva插件UserService 回滚",
fields: fields{
Type: TypePluginServiceEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "service", "enter.go"),
ImportPath: "",
StructName: "User",
StructCamelName: "user",
ModuleName: "",
GroupName: "",
PackageName: "",
ServiceName: "",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginEnter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
StructCamelName: tt.fields.StructCamelName,
ModuleName: tt.fields.ModuleName,
GroupName: tt.fields.GroupName,
PackageName: tt.fields.PackageName,
ServiceName: tt.fields.ServiceName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,189 @@
package ast
import (
"go/ast"
"go/token"
"io"
)
type PluginGen struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
RelativePath string // 相对路径
StructName string // 结构体名称
PackageName string // 包名
IsNew bool // 是否使用new关键字
}
func (a *PluginGen) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PluginGen) Rollback(file *ast.File) error {
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.FuncDecl)
if o1 {
for j := 0; j < len(v1.Body.List); j++ {
v2, o2 := v1.Body.List[j].(*ast.ExprStmt)
if o2 {
v3, o3 := v2.X.(*ast.CallExpr)
if o3 {
v4, o4 := v3.Fun.(*ast.SelectorExpr)
if o4 {
if v4.Sel.Name != "ApplyBasic" {
continue
}
for k := 0; k < len(v3.Args); k++ {
v5, o5 := v3.Args[k].(*ast.CallExpr)
if o5 {
v6, o6 := v5.Fun.(*ast.Ident)
if o6 {
if v6.Name != "new" {
continue
}
for l := 0; l < len(v5.Args); l++ {
v7, o7 := v5.Args[l].(*ast.SelectorExpr)
if o7 {
v8, o8 := v7.X.(*ast.Ident)
if o8 {
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
v3.Args = append(v3.Args[:k], v3.Args[k+1:]...)
continue
}
}
}
}
}
}
if k >= len(v3.Args) {
break
}
v6, o6 := v3.Args[k].(*ast.CompositeLit)
if o6 {
v7, o7 := v6.Type.(*ast.SelectorExpr)
if o7 {
v8, o8 := v7.X.(*ast.Ident)
if o8 {
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
v3.Args = append(v3.Args[:k], v3.Args[k+1:]...)
continue
}
}
}
}
}
if len(v3.Args) == 0 {
_ = NewImport(a.ImportPath).Rollback(file)
}
}
}
}
}
}
}
return nil
}
func (a *PluginGen) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.FuncDecl)
if o1 {
for j := 0; j < len(v1.Body.List); j++ {
v2, o2 := v1.Body.List[j].(*ast.ExprStmt)
if o2 {
v3, o3 := v2.X.(*ast.CallExpr)
if o3 {
v4, o4 := v3.Fun.(*ast.SelectorExpr)
if o4 {
if v4.Sel.Name != "ApplyBasic" {
continue
}
var has bool
for k := 0; k < len(v3.Args); k++ {
v5, o5 := v3.Args[k].(*ast.CallExpr)
if o5 {
v6, o6 := v5.Fun.(*ast.Ident)
if o6 {
if v6.Name != "new" {
continue
}
for l := 0; l < len(v5.Args); l++ {
v7, o7 := v5.Args[l].(*ast.SelectorExpr)
if o7 {
v8, o8 := v7.X.(*ast.Ident)
if o8 {
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
has = true
break
}
}
}
}
}
}
v6, o6 := v3.Args[k].(*ast.CompositeLit)
if o6 {
v7, o7 := v6.Type.(*ast.SelectorExpr)
if o7 {
v8, o8 := v7.X.(*ast.Ident)
if o8 {
if v8.Name == a.PackageName && v7.Sel.Name == a.StructName {
has = true
break
}
}
}
}
}
if !has {
if a.IsNew {
arg := &ast.CallExpr{
Fun: &ast.Ident{Name: "\n\t\tnew"},
Args: []ast.Expr{
&ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.StructName},
},
},
}
v3.Args = append(v3.Args, arg)
v3.Args = append(v3.Args, &ast.BasicLit{
Kind: token.STRING,
Value: "\n",
})
break
}
arg := &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.StructName},
},
}
v3.Args = append(v3.Args, arg)
}
}
}
}
}
}
}
return nil
}
func (a *PluginGen) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,127 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPluginGenModel_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
PackageName string
StructName string
IsNew bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 GvaUser 结构体注入",
fields: fields{
Type: TypePluginGen,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
PackageName: "model",
StructName: "User",
IsNew: false,
},
},
{
name: "测试 GvaUser 结构体注入",
fields: fields{
Type: TypePluginGen,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
PackageName: "model",
StructName: "User",
IsNew: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginGen{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
PackageName: tt.fields.PackageName,
StructName: tt.fields.StructName,
IsNew: tt.fields.IsNew,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPluginGenModel_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
PackageName string
StructName string
IsNew bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 GvaUser 回滚",
fields: fields{
Type: TypePluginGen,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
PackageName: "model",
StructName: "User",
IsNew: false,
},
},
{
name: "测试 GvaUser 回滚",
fields: fields{
Type: TypePluginGen,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "gen", "main.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
PackageName: "model",
StructName: "User",
IsNew: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginGen{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
PackageName: tt.fields.PackageName,
StructName: tt.fields.StructName,
IsNew: tt.fields.IsNew,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,111 @@
package ast
import (
"go/ast"
"io"
)
type PluginInitializeGorm struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
RelativePath string // 相对路径
StructName string // 结构体名称
PackageName string // 包名
IsNew bool // 是否使用new关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
}
func (a *PluginInitializeGorm) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PluginInitializeGorm) Rollback(file *ast.File) error {
var needRollBackImport bool
ast.Inspect(file, func(n ast.Node) bool {
callExpr, ok := n.(*ast.CallExpr)
if !ok {
return true
}
selExpr, seok := callExpr.Fun.(*ast.SelectorExpr)
if !seok || selExpr.Sel.Name != "AutoMigrate" {
return true
}
if len(callExpr.Args) <= 1 {
needRollBackImport = true
}
// 删除指定的参数
for i, arg := range callExpr.Args {
compLit, cok := arg.(*ast.CompositeLit)
if !cok {
continue
}
cselExpr, sok := compLit.Type.(*ast.SelectorExpr)
if !sok {
continue
}
ident, idok := cselExpr.X.(*ast.Ident)
if idok && ident.Name == a.PackageName && cselExpr.Sel.Name == a.StructName {
// 删除参数
callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...)
break
}
}
return true
})
if needRollBackImport {
_ = NewImport(a.ImportPath).Rollback(file)
}
return nil
}
func (a *PluginInitializeGorm) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
var call *ast.CallExpr
ast.Inspect(file, func(n ast.Node) bool {
callExpr, ok := n.(*ast.CallExpr)
if !ok {
return true
}
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if ok && selExpr.Sel.Name == "AutoMigrate" {
call = callExpr
return false
}
return true
})
arg := &ast.CompositeLit{
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.StructName},
},
}
call.Args = append(call.Args, arg)
return nil
}
func (a *PluginInitializeGorm) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,138 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPluginInitializeGorm_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
PackageName string
IsNew bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 &model.User{} 注入",
fields: fields{
Type: TypePluginInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
StructName: "User",
PackageName: "model",
IsNew: false,
},
},
{
name: "测试 new(model.ExaCustomer) 注入",
fields: fields{
Type: TypePluginInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
StructName: "User",
PackageName: "model",
IsNew: true,
},
},
{
name: "测试 new(model.SysUsers) 注入",
fields: fields{
Type: TypePluginInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
StructName: "SysUser",
PackageName: "model",
IsNew: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginInitializeGorm{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
PackageName: tt.fields.PackageName,
IsNew: tt.fields.IsNew,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPluginInitializeGorm_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
StructName string
PackageName string
IsNew bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 &model.User{} 回滚",
fields: fields{
Type: TypePluginInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
StructName: "User",
PackageName: "model",
IsNew: false,
},
},
{
name: "测试 new(model.ExaCustomer) 回滚",
fields: fields{
Type: TypePluginInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "gorm.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/model"`,
StructName: "User",
PackageName: "model",
IsNew: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginInitializeGorm{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
StructName: tt.fields.StructName,
PackageName: tt.fields.PackageName,
IsNew: tt.fields.IsNew,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,124 @@
package ast
import (
"fmt"
"go/ast"
"io"
)
// PluginInitializeRouter 插件初始化路由
// PackageName.AppName.GroupName.FunctionName()
type PluginInitializeRouter struct {
Base
Type Type // 类型
Path string // 文件路径
ImportPath string // 导包路径
ImportGlobalPath string // 导包全局变量路径
ImportMiddlewarePath string // 导包中间件路径
RelativePath string // 相对路径
AppName string // 应用名称
GroupName string // 分组名称
PackageName string // 包名
FunctionName string // 函数名
LeftRouterGroupName string // 左路由分组名称
RightRouterGroupName string // 右路由分组名称
}
func (a *PluginInitializeRouter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.Path
a.RelativePath = a.Base.RelativePath(a.Path)
return a.Base.Parse(filename, writer)
}
a.Path = a.Base.AbsolutePath(a.RelativePath)
filename = a.Path
}
return a.Base.Parse(filename, writer)
}
func (a *PluginInitializeRouter) Rollback(file *ast.File) error {
funcDecl := FindFunction(file, "Router")
delI := 0
routerNum := 0
for i := len(funcDecl.Body.List) - 1; i >= 0; i-- {
stmt, ok := funcDecl.Body.List[i].(*ast.ExprStmt)
if !ok {
continue
}
callExpr, ok := stmt.X.(*ast.CallExpr)
if !ok {
continue
}
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
continue
}
ident, ok := selExpr.X.(*ast.SelectorExpr)
if ok {
if iExpr, ieok := ident.X.(*ast.SelectorExpr); ieok {
if iden, idok := iExpr.X.(*ast.Ident); idok {
if iden.Name == "router" {
routerNum++
}
}
}
if ident.Sel.Name == a.GroupName && selExpr.Sel.Name == a.FunctionName {
// 删除语句
delI = i
}
}
}
funcDecl.Body.List = append(funcDecl.Body.List[:delI], funcDecl.Body.List[delI+1:]...)
if routerNum <= 1 {
_ = NewImport(a.ImportPath).Rollback(file)
}
return nil
}
func (a *PluginInitializeRouter) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
funcDecl := FindFunction(file, "Router")
var exists bool
ast.Inspect(funcDecl, func(n ast.Node) bool {
callExpr, ok := n.(*ast.CallExpr)
if !ok {
return true
}
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
return true
}
ident, ok := selExpr.X.(*ast.SelectorExpr)
if ok && ident.Sel.Name == a.GroupName && selExpr.Sel.Name == a.FunctionName {
exists = true
return false
}
return true
})
if !exists {
stmtStr := fmt.Sprintf("%s.%s.%s.%s(%s, %s)", a.PackageName, a.AppName, a.GroupName, a.FunctionName, a.LeftRouterGroupName, a.RightRouterGroupName)
stmt := CreateStmt(stmtStr)
funcDecl.Body.List = append(funcDecl.Body.List, stmt)
}
return nil
}
func (a *PluginInitializeRouter) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,155 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPluginInitializeRouter_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
AppName string
GroupName string
PackageName string
FunctionName string
LeftRouterGroupName string
RightRouterGroupName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 Gva插件User 注入",
fields: fields{
Type: TypePluginInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/router"`,
AppName: "Router",
GroupName: "User",
PackageName: "router",
FunctionName: "Init",
LeftRouterGroupName: "public",
RightRouterGroupName: "private",
},
wantErr: false,
},
{
name: "测试 中文 注入",
fields: fields{
Type: TypePluginInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/router"`,
AppName: "Router",
GroupName: "U中文",
PackageName: "router",
FunctionName: "Init",
LeftRouterGroupName: "public",
RightRouterGroupName: "private",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginInitializeRouter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
AppName: tt.fields.AppName,
GroupName: tt.fields.GroupName,
PackageName: tt.fields.PackageName,
FunctionName: tt.fields.FunctionName,
LeftRouterGroupName: tt.fields.LeftRouterGroupName,
RightRouterGroupName: tt.fields.RightRouterGroupName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPluginInitializeRouter_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
ImportPath string
AppName string
GroupName string
PackageName string
FunctionName string
LeftRouterGroupName string
RightRouterGroupName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 Gva插件User 回滚",
fields: fields{
Type: TypePluginInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/router"`,
AppName: "Router",
GroupName: "User",
PackageName: "router",
FunctionName: "Init",
LeftRouterGroupName: "public",
RightRouterGroupName: "private",
},
wantErr: false,
},
{
name: "测试 中文 注入",
fields: fields{
Type: TypePluginInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "initialize", "router.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva/router"`,
AppName: "Router",
GroupName: "U中文",
PackageName: "router",
FunctionName: "Init",
LeftRouterGroupName: "public",
RightRouterGroupName: "private",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &PluginInitializeRouter{
Type: tt.fields.Type,
Path: tt.fields.Path,
ImportPath: tt.fields.ImportPath,
AppName: tt.fields.AppName,
GroupName: tt.fields.GroupName,
PackageName: tt.fields.PackageName,
FunctionName: tt.fields.FunctionName,
LeftRouterGroupName: tt.fields.LeftRouterGroupName,
RightRouterGroupName: tt.fields.RightRouterGroupName,
}
file, err := a.Parse(a.Path, nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format(a.Path, nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,82 @@
package ast
import (
"go/ast"
"go/token"
"io"
"strconv"
"strings"
)
type PluginInitializeV2 struct {
Base
Type Type // 类型
Path string // 文件路径
PluginPath string // 插件路径
RelativePath string // 相对路径
ImportPath string // 导包路径
StructName string // 结构体名称
PackageName string // 包名
}
func (a *PluginInitializeV2) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
if filename == "" {
if a.RelativePath == "" {
filename = a.PluginPath
a.RelativePath = a.Base.RelativePath(a.PluginPath)
return a.Base.Parse(filename, writer)
}
a.PluginPath = a.Base.AbsolutePath(a.RelativePath)
filename = a.PluginPath
}
return a.Base.Parse(filename, writer)
}
func (a *PluginInitializeV2) Injection(file *ast.File) error {
importPath := strings.TrimSpace(a.ImportPath)
if importPath == "" {
return nil
}
importPath = strings.Trim(importPath, "\"")
if importPath == "" || CheckImport(file, importPath) {
return nil
}
importSpec := &ast.ImportSpec{
Name: ast.NewIdent("_"),
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(importPath)},
}
var importDecl *ast.GenDecl
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
if genDecl.Tok == token.IMPORT {
importDecl = genDecl
break
}
}
if importDecl == nil {
file.Decls = append([]ast.Decl{
&ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{importSpec},
},
}, file.Decls...)
return nil
}
importDecl.Specs = append(importDecl.Specs, importSpec)
return nil
}
func (a *PluginInitializeV2) Rollback(file *ast.File) error {
return nil
}
func (a *PluginInitializeV2) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.PluginPath
}
return a.Base.Format(filename, writer, file)
}

View File

@@ -0,0 +1,100 @@
package ast
import (
"git.echol.cn/loser/ai_proxy/server/global"
"path/filepath"
"testing"
)
func TestPluginInitialize_Injection(t *testing.T) {
type fields struct {
Type Type
Path string
PluginPath string
ImportPath string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 Gva插件 注册注入",
fields: fields{
Type: TypePluginInitializeV2,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "plugin.go"),
PluginPath: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "register.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva"`,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := PluginInitializeV2{
Type: tt.fields.Type,
Path: tt.fields.Path,
PluginPath: tt.fields.PluginPath,
ImportPath: tt.fields.ImportPath,
}
file, err := a.Parse("", nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Injection(file)
err = a.Format("", nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Injection() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPluginInitialize_Rollback(t *testing.T) {
type fields struct {
Type Type
Path string
PluginPath string
ImportPath string
PluginName string
StructName string
PackageName string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "测试 Gva插件 回滚",
fields: fields{
Type: TypePluginInitializeV2,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "gva", "plugin.go"),
PluginPath: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "register.go"),
ImportPath: `"git.echol.cn/loser/ai_proxy/server/plugin/gva"`,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := PluginInitializeV2{
Type: tt.fields.Type,
Path: tt.fields.Path,
PluginPath: tt.fields.PluginPath,
ImportPath: tt.fields.ImportPath,
StructName: "Plugin",
PackageName: "gva",
}
file, err := a.Parse("", nil)
if err != nil {
t.Errorf("Parse() error = %v, wantErr %v", err, tt.wantErr)
}
a.Rollback(file)
err = a.Format("", nil, file)
if (err != nil) != tt.wantErr {
t.Errorf("Rollback() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,713 @@
package autocode
import (
"fmt"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
"slices"
"strings"
"text/template"
)
// GetTemplateFuncMap 返回模板函数映射,用于在模板中使用
func GetTemplateFuncMap() template.FuncMap {
return template.FuncMap{
"title": strings.Title,
"GenerateField": GenerateField,
"GenerateSearchField": GenerateSearchField,
"GenerateSearchConditions": GenerateSearchConditions,
"GenerateSearchFormItem": GenerateSearchFormItem,
"GenerateTableColumn": GenerateTableColumn,
"GenerateFormItem": GenerateFormItem,
"GenerateDescriptionItem": GenerateDescriptionItem,
"GenerateDefaultFormValue": GenerateDefaultFormValue,
}
}
// 渲染Model中的字段
func GenerateField(field systemReq.AutoCodeField) string {
// 构建gorm标签
gormTag := ``
if field.FieldIndexType != "" {
gormTag += field.FieldIndexType + ";"
}
if field.PrimaryKey {
gormTag += "primarykey;"
}
if field.DefaultValue != "" {
gormTag += fmt.Sprintf("default:%s;", field.DefaultValue)
}
if field.Comment != "" {
gormTag += fmt.Sprintf("comment:%s;", field.Comment)
}
gormTag += "column:" + field.ColumnName + ";"
// 对于int类型根据DataTypeLong决定具体的Go类型不使用size标签
if field.DataTypeLong != "" && field.FieldType != "enum" && field.FieldType != "int" {
gormTag += fmt.Sprintf("size:%s;", field.DataTypeLong)
}
requireTag := ` binding:"required"` + "`"
// 根据字段类型构建不同的字段定义
var result string
switch field.FieldType {
case "enum":
result = fmt.Sprintf(`%s string `+"`"+`json:"%s" form:"%s" gorm:"%stype:enum(%s);"`+"`",
field.FieldName, field.FieldJson, field.FieldJson, gormTag, field.DataTypeLong)
case "picture", "video":
tagContent := fmt.Sprintf(`json:"%s" form:"%s" gorm:"%s"`,
field.FieldJson, field.FieldJson, gormTag)
result = fmt.Sprintf(`%s string `+"`"+`%s`+"`"+``, field.FieldName, tagContent)
case "file", "pictures", "array":
tagContent := fmt.Sprintf(`json:"%s" form:"%s" gorm:"%s"`,
field.FieldJson, field.FieldJson, gormTag)
result = fmt.Sprintf(`%s datatypes.JSON `+"`"+`%s swaggertype:"array,object"`+"`"+``,
field.FieldName, tagContent)
case "richtext":
tagContent := fmt.Sprintf(`json:"%s" form:"%s" gorm:"%s`,
field.FieldJson, field.FieldJson, gormTag)
result = fmt.Sprintf(`%s *string `+"`"+`%stype:text;"`+"`"+``,
field.FieldName, tagContent)
case "json":
tagContent := fmt.Sprintf(`json:"%s" form:"%s" gorm:"%s"`,
field.FieldJson, field.FieldJson, gormTag)
result = fmt.Sprintf(`%s datatypes.JSON `+"`"+`%s swaggertype:"object"`+"`"+``,
field.FieldName, tagContent)
default:
tagContent := fmt.Sprintf(`json:"%s" form:"%s" gorm:"%s"`,
field.FieldJson, field.FieldJson, gormTag)
// 对于int类型根据DataTypeLong决定具体的Go类型
var fieldType string
if field.FieldType == "int" {
switch field.DataTypeLong {
case "1", "2", "3":
fieldType = "int8"
case "4", "5":
fieldType = "int16"
case "6", "7", "8", "9", "10":
fieldType = "int32"
case "11", "12", "13", "14", "15", "16", "17", "18", "19", "20":
fieldType = "int64"
default:
fieldType = "int64"
}
} else {
fieldType = field.FieldType
}
result = fmt.Sprintf(`%s *%s `+"`"+`%s`+"`"+``,
field.FieldName, fieldType, tagContent)
}
if field.Require {
result = result[0:len(result)-1] + requireTag
}
// 添加字段描述
if field.FieldDesc != "" {
result += fmt.Sprintf(" //%s", field.FieldDesc)
}
return result
}
// 格式化搜索条件语句
func GenerateSearchConditions(fields []*systemReq.AutoCodeField) string {
var conditions []string
for _, field := range fields {
if field.FieldSearchType == "" {
continue
}
var condition string
if slices.Contains([]string{"enum", "pictures", "picture", "video", "json", "richtext", "array"}, field.FieldType) {
if field.FieldType == "enum" {
if field.FieldSearchType == "LIKE" {
condition = fmt.Sprintf(`
if info.%s != "" {
db = db.Where("%s LIKE ?", "%%"+ info.%s+"%%")
}`,
field.FieldName, field.ColumnName, field.FieldName)
} else {
condition = fmt.Sprintf(`
if info.%s != "" {
db = db.Where("%s %s ?", info.%s)
}`,
field.FieldName, field.ColumnName, field.FieldSearchType, field.FieldName)
}
} else {
condition = fmt.Sprintf(`
if info.%s != "" {
// TODO 数据类型为复杂类型,请根据业务需求自行实现复杂类型的查询业务
}`, field.FieldName)
}
} else if field.FieldSearchType == "BETWEEN" || field.FieldSearchType == "NOT BETWEEN" {
if field.FieldType == "time.Time" {
condition = fmt.Sprintf(`
if len(info.%sRange) == 2 {
db = db.Where("%s %s ? AND ? ", info.%sRange[0], info.%sRange[1])
}`,
field.FieldName, field.ColumnName, field.FieldSearchType, field.FieldName, field.FieldName)
} else {
condition = fmt.Sprintf(`
if info.Start%s != nil && info.End%s != nil {
db = db.Where("%s %s ? AND ? ", *info.Start%s, *info.End%s)
}`,
field.FieldName, field.FieldName, field.ColumnName,
field.FieldSearchType, field.FieldName, field.FieldName)
}
} else {
nullCheck := "info." + field.FieldName + " != nil"
if field.FieldType == "string" {
condition = fmt.Sprintf(`
if %s && *info.%s != "" {`, nullCheck, field.FieldName)
} else {
condition = fmt.Sprintf(`
if %s {`, nullCheck)
}
if field.FieldSearchType == "LIKE" {
condition += fmt.Sprintf(`
db = db.Where("%s LIKE ?", "%%"+ *info.%s+"%%")
}`,
field.ColumnName, field.FieldName)
} else {
condition += fmt.Sprintf(`
db = db.Where("%s %s ?", *info.%s)
}`,
field.ColumnName, field.FieldSearchType, field.FieldName)
}
}
conditions = append(conditions, condition)
}
return strings.Join(conditions, "")
}
// 格式化前端搜索条件
func GenerateSearchFormItem(field systemReq.AutoCodeField) string {
// 开始构建表单项
result := fmt.Sprintf(`<el-form-item label="%s" prop="%s">
`, field.FieldDesc, field.FieldJson)
// 根据字段属性生成不同的输入类型
if field.FieldType == "bool" {
result += fmt.Sprintf(` <el-select v-model="searchInfo.%s" clearable placeholder="请选择">
`, field.FieldJson)
result += ` <el-option key="true" label="是" value="true"></el-option>
`
result += ` <el-option key="false" label="否" value="false"></el-option>
`
result += ` </el-select>
`
} else if field.DictType != "" {
multipleAttr := ""
if field.FieldType == "array" {
multipleAttr = "multiple "
}
result += fmt.Sprintf(` <el-tree-select v-model="searchInfo.%s" placeholder="请选择%s" :data="%sOptions" style="width:100%%" filterable :clearable="%v" check-strictly %s></el-tree-select>
`,
field.FieldJson, field.FieldDesc, field.DictType, field.Clearable, multipleAttr)
} else if field.CheckDataSource {
multipleAttr := ""
if field.DataSource.Association == 2 {
multipleAttr = "multiple "
}
result += fmt.Sprintf(` <el-select %sv-model="searchInfo.%s" filterable placeholder="请选择%s" :clearable="%v">
`,
multipleAttr, field.FieldJson, field.FieldDesc, field.Clearable)
result += fmt.Sprintf(` <el-option v-for="(item,key) in dataSource.%s" :key="key" :label="item.label" :value="item.value" />
`,
field.FieldJson)
result += ` </el-select>
`
} else if field.FieldType == "float64" || field.FieldType == "int" {
if field.FieldSearchType == "BETWEEN" || field.FieldSearchType == "NOT BETWEEN" {
result += fmt.Sprintf(` <el-input class="!w-40" v-model.number="searchInfo.start%s" placeholder="最小值" />
`, field.FieldName)
result += `
`
result += fmt.Sprintf(` <el-input class="!w-40" v-model.number="searchInfo.end%s" placeholder="最大值" />
`, field.FieldName)
} else {
result += fmt.Sprintf(` <el-input v-model.number="searchInfo.%s" placeholder="搜索条件" />
`, field.FieldJson)
}
} else if field.FieldType == "time.Time" {
if field.FieldSearchType == "BETWEEN" || field.FieldSearchType == "NOT BETWEEN" {
result += ` <template #label>
`
result += ` <span>
`
result += fmt.Sprintf(` %s
`, field.FieldDesc)
result += ` <el-tooltip content="搜索范围是开始日期(包含)至结束日期(不包含)">
`
result += ` <el-icon><QuestionFilled /></el-icon>
`
result += ` </el-tooltip>
`
result += ` </span>
`
result += ` </template>
`
result += fmt.Sprintf(`<el-date-picker class="!w-380px" v-model="searchInfo.%sRange" type="datetimerange" range-separator="至" start-placeholder="开始时间" end-placeholder="结束时间"></el-date-picker>`, field.FieldJson)
} else {
result += fmt.Sprintf(`<el-date-picker v-model="searchInfo.%s" type="datetime" placeholder="搜索条件"></el-date-picker>`, field.FieldJson)
}
} else {
result += fmt.Sprintf(` <el-input v-model="searchInfo.%s" placeholder="搜索条件" />
`, field.FieldJson)
}
// 关闭表单项
result += `</el-form-item>`
return result
}
// GenerateTableColumn generates HTML for table column based on field properties
func GenerateTableColumn(field systemReq.AutoCodeField) string {
// Add sortable attribute if needed
sortAttr := ""
if field.Sort {
sortAttr = " sortable"
}
// Handle different field types
if field.CheckDataSource {
result := fmt.Sprintf(`<el-table-column%s align="left" label="%s" prop="%s" width="120">
`,
sortAttr, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
if field.DataSource.Association == 2 {
result += fmt.Sprintf(` <el-tag v-for="(item,key) in filterDataSource(dataSource.%s,scope.row.%s)" :key="key">
`,
field.FieldJson, field.FieldJson)
result += ` {{ item }}
`
result += ` </el-tag>
`
} else {
result += fmt.Sprintf(` <span>{{ filterDataSource(dataSource.%s,scope.row.%s) }}</span>
`,
field.FieldJson, field.FieldJson)
}
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.DictType != "" {
result := fmt.Sprintf(`<el-table-column%s align="left" label="%s" prop="%s" width="120">
`,
sortAttr, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
if field.FieldType == "array" {
result += fmt.Sprintf(` <el-tag class="mr-1" v-for="item in scope.row.%s" :key="item"> {{ filterDict(item,%sOptions) }}</el-tag>
`,
field.FieldJson, field.DictType)
} else {
result += fmt.Sprintf(` {{ filterDict(scope.row.%s,%sOptions) }}
`,
field.FieldJson, field.DictType)
}
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.FieldType == "bool" {
result := fmt.Sprintf(`<el-table-column%s align="left" label="%s" prop="%s" width="120">
`,
sortAttr, field.FieldDesc, field.FieldJson)
result += fmt.Sprintf(` <template #default="scope">{{ formatBoolean(scope.row.%s) }}</template>
`, field.FieldJson)
result += `</el-table-column>`
return result
} else if field.FieldType == "time.Time" {
result := fmt.Sprintf(`<el-table-column%s align="left" label="%s" prop="%s" width="180">
`,
sortAttr, field.FieldDesc, field.FieldJson)
result += fmt.Sprintf(` <template #default="scope">{{ formatDate(scope.row.%s) }}</template>
`, field.FieldJson)
result += `</el-table-column>`
return result
} else if field.FieldType == "picture" {
result := fmt.Sprintf(`<el-table-column label="%s" prop="%s" width="200">
`, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
result += fmt.Sprintf(` <el-image preview-teleported style="width: 100px; height: 100px" :src="getUrl(scope.row.%s)" fit="cover"/>
`, field.FieldJson)
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.FieldType == "pictures" {
result := fmt.Sprintf(`<el-table-column label="%s" prop="%s" width="200">
`, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
result += ` <div class="multiple-img-box">
`
result += fmt.Sprintf(` <el-image preview-teleported v-for="(item,index) in scope.row.%s" :key="index" style="width: 80px; height: 80px" :src="getUrl(item)" fit="cover"/>
`, field.FieldJson)
result += ` </div>
`
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.FieldType == "video" {
result := fmt.Sprintf(`<el-table-column label="%s" prop="%s" width="200">
`, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
result += ` <video
`
result += ` style="width: 100px; height: 100px"
`
result += ` muted
`
result += ` preload="metadata"
`
result += ` >
`
result += fmt.Sprintf(` <source :src="getUrl(scope.row.%s) + '#t=1'">
`, field.FieldJson)
result += ` </video>
`
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.FieldType == "richtext" {
result := fmt.Sprintf(`<el-table-column label="%s" prop="%s" width="200">
`, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
result += ` [富文本内容]
`
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.FieldType == "file" {
result := fmt.Sprintf(`<el-table-column label="%s" prop="%s" width="200">
`, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
result += ` <div class="file-list">
`
result += fmt.Sprintf(` <el-tag v-for="file in scope.row.%s" :key="file.uid" @click="onDownloadFile(file.url)">{{ file.name }}</el-tag>
`, field.FieldJson)
result += ` </div>
`
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.FieldType == "json" {
result := fmt.Sprintf(`<el-table-column label="%s" prop="%s" width="200">
`, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
result += ` [JSON]
`
result += ` </template>
`
result += `</el-table-column>`
return result
} else if field.FieldType == "array" {
result := fmt.Sprintf(`<el-table-column label="%s" prop="%s" width="200">
`, field.FieldDesc, field.FieldJson)
result += ` <template #default="scope">
`
result += fmt.Sprintf(` <ArrayCtrl v-model="scope.row.%s"/>
`, field.FieldJson)
result += ` </template>
`
result += `</el-table-column>`
return result
} else {
return fmt.Sprintf(`<el-table-column%s align="left" label="%s" prop="%s" width="120" />
`,
sortAttr, field.FieldDesc, field.FieldJson)
}
}
func GenerateFormItem(field systemReq.AutoCodeField) string {
// 开始构建表单项
result := fmt.Sprintf(`<el-form-item label="%s:" prop="%s">
`, field.FieldDesc, field.FieldJson)
// 处理不同字段类型
if field.CheckDataSource {
multipleAttr := ""
if field.DataSource.Association == 2 {
multipleAttr = " multiple"
}
result += fmt.Sprintf(` <el-select%s v-model="formData.%s" placeholder="请选择%s" filterable style="width:100%%" :clearable="%v">
`,
multipleAttr, field.FieldJson, field.FieldDesc, field.Clearable)
result += fmt.Sprintf(` <el-option v-for="(item,key) in dataSource.%s" :key="key" :label="item.label" :value="item.value" />
`,
field.FieldJson)
result += ` </el-select>
`
} else {
switch field.FieldType {
case "bool":
result += fmt.Sprintf(` <el-switch v-model="formData.%s" active-color="#13ce66" inactive-color="#ff4949" active-text="是" inactive-text="否" clearable ></el-switch>
`,
field.FieldJson)
case "string":
if field.DictType != "" {
result += fmt.Sprintf(` <el-tree-select v-model="formData.%s" placeholder="请选择%s" :data="%sOptions" style="width:100%%" filterable :clearable="%v" check-strictly></el-tree-select>
`,
field.FieldJson, field.FieldDesc, field.DictType, field.Clearable)
} else {
result += fmt.Sprintf(` <el-input v-model="formData.%s" :clearable="%v" placeholder="请输入%s" />
`,
field.FieldJson, field.Clearable, field.FieldDesc)
}
case "richtext":
result += fmt.Sprintf(` <RichEdit v-model="formData.%s"/>
`, field.FieldJson)
case "json":
result += fmt.Sprintf(` // 此字段为json结构可以前端自行控制展示和数据绑定模式 需绑定json的key为 formData.%s 后端会按照json的类型进行存取
`, field.FieldJson)
result += fmt.Sprintf(` {{ formData.%s }}
`, field.FieldJson)
case "array":
if field.DictType != "" {
result += fmt.Sprintf(` <el-select multiple v-model="formData.%s" placeholder="请选择%s" filterable style="width:100%%" :clearable="%v">
`,
field.FieldJson, field.FieldDesc, field.Clearable)
result += fmt.Sprintf(` <el-option v-for="(item,key) in %sOptions" :key="key" :label="item.label" :value="item.value" />
`,
field.DictType)
result += ` </el-select>
`
} else {
result += fmt.Sprintf(` <ArrayCtrl v-model="formData.%s" editable/>
`, field.FieldJson)
}
case "int":
result += fmt.Sprintf(` <el-input v-model.number="formData.%s" :clearable="%v" placeholder="请输入%s" />
`,
field.FieldJson, field.Clearable, field.FieldDesc)
case "time.Time":
result += fmt.Sprintf(` <el-date-picker v-model="formData.%s" type="date" style="width:100%%" placeholder="选择日期" :clearable="%v" />
`,
field.FieldJson, field.Clearable)
case "float64":
result += fmt.Sprintf(` <el-input-number v-model="formData.%s" style="width:100%%" :precision="2" :clearable="%v" />
`,
field.FieldJson, field.Clearable)
case "enum":
result += fmt.Sprintf(` <el-select v-model="formData.%s" placeholder="请选择%s" style="width:100%%" filterable :clearable="%v">
`,
field.FieldJson, field.FieldDesc, field.Clearable)
result += fmt.Sprintf(` <el-option v-for="item in [%s]" :key="item" :label="item" :value="item" />
`,
field.DataTypeLong)
result += ` </el-select>
`
case "picture":
result += fmt.Sprintf(` <SelectImage
v-model="formData.%s"
file-type="image"
/>
`, field.FieldJson)
case "pictures":
result += fmt.Sprintf(` <SelectImage
multiple
v-model="formData.%s"
file-type="image"
/>
`, field.FieldJson)
case "video":
result += fmt.Sprintf(` <SelectImage
v-model="formData.%s"
file-type="video"
/>
`, field.FieldJson)
case "file":
result += fmt.Sprintf(` <SelectFile v-model="formData.%s" />
`, field.FieldJson)
}
}
// 关闭表单项
result += `</el-form-item>`
return result
}
func GenerateDescriptionItem(field systemReq.AutoCodeField) string {
// 开始构建描述项
result := fmt.Sprintf(`<el-descriptions-item label="%s">
`, field.FieldDesc)
if field.CheckDataSource {
result += ` <template #default="scope">
`
if field.DataSource.Association == 2 {
result += fmt.Sprintf(` <el-tag v-for="(item,key) in filterDataSource(dataSource.%s,detailForm.%s)" :key="key">
`,
field.FieldJson, field.FieldJson)
result += ` {{ item }}
`
result += ` </el-tag>
`
} else {
result += fmt.Sprintf(` <span>{{ filterDataSource(dataSource.%s,detailForm.%s) }}</span>
`,
field.FieldJson, field.FieldJson)
}
result += ` </template>
`
} else if field.FieldType != "picture" && field.FieldType != "pictures" &&
field.FieldType != "file" && field.FieldType != "array" &&
field.FieldType != "richtext" {
result += fmt.Sprintf(` {{ detailForm.%s }}
`, field.FieldJson)
} else {
switch field.FieldType {
case "picture":
result += fmt.Sprintf(` <el-image style="width: 50px; height: 50px" :preview-src-list="returnArrImg(detailForm.%s)" :src="getUrl(detailForm.%s)" fit="cover" />
`,
field.FieldJson, field.FieldJson)
case "array":
result += fmt.Sprintf(` <ArrayCtrl v-model="detailForm.%s"/>
`, field.FieldJson)
case "pictures":
result += fmt.Sprintf(` <el-image style="width: 50px; height: 50px; margin-right: 10px" :preview-src-list="returnArrImg(detailForm.%s)" :initial-index="index" v-for="(item,index) in detailForm.%s" :key="index" :src="getUrl(item)" fit="cover" />
`,
field.FieldJson, field.FieldJson)
case "richtext":
result += fmt.Sprintf(` <RichView v-model="detailForm.%s" />
`, field.FieldJson)
case "file":
result += fmt.Sprintf(` <div class="fileBtn" v-for="(item,index) in detailForm.%s" :key="index">
`, field.FieldJson)
result += ` <el-button type="primary" text bg @click="onDownloadFile(item.url)">
`
result += ` <el-icon style="margin-right: 5px"><Download /></el-icon>
`
result += ` {{ item.name }}
`
result += ` </el-button>
`
result += ` </div>
`
}
}
// 关闭描述项
result += `</el-descriptions-item>`
return result
}
func GenerateDefaultFormValue(field systemReq.AutoCodeField) string {
// 根据字段类型确定默认值
var defaultValue string
switch field.FieldType {
case "bool":
defaultValue = "false"
case "string", "richtext":
defaultValue = "''"
case "int":
if field.DataSource != nil { // 检查数据源是否存在
defaultValue = "undefined"
} else {
defaultValue = "0"
}
case "time.Time":
defaultValue = "new Date()"
case "float64":
defaultValue = "0"
case "picture", "video":
defaultValue = "\"\""
case "pictures", "file", "array":
defaultValue = "[]"
case "json":
defaultValue = "{}"
default:
defaultValue = "null"
}
// 返回格式化后的默认值字符串
return fmt.Sprintf(`%s: %s,`, field.FieldJson, defaultValue)
}
// GenerateSearchField 根据字段属性生成搜索结构体中的字段定义
func GenerateSearchField(field systemReq.AutoCodeField) string {
var result string
if field.FieldSearchType == "" {
return "" // 如果没有搜索类型,返回空字符串
}
if field.FieldSearchType == "BETWEEN" || field.FieldSearchType == "NOT BETWEEN" {
// 生成范围搜索字段
// time 的情况
if field.FieldType == "time.Time" {
result = fmt.Sprintf("%sRange []time.Time `json:\"%sRange\" form:\"%sRange[]\"`",
field.FieldName, field.FieldJson, field.FieldJson)
} else {
startField := fmt.Sprintf("Start%s *%s `json:\"start%s\" form:\"start%s\"`",
field.FieldName, field.FieldType, field.FieldName, field.FieldName)
endField := fmt.Sprintf("End%s *%s `json:\"end%s\" form:\"end%s\"`",
field.FieldName, field.FieldType, field.FieldName, field.FieldName)
result = startField + "\n" + endField
}
} else {
// 生成普通搜索字段
if field.FieldType == "enum" || field.FieldType == "picture" ||
field.FieldType == "pictures" || field.FieldType == "video" ||
field.FieldType == "json" || field.FieldType == "richtext" || field.FieldType == "array" || field.FieldType == "file" {
result = fmt.Sprintf("%s string `json:\"%s\" form:\"%s\"` ",
field.FieldName, field.FieldJson, field.FieldJson)
} else {
result = fmt.Sprintf("%s *%s `json:\"%s\" form:\"%s\"` ",
field.FieldName, field.FieldType, field.FieldJson, field.FieldJson)
}
}
return result
}

View File

@@ -1,285 +0,0 @@
package utils
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"image"
"image/png"
"io"
)
// CharacterCardV2 SillyTavern 角色卡 V2 格式
type CharacterCardV2 struct {
Spec string `json:"spec"`
SpecVersion string `json:"spec_version"`
Data CharacterCardV2Data `json:"data"`
}
type CharacterCardV2Data struct {
Name string `json:"name"`
Description string `json:"description"`
Personality string `json:"personality"`
Scenario string `json:"scenario"`
FirstMes string `json:"first_mes"`
MesExample string `json:"mes_example"`
CreatorNotes string `json:"creator_notes"`
SystemPrompt string `json:"system_prompt"`
PostHistoryInstructions string `json:"post_history_instructions"`
Tags []string `json:"tags"`
Creator string `json:"creator"`
CharacterVersion string `json:"character_version"`
AlternateGreetings []string `json:"alternate_greetings"`
CharacterBook map[string]interface{} `json:"character_book,omitempty"`
Extensions map[string]interface{} `json:"extensions"`
}
// ExtractCharacterFromPNG 从 PNG 图片中提取角色卡数据
func ExtractCharacterFromPNG(pngData []byte) (*CharacterCardV2, error) {
reader := bytes.NewReader(pngData)
// 验证 PNG 格式(解码但不保存图片)
_, err := png.Decode(reader)
if err != nil {
return nil, errors.New("无效的 PNG 文件")
}
// 重新读取以获取 tEXt chunks
reader.Seek(0, 0)
// 查找 tEXt chunk 中的 "chara" 字段
charaJSON, err := extractTextChunk(reader, "chara")
if err != nil {
return nil, errors.New("PNG 中没有找到角色卡数据")
}
// 尝试 Base64 解码
decodedJSON, err := base64.StdEncoding.DecodeString(charaJSON)
if err != nil {
// 如果不是 Base64直接使用原始 JSON
decodedJSON = []byte(charaJSON)
}
// 解析 JSON
var card CharacterCardV2
err = json.Unmarshal(decodedJSON, &card)
if err != nil {
return nil, errors.New("解析角色卡数据失败: " + err.Error())
}
return &card, nil
}
// extractTextChunk 从 PNG 中提取指定 key 的 tEXt chunk
func extractTextChunk(r io.Reader, key string) (string, error) {
// 跳过 PNG signature (8 bytes)
signature := make([]byte, 8)
if _, err := io.ReadFull(r, signature); err != nil {
return "", err
}
// 验证 PNG signature
expectedSig := []byte{137, 80, 78, 71, 13, 10, 26, 10}
if !bytes.Equal(signature, expectedSig) {
return "", errors.New("invalid PNG signature")
}
// 读取所有 chunks
for {
// 读取 chunk length (4 bytes)
lengthBytes := make([]byte, 4)
if _, err := io.ReadFull(r, lengthBytes); err != nil {
if err == io.EOF {
break
}
return "", err
}
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 |
uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
// 读取 chunk type (4 bytes)
chunkType := make([]byte, 4)
if _, err := io.ReadFull(r, chunkType); err != nil {
return "", err
}
// 读取 chunk data
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return "", err
}
// 读取 CRC (4 bytes)
crc := make([]byte, 4)
if _, err := io.ReadFull(r, crc); err != nil {
return "", err
}
// 检查是否是 tEXt chunk
if string(chunkType) == "tEXt" {
// tEXt chunk 格式: keyword\0text
nullIndex := bytes.IndexByte(data, 0)
if nullIndex == -1 {
continue
}
keyword := string(data[:nullIndex])
text := string(data[nullIndex+1:])
if keyword == key {
return text, nil
}
}
// IEND chunk 表示结束
if string(chunkType) == "IEND" {
break
}
}
return "", errors.New("text chunk not found")
}
// EmbedCharacterToPNG 将角色卡数据嵌入到 PNG 图片中
func EmbedCharacterToPNG(img image.Image, card *CharacterCardV2) ([]byte, error) {
// 序列化角色卡数据
cardJSON, err := json.Marshal(card)
if err != nil {
return nil, err
}
// Base64 编码
encodedJSON := base64.StdEncoding.EncodeToString(cardJSON)
// 创建一个 buffer 来写入 PNG
var buf bytes.Buffer
// 写入 PNG signature
buf.Write([]byte{137, 80, 78, 71, 13, 10, 26, 10})
// 编码原始图片到临时 buffer
var imgBuf bytes.Buffer
if err := png.Encode(&imgBuf, img); err != nil {
return nil, err
}
// 跳过原始 PNG 的 signature
imgData := imgBuf.Bytes()[8:]
// 将原始图片的 chunks 复制到输出,在 IEND 之前插入 tEXt chunk
r := bytes.NewReader(imgData)
for {
// 读取 chunk length
lengthBytes := make([]byte, 4)
if _, err := io.ReadFull(r, lengthBytes); err != nil {
if err == io.EOF {
break
}
return nil, err
}
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 |
uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
// 读取 chunk type
chunkType := make([]byte, 4)
if _, err := io.ReadFull(r, chunkType); err != nil {
return nil, err
}
// 读取 chunk data
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return nil, err
}
// 读取 CRC
crc := make([]byte, 4)
if _, err := io.ReadFull(r, crc); err != nil {
return nil, err
}
// 如果是 IEND chunk先写入 tEXt chunk
if string(chunkType) == "IEND" {
// 写入 tEXt chunk
writeTextChunk(&buf, "chara", encodedJSON)
}
// 写入原始 chunk
buf.Write(lengthBytes)
buf.Write(chunkType)
buf.Write(data)
buf.Write(crc)
if string(chunkType) == "IEND" {
break
}
}
return buf.Bytes(), nil
}
// writeTextChunk 写入 tEXt chunk
func writeTextChunk(w io.Writer, keyword, text string) error {
data := append([]byte(keyword), 0)
data = append(data, []byte(text)...)
// 写入 length
length := uint32(len(data))
lengthBytes := []byte{
byte(length >> 24),
byte(length >> 16),
byte(length >> 8),
byte(length),
}
w.Write(lengthBytes)
// 写入 type
w.Write([]byte("tEXt"))
// 写入 data
w.Write(data)
// 计算并写入 CRC
crcData := append([]byte("tEXt"), data...)
crc := calculateCRC(crcData)
crcBytes := []byte{
byte(crc >> 24),
byte(crc >> 16),
byte(crc >> 8),
byte(crc),
}
w.Write(crcBytes)
return nil
}
// calculateCRC 计算 CRC32
func calculateCRC(data []byte) uint32 {
crc := uint32(0xFFFFFFFF)
for _, b := range data {
crc ^= uint32(b)
for i := 0; i < 8; i++ {
if crc&1 != 0 {
crc = (crc >> 1) ^ 0xEDB88320
} else {
crc >>= 1
}
}
}
return crc ^ 0xFFFFFFFF
}
// ParseCharacterCardJSON 解析 JSON 格式的角色卡
func ParseCharacterCardJSON(jsonData []byte) (*CharacterCardV2, error) {
var card CharacterCardV2
err := json.Unmarshal(jsonData, &card)
if err != nil {
return nil, errors.New("解析角色卡 JSON 失败: " + err.Error())
}
return &card, nil
}

View File

@@ -2,11 +2,10 @@ package utils
import (
"fmt"
"git.echol.cn/loser/ai_proxy/server/model/common"
"math/rand"
"reflect"
"strings"
"git.echol.cn/loser/ai_proxy/server/model/common"
)
//@author: [piexlmax](https://github.com/piexlmax)

View File

@@ -3,7 +3,6 @@ package utils
import (
"crypto/md5"
"encoding/hex"
"golang.org/x/crypto/bcrypt"
)

View File

@@ -1,26 +0,0 @@
package utils
import (
"strconv"
"github.com/gin-gonic/gin"
)
// StringToUint 字符串转 uint
func StringToUint(s string) (uint, error) {
val, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return 0, err
}
return uint(val), nil
}
// GetIntQuery 获取查询参数int类型
func GetIntQuery(c *gin.Context, key string, defaultValue int) int {
if val := c.Query(key); val != "" {
if intVal, err := strconv.Atoi(val); err == nil {
return intVal
}
}
return defaultValue
}

View File

@@ -1,22 +0,0 @@
package utils
import (
"strconv"
"github.com/gin-gonic/gin"
)
// GetUintParam 从 URL 参数中获取 uint 值
func GetUintParam(c *gin.Context, key string) uint {
val := c.Param(key)
if val == "" {
return 0
}
uintVal, err := strconv.ParseUint(val, 10, 32)
if err != nil {
return 0
}
return uint(uintVal)
}

View File

@@ -1,15 +0,0 @@
package utils
import (
"crypto/rand"
"encoding/hex"
)
// GenerateRandomString 生成随机字符串
func GenerateRandomString(length int) (string, error) {
bytes := make([]byte, length/2+1)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes)[:length], nil
}

View File

@@ -1,11 +1,10 @@
package utils
import (
"git.echol.cn/loser/ai_proxy/server/global"
"runtime"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/mem"

View File

@@ -1,9 +1,8 @@
package timer
import (
"sync"
"github.com/robfig/cron/v3"
"sync"
)
type Timer interface {

View File

@@ -1,9 +1,8 @@
package utils
import (
"testing"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"testing"
)
type PageInfoTest struct {