Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CLI Flags for Package name customization for Model, Table, View and Enum #424

Merged
merged 8 commits into from
Nov 26, 2024
Merged
43 changes: 32 additions & 11 deletions cmd/jet/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/go-jet/jet/v2/internal/utils/errfmt"
"github.com/go-jet/jet/v2/internal/utils/strslice"
"os"
"path/filepath"
"strings"

"github.com/go-jet/jet/v2/generator/metadata"
Expand Down Expand Up @@ -42,7 +43,11 @@ var (
ignoreViews string
ignoreEnums string

destDir string
destDir string
modelPkg string
tablePkg string
viewPkg string
enumPkg string
)

func init() {
Expand All @@ -66,17 +71,28 @@ func init() {
flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public")(PostgreSQL only)`)
flag.StringVar(&params, "params", "", "Additional connection string parameters(optional). Used only if dsn is not set.")
flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL. Used only if dsn is not set. (optional)(default "disable")(PostgreSQL only)`)
flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore`)
flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore`)
flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore`)
flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore.`)
flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore.`)
flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore.`)

flag.StringVar(&destDir, "path", "", "Destination dir for files generated.")
}
flag.StringVar(&destDir, "path", "", "Destination directory for files generated.")
flag.StringVar(&modelPkg, "model-pkg", "model", "Relative path for the Model files package from the destination directory.")
flag.StringVar(&tablePkg, "table-pkg", "table", "Relative path for the Table files package from the destination directory.")
flag.StringVar(&viewPkg, "view-pkg", "view", "Relative path for the View files package from the destination directory.")
flag.StringVar(&enumPkg, "enum-pkg", "enum", "Relative path for the Enum files package from the destination directory.")

func main() {
flag.Usage = usage
flag.Parse()

// Convert the OS-specific path separator to slashes for cross-platform compatibility.
destDir = filepath.ToSlash(destDir)
modelPkg = filepath.ToSlash(modelPkg)
tablePkg = filepath.ToSlash(tablePkg)
viewPkg = filepath.ToSlash(viewPkg)
enumPkg = filepath.ToSlash(enumPkg)
go-jet marked this conversation as resolved.
Show resolved Hide resolved
}

func main() {
if dsn == "" && (source == "" || host == "" || port == 0 || user == "" || dbName == "") {
printErrorAndExit("ERROR: required flag(s) missing")
}
Expand Down Expand Up @@ -170,6 +186,7 @@ func usage() {
"source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode",
"path",
"ignore-tables", "ignore-views", "ignore-enums",
"model-pkg", "table-pkg", "view-pkg", "enum-pkg",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to update this names as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, my bad.
Pushed the changes already.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. LGTM. 👍

}

for _, name := range order {
Expand All @@ -186,6 +203,7 @@ func usage() {
$ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen
$ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen -model-pkg=./entity
`)
}

Expand Down Expand Up @@ -246,7 +264,7 @@ func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []strin
return template.Default(dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseModel(template.DefaultModel().UsePath(modelPkg).
UseTable(func(table metadata.Table) template.TableModel {
if shouldSkipTable(table) {
return template.TableModel{Skip: true}
Expand All @@ -271,19 +289,22 @@ func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []strin
if shouldSkipTable(table) {
return template.TableSQLBuilder{Skip: true}
}
return template.DefaultTableSQLBuilder(table)

return template.DefaultTableSQLBuilder(table).UsePath(tablePkg)
}).
UseView(func(table metadata.Table) template.ViewSQLBuilder {
if shouldSkipView(table) {
return template.ViewSQLBuilder{Skip: true}
}
return template.DefaultViewSQLBuilder(table)

return template.DefaultViewSQLBuilder(table).UsePath(viewPkg)
}).
UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder {
if shouldSkipEnum(enum) {
return template.EnumSQLBuilder{Skip: true}
}
return template.DefaultEnumSQLBuilder(enum)

return template.DefaultEnumSQLBuilder(enum).UsePath(enumPkg)
}),
)
})
Expand Down
155 changes: 155 additions & 0 deletions tests/postgres/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os/exec"
"path/filepath"
"reflect"
"regexp"
"strconv"
"testing"

