Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 287 additions & 0 deletions auto_registry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
package gen

import (
"os"
"path/filepath"
"strings"
"testing"

"gorm.io/driver/sqlite"
"gorm.io/gen/internal/model"
"gorm.io/gorm"
)

func TestAutoRegistryInitGeneration(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()

// 创建测试数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}

// 创建简单的测试表
err = db.Exec(`CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)`).Error
if err != nil {
t.Fatalf("failed to create users table: %v", err)
}

err = db.Exec(`CREATE TABLE orders (id INTEGER PRIMARY KEY, amount DECIMAL)`).Error
if err != nil {
t.Fatalf("failed to create orders table: %v", err)
}

tests := []struct {
name string
configTables []string
expectInitFunc map[string]bool // table -> shouldHaveInit
}{
{
name: "AllTables",
configTables: []string{}, // 空数组表示所有表
expectInitFunc: map[string]bool{
"users": true,
"orders": true,
},
},
{
name: "OnlyUsersTable",
configTables: []string{"users"},
expectInitFunc: map[string]bool{
"users": true,
"orders": false,
},
},
{
name: "BothTables",
configTables: []string{"users", "orders"},
expectInitFunc: map[string]bool{
"users": true,
"orders": true,
},
},
{
name: "NoTables",
configTables: []string{"nonexistent"},
expectInitFunc: map[string]bool{
"users": false,
"orders": false,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
modelDir := filepath.Join(tempDir, tt.name, "model")

// 配置生成器
g := NewGenerator(Config{
OutPath: filepath.Join(tempDir, tt.name, "query"),
ModelPkgPath: modelDir,
Mode: WithDefaultQuery | WithAutoRegistry,
})

// 根据测试配置启用自动注册
if len(tt.configTables) == 0 {
g.WithAutoRegistry() // 不传参数,所有表
} else {
g.WithAutoRegistry(tt.configTables...) // 传入指定表名
}

g.UseDB(db)
g.GenerateAllTable()
g.Execute()

// 验证每个表的 init 函数生成情况
for tableName, shouldHaveInit := range tt.expectInitFunc {
t.Run(tableName, func(t *testing.T) {
checkInitFunction(t, modelDir, tableName, shouldHaveInit)
})
}

// 验证注册表文件是否生成
registryFile := filepath.Join(modelDir, "gen.go")
if _, err := os.Stat(registryFile); os.IsNotExist(err) {
t.Errorf("registry file %s should exist", registryFile)
}
})
}
}

// checkInitFunction 检查指定表的模型文件是否包含正确的 init 函数
func checkInitFunction(t *testing.T, modelDir, tableName string, shouldHaveInit bool) {
fileName := filepath.Join(modelDir, tableName+".gen.go")

// 检查文件是否存在
if _, err := os.Stat(fileName); os.IsNotExist(err) {
t.Errorf("model file %s does not exist", fileName)
return
}

// 读取文件内容
content, err := os.ReadFile(fileName)
if err != nil {
t.Errorf("failed to read file %s: %v", fileName, err)
return
}

fileContent := string(content)

// 检查是否包含 init 函数
hasInitFunc := strings.Contains(fileContent, "func init() {")
hasRegisterCall := strings.Contains(fileContent, "RegisterModel(")

if shouldHaveInit {
if !hasInitFunc {
t.Errorf("file %s should contain 'func init() {' but doesn't", fileName)
}
if !hasRegisterCall {
t.Errorf("file %s should contain 'RegisterModel(' call but doesn't", fileName)
}

// 验证 RegisterModel 调用格式
if hasInitFunc && hasRegisterCall {
expectedModelName := getExpectedModelName(tableName)
expectedCall := "RegisterModel(&" + expectedModelName + "{}, TableName" + expectedModelName + ")"

if !strings.Contains(fileContent, expectedCall) {
t.Errorf("file %s should contain %s", fileName, expectedCall)
t.Logf("Actual file content:\n%s", fileContent)
}
}
} else {
if hasInitFunc && hasRegisterCall {
t.Errorf("file %s should not contain init function with RegisterModel call", fileName)
}
}
}

// getExpectedModelName 根据表名获取期望的模型名
func getExpectedModelName(tableName string) string {
switch tableName {
case "users":
return "User"
case "orders":
return "Order"
default:
// 简单的首字母大写
return model.TitleCase(tableName)
}
}

// TestShouldEnableAutoRegistry 测试表过滤逻辑
func TestShouldEnableAutoRegistry(t *testing.T) {
tests := []struct {
name string
configuredList []string
tableName string
expected bool
}{
{
name: "EmptyList_AllTablesEnabled",
configuredList: []string{},
tableName: "users",
expected: true,
},
{
name: "TableInList_ShouldEnable",
configuredList: []string{"users", "orders"},
tableName: "users",
expected: true,
},
{
name: "TableNotInList_ShouldDisable",
configuredList: []string{"users", "orders"},
tableName: "products",
expected: false,
},
{
name: "SingleTable_Match",
configuredList: []string{"users"},
tableName: "users",
expected: true,
},
{
name: "SingleTable_NoMatch",
configuredList: []string{"users"},
tableName: "orders",
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Generator{
Config: Config{
RegistryTableList: tt.configuredList,
},
}

result := g.shouldEnableAutoRegistry(tt.tableName)
if result != tt.expected {
t.Errorf("shouldEnableAutoRegistry(%s) = %v, want %v",
tt.tableName, result, tt.expected)
}
})
}
}

