diff --git a/cmd/jet/main.go b/cmd/jet/main.go index b56ef31a..243a40de 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -42,7 +42,11 @@ var ( ignoreViews string ignoreEnums string - destDir string + destDir string + modelPkg string + tablePkg string + viewPkg string + enumPkg string ) func init() { @@ -66,11 +70,15 @@ func init() { flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public")(PostgreSQL only)`) flag.StringVar(¶ms, "params", "", "Additional connection string parameters(optional). Used only if dsn is not set.") flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL. Used only if dsn is not set. (optional)(default "disable")(PostgreSQL only)`) - flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore`) - flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore`) - flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore`) - - flag.StringVar(&destDir, "path", "", "Destination dir for files generated.") + flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore.`) + flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore.`) + flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore.`) + + flag.StringVar(&destDir, "path", "", "Destination directory for files generated.") + flag.StringVar(&modelPkg, "rel-model-path", "model", "Relative path for the Model files package from the destination directory.") + flag.StringVar(&tablePkg, "rel-table-path", "table", "Relative path for the Table files package from the destination directory.") + flag.StringVar(&viewPkg, "rel-view-path", "view", "Relative path for the View files package from the destination directory.") + flag.StringVar(&enumPkg, "rel-enum-path", "enum", "Relative path for the Enum files package from the destination directory.") } func main() { @@ -170,6 +178,7 @@ func usage() { "source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode", "path", "ignore-tables", "ignore-views", "ignore-enums", + "rel-model-path", "rel-table-path", "rel-view-path", "rel-enum-path", } for _, name := range order { @@ -186,6 +195,7 @@ func usage() { $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen $ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen + $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen -rel-model-path=./entity `) } @@ -246,7 +256,7 @@ func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []strin return template.Default(dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). - UseModel(template.DefaultModel(). + UseModel(template.DefaultModel().UsePath(modelPkg). UseTable(func(table metadata.Table) template.TableModel { if shouldSkipTable(table) { return template.TableModel{Skip: true} @@ -271,19 +281,22 @@ func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []strin if shouldSkipTable(table) { return template.TableSQLBuilder{Skip: true} } - return template.DefaultTableSQLBuilder(table) + + return template.DefaultTableSQLBuilder(table).UsePath(tablePkg) }). UseView(func(table metadata.Table) template.ViewSQLBuilder { if shouldSkipView(table) { return template.ViewSQLBuilder{Skip: true} } - return template.DefaultViewSQLBuilder(table) + + return template.DefaultViewSQLBuilder(table).UsePath(viewPkg) }). UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder { if shouldSkipEnum(enum) { return template.EnumSQLBuilder{Skip: true} } - return template.DefaultEnumSQLBuilder(enum) + + return template.DefaultEnumSQLBuilder(enum).UsePath(enumPkg) }), ) }) diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index ce6acd16..0b503ef8 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -4,7 +4,7 @@ import ( "database/sql" "fmt" "net/url" - "path" + "path/filepath" "strconv" "github.com/go-jet/jet/v2/generator/metadata" @@ -66,7 +66,7 @@ func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) er return fmt.Errorf("failed to get '%s' schema metadata: %w", schema, err) } - dirPath := path.Join(destDir, cfg.Database) + dirPath := filepath.Join(destDir, cfg.Database) err = template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) if err != nil { diff --git a/generator/template/model_template.go b/generator/template/model_template.go index f89ebd1b..990b4096 100644 --- a/generator/template/model_template.go +++ b/generator/template/model_template.go @@ -6,7 +6,7 @@ import ( "github.com/go-jet/jet/v2/internal/utils/dbidentifier" "github.com/google/uuid" "github.com/jackc/pgtype" - "path" + "path/filepath" "reflect" "strings" "time" @@ -23,7 +23,7 @@ type Model struct { // PackageName returns package name of model types func (m Model) PackageName() string { - return path.Base(m.Path) + return filepath.Base(m.Path) } // UsePath returns new Model template with replaced file path diff --git a/generator/template/process.go b/generator/template/process.go index 5abef879..371c89b6 100644 --- a/generator/template/process.go +++ b/generator/template/process.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "github.com/go-jet/jet/v2/internal/utils/filesys" - "path" + "path/filepath" "strings" "text/template" @@ -20,7 +20,7 @@ func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemp } schemaTemplate := generatorTemplate.Schema(schemaMetaData) - schemaPath := path.Join(dirPath, schemaTemplate.Path) + schemaPath := filepath.Join(dirPath, schemaTemplate.Path) fmt.Println("Destination directory:", schemaPath) fmt.Println("Cleaning up destination directory...") @@ -50,7 +50,7 @@ func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate return nil } - modelDirPath := path.Join(dirPath, modelTemplate.Path) + modelDirPath := filepath.Join(dirPath, modelTemplate.Path) err := filesys.EnsureDirPathExist(modelDirPath) if err != nil { @@ -83,7 +83,7 @@ func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metad return nil } - sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path) + sqlBuilderPath := filepath.Join(dirPath, sqlBuilderTemplate.Path) err := processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate) if err != nil { @@ -117,7 +117,7 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData [] continue } - enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path) + enumSQLBuilderPath := filepath.Join(dirPath, enumTemplate.Path) err := filesys.EnsureDirPathExist(enumSQLBuilderPath) if err != nil { @@ -182,7 +182,7 @@ func processTableSQLBuilder(fileTypes, dirPath string, continue } - tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilder.Path) + tableSQLBuilderPath := filepath.Join(dirPath, tableSQLBuilder.Path) err := filesys.EnsureDirPathExist(tableSQLBuilderPath) if err != nil { @@ -255,7 +255,7 @@ func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder return fmt.Errorf("failed to generate use schema template: %w", err) } - basePath := path.Join(dirPath, builders[0].Path) + basePath := filepath.Join(dirPath, builders[0].Path) fileName := fileTypes + "_use_schema" err = filesys.FormatAndSaveGoFile(basePath, fileName, text) diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index 16f88b84..3b158b85 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/internal/utils/dbidentifier" - "path" + "path/filepath" "slices" "strings" "unicode" @@ -90,7 +90,7 @@ func DefaultViewSQLBuilder(viewMetaData metadata.Table) ViewSQLBuilder { // PackageName returns package name of table sql builder types func (tb TableSQLBuilder) PackageName() string { - return path.Base(tb.Path) + return filepath.Base(tb.Path) } // UsePath returns new TableSQLBuilder with new relative path set @@ -228,7 +228,7 @@ func DefaultEnumSQLBuilder(enumMetaData metadata.Enum) EnumSQLBuilder { // PackageName returns enum sql builder package name func (e EnumSQLBuilder) PackageName() string { - return path.Base(e.Path) + return filepath.Base(e.Path) } // UsePath returns new EnumSQLBuilder with new path set diff --git a/tests/internal/utils/file/file.go b/tests/internal/utils/file/file.go index 73095ea3..98438b13 100644 --- a/tests/internal/utils/file/file.go +++ b/tests/internal/utils/file/file.go @@ -3,13 +3,13 @@ package file import ( "github.com/stretchr/testify/require" "os" - "path" + "path/filepath" "testing" ) // Exists expects file to exist on path constructed from pathElems and returns content of the file func Exists(t *testing.T, pathElems ...string) (fileContent string) { - modelFilePath := path.Join(pathElems...) + modelFilePath := filepath.Join(pathElems...) file, err := os.ReadFile(modelFilePath) // #nosec G304 require.Nil(t, err) require.NotEmpty(t, file) @@ -18,7 +18,7 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) { // NotExists expects file not to exist on path constructed from pathElems func NotExists(t *testing.T, pathElems ...string) { - modelFilePath := path.Join(pathElems...) + modelFilePath := filepath.Join(pathElems...) _, err := os.ReadFile(modelFilePath) // #nosec G304 require.True(t, os.IsNotExist(err)) } diff --git a/tests/mysql/generator_template_test.go b/tests/mysql/generator_template_test.go index a257f31c..8c539e26 100644 --- a/tests/mysql/generator_template_test.go +++ b/tests/mysql/generator_template_test.go @@ -12,18 +12,18 @@ import ( "github.com/go-jet/jet/v2/tests/dbconfig" file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" "github.com/stretchr/testify/require" - "path" + "path/filepath" "testing" ) const tempTestDir = "./.tempTestDir" -var defaultModelPath = path.Join(tempTestDir, "dvds/model") -var defaultActorModelFilePath = path.Join(tempTestDir, "dvds/model", "actor.go") -var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table") -var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "dvds/view") -var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "dvds/enum") -var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table", "actor.go") +var defaultModelPath = filepath.Join(tempTestDir, "dvds/model") +var defaultActorModelFilePath = filepath.Join(tempTestDir, "dvds/model", "actor.go") +var defaultTableSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/table") +var defaultViewSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/view") +var defaultEnumSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/enum") +var defaultActorSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/table", "actor.go") func dbConnection(dbName string) mysql2.DBConnection { if sourceIsMariaDB() { diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 65603cd7..9d323bd9 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -3,7 +3,7 @@ package postgres import ( "database/sql" "fmt" - "path" + "path/filepath" "testing" "github.com/go-jet/jet/v2/generator/metadata" @@ -20,13 +20,13 @@ import ( const tempTestDir = "./.tempTestDir" -var defaultModelPath = path.Join(tempTestDir, "jetdb/dvds/model") -var defaultSqlBuilderPath = path.Join(tempTestDir, "jetdb/dvds/table") -var defaultActorModelFilePath = path.Join(tempTestDir, "jetdb/dvds/model", "actor.go") -var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table") -var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/view") -var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/enum") -var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table", "actor.go") +var defaultModelPath = filepath.Join(tempTestDir, "jetdb/dvds/model") +var defaultSqlBuilderPath = filepath.Join(tempTestDir, "jetdb/dvds/table") +var defaultActorModelFilePath = filepath.Join(tempTestDir, "jetdb/dvds/model", "actor.go") +var defaultTableSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/table") +var defaultViewSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/view") +var defaultEnumSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/enum") +var defaultActorSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/table", "actor.go") var dbConnection = postgres.DBConnection{ Host: dbconfig.PgHost, diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 87a7eeec..93fb9c3a 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -6,6 +6,7 @@ import ( "os/exec" "path/filepath" "reflect" + "regexp" "strconv" "testing" @@ -108,6 +109,76 @@ func TestCmdGenerator(t *testing.T) { require.NoError(t, err) } +func TestCmdGeneratorWithPkgNames(t *testing.T) { + err := os.RemoveAll(genTestDir2) + require.NoError(t, err) + + // Testing with custom package paths + modelPath := "./newmodel" + tablePath := "./newtable" + viewPath := "./newview" + enumPath := "./newenum" + + cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", + "-port="+strconv.Itoa(dbconfig.PgPort), + "-user=jet", + "-password=jet", + "-schema=dvds", + "-path="+genTestDir2, + "-rel-model-path="+modelPath, + "-rel-table-path="+tablePath, + "-rel-view-path="+viewPath, + "-rel-enum-path="+enumPath) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err = cmd.Run() + require.NoError(t, err) + + assertGeneratedFilesWithPkgNames( + t, + modelPath, + tablePath, + viewPath, + enumPath, + ) + + err = os.RemoveAll(genTestDir2) + require.NoError(t, err) + + // Testing with nested paths + modelPath = "./db/newmodel" + tablePath = "./db/newtable" + viewPath = "./db/newview" + enumPath = "./db/newenum" + + cmd = exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", + "-port="+strconv.Itoa(dbconfig.PgPort), + "-user=jet", + "-password=jet", + "-schema=dvds", + "-path="+genTestDir2, + "-rel-model-path="+modelPath, + "-rel-table-path="+tablePath, + "-rel-view-path="+viewPath, + "-rel-enum-path="+enumPath) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err = cmd.Run() + require.NoError(t, err) + + assertGeneratedFilesWithPkgNames( + t, + modelPath, + tablePath, + viewPath, + enumPath, + ) +} + func TestGeneratorIgnoreTables(t *testing.T) { tests := []struct { name string @@ -311,6 +382,90 @@ func assertGeneratedFiles(t *testing.T) { testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", actorModelFile) } +func assertGeneratedFilesWithPkgNames(t *testing.T, modelPkgPath, tablePkgPath, viewPkgPath, enumPkgPath string) { + // We can get the package names from the base of the package paths for + // replacing package names in the default file content strings + modelPkg := filepath.Base(modelPkgPath) + tablePkg := filepath.Base(tablePkgPath) + viewPkg := filepath.Base(viewPkgPath) + enumPkg := filepath.Base(enumPkgPath) + + // Table SQL Builder files + testutils.AssertFileNamesEqual( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath), + "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", "table_use_schema.go", + ) + + testutils.AssertFileContent( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath, "actor.go"), + getFileContentWithNewPkg(tablePkg, actorSQLBuilderFile), + ) + + testutils.AssertFileContent( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath, "table_use_schema.go"), + getFileContentWithNewPkg(tablePkg, tableUseSchemaFile), + ) + + // View SQL Builder files + testutils.AssertFileNamesEqual( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath), + "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", + "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go", "view_use_schema.go", + ) + + testutils.AssertFileContent(t, + filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath, "actor_info.go"), + getFileContentWithNewPkg(viewPkg, actorInfoSQLBuilderFile), + ) + + testutils.AssertFileContent( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath, "view_use_schema.go"), + getFileContentWithNewPkg(viewPkg, viewUseSchemaFile), + ) + + // Enums SQL Builder files + testutils.AssertFileNamesEqual( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", enumPkgPath), + "mpaa_rating.go", + ) + + testutils.AssertFileContent( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", enumPkgPath, "mpaa_rating.go"), + getFileContentWithNewPkg(enumPkg, mpaaRatingEnumFile), + ) + + // Model files + testutils.AssertFileNamesEqual( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", modelPkgPath), + "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go", + "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go", + ) + + testutils.AssertFileContent( + t, + filepath.Join("./.gentestdata2/jetdb/dvds/", modelPkgPath, "actor.go"), + getFileContentWithNewPkg(modelPkg, actorModelFile), + ) +} + +func getFileContentWithNewPkg(pkgName, fileContent string) string { + regex := regexp.MustCompile(`package \w+`) + return regex.ReplaceAllString(fileContent, "package "+pkgName) +} + var mpaaRatingEnumFile = ` // // Code generated by go-jet DO NOT EDIT.