Expand Down Expand Up @@ -108,6 +109,76 @@ func TestCmdGenerator(t *testing.T) {
require.NoError(t, err)
}

func TestCmdGeneratorWithPkgNames(t *testing.T) {
err := os.RemoveAll(genTestDir2)
require.NoError(t, err)

// Testing with custom package paths
modelPath := "./newmodel"
tablePath := "./newtable"
viewPath := "./newview"
enumPath := "./newenum"

cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost",
"-port="+strconv.Itoa(dbconfig.PgPort),
"-user=jet",
"-password=jet",
"-schema=dvds",
"-path="+genTestDir2,
"-model-pkg="+modelPath,
"-table-pkg="+tablePath,
"-view-pkg="+viewPath,
"-enum-pkg="+enumPath)

cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout

err = cmd.Run()
require.NoError(t, err)

assertGeneratedFilesWithPkgNames(
t,
modelPath,
tablePath,
viewPath,
enumPath,
)

err = os.RemoveAll(genTestDir2)
require.NoError(t, err)

// Testing with nested paths
modelPath = "./db/newmodel"
tablePath = "./db/newtable"
viewPath = "./db/newview"
enumPath = "./db/newenum"

cmd = exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost",
"-port="+strconv.Itoa(dbconfig.PgPort),
"-user=jet",
"-password=jet",
"-schema=dvds",
"-path="+genTestDir2,
"-model-pkg="+modelPath,
"-table-pkg="+tablePath,
"-view-pkg="+viewPath,
"-enum-pkg="+enumPath)

cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout

err = cmd.Run()
require.NoError(t, err)

assertGeneratedFilesWithPkgNames(
t,
modelPath,
tablePath,
viewPath,
enumPath,
)
}

func TestGeneratorIgnoreTables(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -311,6 +382,90 @@ func assertGeneratedFiles(t *testing.T) {
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", actorModelFile)
}

func assertGeneratedFilesWithPkgNames(t *testing.T, modelPkgPath, tablePkgPath, viewPkgPath, enumPkgPath string) {
// We can get the package names from the base of the package paths for
// replacing package names in the default file content strings
modelPkg := filepath.Base(modelPkgPath)
tablePkg := filepath.Base(tablePkgPath)
viewPkg := filepath.Base(viewPkgPath)
enumPkg := filepath.Base(enumPkgPath)

// Table SQL Builder files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath),
"actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "table_use_schema.go",
)

testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath, "actor.go"),
getFileContentWithNewPkg(tablePkg, actorSQLBuilderFile),
)

testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath, "table_use_schema.go"),
getFileContentWithNewPkg(tablePkg, tableUseSchemaFile),
)

// View SQL Builder files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath),
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go", "view_use_schema.go",
)

testutils.AssertFileContent(t,
filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath, "actor_info.go"),
getFileContentWithNewPkg(viewPkg, actorInfoSQLBuilderFile),
)

testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath, "view_use_schema.go"),
getFileContentWithNewPkg(viewPkg, viewUseSchemaFile),
)

// Enums SQL Builder files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", enumPkgPath),
"mpaa_rating.go",
)

testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", enumPkgPath, "mpaa_rating.go"),
getFileContentWithNewPkg(enumPkg, mpaaRatingEnumFile),
)

// Model files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", modelPkgPath),
"actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go",
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go",
)

testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", modelPkgPath, "actor.go"),
getFileContentWithNewPkg(modelPkg, actorModelFile),
)
}

func getFileContentWithNewPkg(pkgName, fileContent string) string {
regex := regexp.MustCompile(`package \w+`)
return regex.ReplaceAllString(fileContent, "package "+pkgName)
}

var mpaaRatingEnumFile = `
//
// Code generated by go-jet DO NOT EDIT.
Expand Down
Loading