// TestWithAutoRegistryConfig 测试配置方法
func TestWithAutoRegistryConfig(t *testing.T) {
tests := []struct {
name string
tableNames []string
expectedList []string
expectedMode GenerateMode
}{
{
name: "NoTables",
tableNames: []string{},
expectedList: []string{},
expectedMode: WithAutoRegistry,
},
{
name: "SingleTable",
tableNames: []string{"users"},
expectedList: []string{"users"},
expectedMode: WithAutoRegistry,
},
{
name: "MultipleTables",
tableNames: []string{"users", "orders", "products"},
expectedList: []string{"users", "orders", "products"},
expectedMode: WithAutoRegistry,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := Config{}

// 调用 WithAutoRegistry 方法
cfg.WithAutoRegistry(tt.tableNames...)

// 验证配置结果
if cfg.Mode&WithAutoRegistry == 0 {
t.Error("WithAutoRegistry mode should be enabled")
}

if len(cfg.RegistryTableList) != len(tt.expectedList) {
t.Errorf("RegistryTableList length = %d, want %d",
len(cfg.RegistryTableList), len(tt.expectedList))
}

for i, expected := range tt.expectedList {
if i >= len(cfg.RegistryTableList) || cfg.RegistryTableList[i] != expected {
actual := "<out of bounds>"
if i < len(cfg.RegistryTableList) {
actual = cfg.RegistryTableList[i]
}
t.Errorf("RegistryTableList[%d] = %s, want %s",
i, actual, expected)
}
}
})
}
}
13 changes: 13 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ const (

// WithGeneric generate code with generic
WithGeneric

// WithAutoRegistry generate init functions to auto-register models
WithAutoRegistry
)

// Config generator's basic configuration
Expand All @@ -38,6 +41,9 @@ type Config struct {
ModelPkgPath string // generated model code's package name
WithUnitTest bool // generate unit test for query code

// auto registry configuration
RegistryTableList []string // specific table names to enable auto registry, empty means all tables

// generate model global configuration
FieldNullable bool // generate pointer when field is nullable
FieldCoverable bool // generate pointer when field has default value, to fix problem zero value cannot be assign: https://gorm.io/docs/create.html#Default-Values
Expand Down Expand Up @@ -106,6 +112,13 @@ func (cfg *Config) WithJSONTagNameStrategy(ns func(columnName string) (tagConten
cfg.fieldJSONTagNS = ns
}

// WithAutoRegistry enable auto registry feature for generated models
// tableNames: optional table names to enable auto registry, if empty, all tables will be enabled
func (cfg *Config) WithAutoRegistry(tableNames ...string) {
cfg.Mode |= WithAutoRegistry
cfg.RegistryTableList = tableNames
}

// WithImportPkgPath specify import package path
func (cfg *Config) WithImportPkgPath(paths ...string) {
for i, path := range paths {
Expand Down
2 changes: 1 addition & 1 deletion examples/biz/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"fmt"

"gorm.io/gen/examples/dal/query"
"examples/dal/query"
)

var q = query.Q
Expand Down
13 changes: 9 additions & 4 deletions examples/cmd/gen/generate.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
package main

import (
"examples/conf"
"examples/dal"

"gorm.io/gen"
"gorm.io/gen/examples/conf"
"gorm.io/gen/examples/dal"
)

func init() {
dal.DB = dal.ConnectDB(conf.MySQLDSN).Debug()
dal.DB = dal.ConnectDB(conf.SQLiteDBName).Debug()

prepare(dal.DB) // prepare table for generate
}

func main() {
g := gen.NewGenerator(gen.Config{
OutPath: "../../dal/query",
OutPath: "../../dal/query",
ModelPkgPath: "../../dal/model",
})

g.UseDB(dal.DB)

// auto registry to models
g.WithAutoRegistry()

// generate all table from database
g.ApplyBasic(g.GenerateAllTable()...)

Expand Down
14 changes: 8 additions & 6 deletions examples/cmd/gen/prepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
// prepare table for test

const mytableSQL = "CREATE TABLE IF NOT EXISTS `mytables` (" +
" `ID` int(11) NOT NULL," +
" `username` varchar(16) DEFAULT NULL," +
" `age` int(8) NOT NULL," +
" `phone` varchar(11) NOT NULL," +
" INDEX `idx_username` (`username`)" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;"
" `ID` INTEGER NOT NULL PRIMARY KEY," +
" `username` TEXT," +
" `age` INTEGER NOT NULL," +
" `phone` TEXT NOT NULL" +
");"

const indexSQL = "CREATE INDEX IF NOT EXISTS `idx_username` ON `mytables` (`username`);"

func prepare(db *gorm.DB) {
db.Exec(mytableSQL)
db.Exec(indexSQL)
}
5 changes: 3 additions & 2 deletions examples/cmd/only_model/generate.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package main

import (
"examples/conf"
"examples/dal"

"gorm.io/gen"
"gorm.io/gen/examples/conf"
"gorm.io/gen/examples/dal"
)

func init() {
Expand Down
Loading
Loading