You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

185 lines
5.1 KiB
Go

11 months ago
package system
import (
"context"
"database/sql"
"errors"
"fmt"
"gorm.io/gorm"
"miniapp/global"
"miniapp/model/system/request"
"sort"
)
const (
Mysql = "mysql"
Pgsql = "pgsql"
Sqlite = "sqlite"
InitSuccess = "\n[%v] --> 初始数据成功!\n"
InitDataExist = "\n[%v] --> %v 的初始数据已存在!\n"
InitDataFailed = "\n[%v] --> %v 初始数据失败! \nerr: %+v\n"
InitDataSuccess = "\n[%v] --> %v 初始数据成功!\n"
)
const (
InitOrderSystem = 10
InitOrderInternal = 1000
InitOrderExternal = 100000
)
var (
ErrMissingDBContext = errors.New("missing db in context")
ErrMissingDependentContext = errors.New("missing dependent value in context")
ErrDBTypeMismatch = errors.New("db type mismatch")
)
// SubInitializer 提供 source/*/init() 使用的接口,每个 initializer 完成一个初始化过程
type SubInitializer interface {
InitializerName() string // 不一定代表单独一个表,所以改成了更宽泛的语义
MigrateTable(ctx context.Context) (next context.Context, err error)
InitializeData(ctx context.Context) (next context.Context, err error)
TableCreated(ctx context.Context) bool
DataInserted(ctx context.Context) bool
}
// TypedDBInitHandler 执行传入的 initializer
type TypedDBInitHandler interface {
EnsureDB(ctx context.Context, conf *request.InitDB) (context.Context, error) // 建库,失败属于 fatal error因此让它 panic
WriteConfig(ctx context.Context) error // 回写配置
InitTables(ctx context.Context, inits initSlice) error // 建表 handler
InitData(ctx context.Context, inits initSlice) error // 建数据 handler
}
// orderedInitializer 组合一个顺序字段,以供排序
type orderedInitializer struct {
order int
SubInitializer
}
// initSlice 供 initializer 排序依赖时使用
type initSlice []*orderedInitializer
var (
initializers initSlice
cache map[string]*orderedInitializer
)
// RegisterInit 注册要执行的初始化过程,会在 InitDB() 时调用
func RegisterInit(order int, i SubInitializer) {
if initializers == nil {
initializers = initSlice{}
}
if cache == nil {
cache = map[string]*orderedInitializer{}
}
name := i.InitializerName()
if _, existed := cache[name]; existed {
panic(fmt.Sprintf("Name conflict on %s", name))
}
ni := orderedInitializer{order, i}
initializers = append(initializers, &ni)
cache[name] = &ni
}
/* ---- * service * ---- */
type InitDBService struct{}
// InitDB 创建数据库并初始化 总入口
func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
ctx := context.TODO()
if len(initializers) == 0 {
return errors.New("无可用初始化过程,请检查初始化是否已执行完成")
}
sort.Sort(&initializers) // 保证有依赖的 initializer 排在后面执行
// Note: 若 initializer 只有单一依赖,可以写为 B=A+1, C=A+1; 由于 BC 之间没有依赖关系,所以谁先谁后并不影响初始化
// 若存在多个依赖,可以写为 C=A+B, D=A+B+C, E=A+1;
// C必然>A|B因此在AB之后执行D必然>A|B|C因此在ABC后执行而E只依赖A顺序与CD无关因此E与CD哪个先执行并不影响
var initHandler TypedDBInitHandler
switch conf.DBType {
case "mysql":
initHandler = NewMysqlInitHandler()
ctx = context.WithValue(ctx, "dbtype", "mysql")
case "pgsql":
initHandler = NewPgsqlInitHandler()
ctx = context.WithValue(ctx, "dbtype", "pgsql")
case "sqlite":
initHandler = NewSqliteInitHandler()
ctx = context.WithValue(ctx, "dbtype", "sqlite")
default:
initHandler = NewMysqlInitHandler()
ctx = context.WithValue(ctx, "dbtype", "mysql")
}
ctx, err = initHandler.EnsureDB(ctx, &conf)
if err != nil {
return err
}
db := ctx.Value("db").(*gorm.DB)
global.GVA_DB = db
if err = initHandler.InitTables(ctx, initializers); err != nil {
return err
}
if err = initHandler.InitData(ctx, initializers); err != nil {
return err
}
if err = initHandler.WriteConfig(ctx); err != nil {
return err
}
initializers = initSlice{}
cache = map[string]*orderedInitializer{}
return nil
}
// createDatabase 创建数据库( EnsureDB() 中调用
func createDatabase(dsn string, driver string, createSql string) error {
db, err := sql.Open(driver, dsn)
if err != nil {
return err
}
defer func(db *sql.DB) {
err = db.Close()
if err != nil {
fmt.Println(err)
}
}(db)
if err = db.Ping(); err != nil {
return err
}
_, err = db.Exec(createSql)
return err
}
// createTables 创建表(默认 dbInitHandler.initTables 行为)
func createTables(ctx context.Context, inits initSlice) error {
next, cancel := context.WithCancel(ctx)
defer func(c func()) { c() }(cancel)
for _, init := range inits {
if init.TableCreated(next) {
continue
}
if n, err := init.MigrateTable(next); err != nil {
return err
} else {
next = n
}
}
return nil
}
/* -- sortable interface -- */
func (a initSlice) Len() int {
return len(a)
}
func (a initSlice) Less(i, j int) bool {
return a[i].order < a[j].order
}
func (a initSlice) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}