Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type regexp and enums. #9

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 186 additions & 26 deletions dgw.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"go/format"
"io/ioutil"
"regexp"
"sort"
"strings"
"text/template"
Expand Down Expand Up @@ -37,6 +39,34 @@ func OpenDB(connStr string) (*sql.DB, error) {
return conn, nil
}

const pgLoadEnumDef = `
SELECT n.nspname AS schema,
pg_catalog.format_type ( t.oid, NULL ) AS name,
ARRAY( SELECT e.enumlabel
FROM pg_catalog.pg_enum e
WHERE e.enumtypid = t.oid
ORDER BY e.oid )
AS elements
FROM pg_catalog.pg_type t
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = t.typnamespace
WHERE ( t.typrelid = 0
OR ( SELECT c.relkind = 'c'
FROM pg_catalog.pg_class c
WHERE c.oid = t.typrelid
)
)
AND NOT EXISTS
( SELECT 1
FROM pg_catalog.pg_type el
WHERE el.oid = t.typelem
AND el.typarray = t.oid
)
AND n.nspname = $1
AND pg_catalog.pg_type_is_visible ( t.oid )
ORDER BY 1, 2;
`

const queryInterface = `
// Queryer database/sql compatible query interface
type Queryer interface {
Expand Down Expand Up @@ -106,6 +136,28 @@ type TypeMap struct {
DBTypes []string `toml:"db_types"`
NotNullGoType string `toml:"notnull_go_type"`
NullableGoType string `toml:"nullable_go_type"`

compiled bool
rePatterns []*regexp.Regexp
}

func (t *TypeMap) Match(s string) bool {
if !t.compiled {
for _, v := range t.DBTypes {
if strings.HasPrefix(v, "re/") {
t.rePatterns = append(t.rePatterns, regexp.MustCompile(v[3:]))
}
}
}
if contains(s, t.DBTypes) {
return true
}
for _, v := range t.rePatterns {
if v.MatchString(s) {
return true
}
}
return false
}

// AutoKeyMap auto generating key config
Expand All @@ -114,7 +166,7 @@ type AutoKeyMap struct {
}

// PgTypeMapConfig go/db type map struct toml config
type PgTypeMapConfig map[string]TypeMap
type PgTypeMapConfig map[string]*TypeMap

// PgTable postgres table
type PgTable struct {
Expand Down Expand Up @@ -186,6 +238,49 @@ func PgLoadTypeMapFromFile(filePath string) (*PgTypeMapConfig, error) {
return &conf, nil
}

type PgEnum struct {
Schema string
Name string
Values []string
}

type EnumValue struct {
Type *EnumType
Name string
Value string
}

type EnumType struct {
Name string
Enum *PgEnum
Comment string
Values []EnumValue
}

func PgLoadEnumDef(db Queryer, schema string) ([]*PgEnum, error) {
enumDefs, err := db.Query(pgLoadEnumDef, schema)
if err != nil {
return nil, errors.Wrap(err, "failed to load enum def")
}

enums := []*PgEnum{}
for enumDefs.Next() {
e := &PgEnum{}
var vals pq.StringArray
err := enumDefs.Scan(
&e.Schema,
&e.Name,
&vals,
)
e.Values = vals
if err != nil {
return nil, errors.Wrap(err, "failed to scan")
}
enums = append(enums, e)
}
return enums, nil
}

// PgLoadColumnDef load Postgres column definition
func PgLoadColumnDef(db Queryer, schema string, table string) ([]*PgColumn, error) {
colDefs, err := db.Query(pgLoadColumnDef, schema, table)
Expand Down Expand Up @@ -256,11 +351,10 @@ func contains(v string, l []string) bool {
}

// PgConvertType converts type
func PgConvertType(col *PgColumn, typeCfg *PgTypeMapConfig) string {
cfg := map[string]TypeMap(*typeCfg)
typ := cfg["default"].NotNullGoType
for _, v := range cfg {
if contains(col.DataType, v.DBTypes) {
func PgConvertType(col *PgColumn, typeCfg PgTypeMapConfig) string {
typ := typeCfg["default"].NotNullGoType
for _, v := range typeCfg {
if v.Match(col.DataType) {
if col.NotNull {
return v.NotNullGoType
}
Expand All @@ -271,7 +365,7 @@ func PgConvertType(col *PgColumn, typeCfg *PgTypeMapConfig) string {
}

// PgColToField converts pg column to go struct field
func PgColToField(col *PgColumn, typeCfg *PgTypeMapConfig) (*StructField, error) {
func PgColToField(col *PgColumn, typeCfg PgTypeMapConfig) (*StructField, error) {
stfType := PgConvertType(col, typeCfg)
stf := &StructField{
Name: varfmt.PublicVarName(col.Name),
Expand All @@ -282,7 +376,7 @@ func PgColToField(col *PgColumn, typeCfg *PgTypeMapConfig) (*StructField, error)
}

// PgTableToStruct converts table def to go struct
func PgTableToStruct(t *PgTable, typeCfg *PgTypeMapConfig, keyConfig *AutoKeyMap) (*Struct, error) {
func PgTableToStruct(t *PgTable, typeCfg PgTypeMapConfig, keyConfig *AutoKeyMap) (*Struct, error) {
t.setPrimaryKeyInfo(keyConfig)
s := &Struct{
Name: varfmt.PublicVarName(t.Name),
Expand All @@ -292,7 +386,7 @@ func PgTableToStruct(t *PgTable, typeCfg *PgTypeMapConfig, keyConfig *AutoKeyMap
for _, c := range t.Columns {
f, err := PgColToField(c, typeCfg)
if err != nil {
return nil, errors.Wrap(err, "faield to convert col to field")
return nil, errors.Wrap(err, "failed to convert col to field")
}
fs = append(fs, f)
}
Expand All @@ -301,7 +395,7 @@ func PgTableToStruct(t *PgTable, typeCfg *PgTypeMapConfig, keyConfig *AutoKeyMap
}

// PgExecuteDefaultTmpl execute struct template with *Struct
func PgExecuteDefaultTmpl(st *StructTmpl, path string) ([]byte, error) {
func PgExecuteDefaultTmpl(st interface{}, path string) ([]byte, error) {
var src []byte
d, err := Asset(path)
if err != nil {
Expand All @@ -323,7 +417,7 @@ func PgExecuteDefaultTmpl(st *StructTmpl, path string) ([]byte, error) {
}

// PgExecuteCustomTmpl execute custom template
func PgExecuteCustomTmpl(st *StructTmpl, customTmpl string) ([]byte, error) {
func PgExecuteCustomTmpl(st interface{}, customTmpl string) ([]byte, error) {
var src []byte
tpl, err := template.New("struct").Funcs(tmplFuncMap).Parse(customTmpl)
if err != nil {
Expand All @@ -340,34 +434,100 @@ func PgExecuteCustomTmpl(st *StructTmpl, customTmpl string) ([]byte, error) {
return src, nil
}

func getPgTypeMapConfig(typeMapPath string) (PgTypeMapConfig, error) {
cfg := make(PgTypeMapConfig)
if typeMapPath == "" {
if _, err := toml.Decode(typeMap, &cfg); err != nil {
return nil, errors.Wrap(err, "failed to read type map")
}
} else {
if _, err := toml.DecodeFile(typeMapPath, &cfg); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("failed to decode type map file %s", typeMapPath))
}
}
return cfg, nil
}

func PgEnumToType(e *PgEnum, typeCfg PgTypeMapConfig, keyConfig *AutoKeyMap) (*EnumType, error) {
en := &EnumType{
Name: varfmt.PublicVarName(e.Name),
Enum: e,
}
for _, v := range e.Values {
en.Values = append(en.Values, EnumValue{
Type: en,
Name: en.Name + "_" + varfmt.PublicVarName(v),
Value: v,
})
}
if _,ok := typeCfg[e.Name]; !ok {
typeCfg[e.Name] = &TypeMap{
DBTypes: []string{e.Name},
NotNullGoType: en.Name,
NullableGoType: "Null"+en.Name,

compiled: true,
rePatterns: nil,
}
}

return en, nil
}

func PgCreateEnums(db Queryer, schema string, cfg PgTypeMapConfig, customTmpl string) ([]byte, error) {
var src []byte

enums, err := PgLoadEnumDef(db, schema)
if err != nil {
return src, errors.Wrap(err, "failed to load enum definitions")
}

for _, pgEnum := range enums {
enum, err := PgEnumToType(pgEnum, cfg, autoGenKeyCfg)
if err != nil {
return src, errors.Wrap(err, "failed to convert enum definition to type")
}

if customTmpl != "" {
tmpl, err := ioutil.ReadFile(customTmpl)
if err != nil {
return nil, err
}
s, err := PgExecuteCustomTmpl(enum, string(tmpl))
if err != nil {
return nil, errors.Wrap(err, "PgExecuteCustomTmpl failed")
}
src = append(src, s...)
} else {
s, err := PgExecuteDefaultTmpl(enum, "template/enum.tmpl")
if err != nil {
return src, errors.Wrap(err, "failed to execute template")
}
src = append(src, s...)
}
}
return src, nil
}

// PgCreateStruct creates struct from given schema
func PgCreateStruct(
db Queryer, schema, typeMapPath, pkgName, customTmpl string, exTbls []string) ([]byte, error) {
db Queryer, schema string, cfg PgTypeMapConfig, pkgName, customTmpl string, exTbls []string) ([]byte, error) {
var src []byte
pkgDef := []byte(fmt.Sprintf("package %s\n\n", pkgName))
src = append(src, pkgDef...)

tbls, err := PgLoadTableDef(db, schema)
if err != nil {
return src, errors.Wrap(err, "faield to load table definitions")
}
cfg := &PgTypeMapConfig{}
if typeMapPath == "" {
if _, err := toml.Decode(typeMap, cfg); err != nil {
return src, errors.Wrap(err, "faield to read type map")
}
} else {
if _, err := toml.DecodeFile(typeMapPath, cfg); err != nil {
return src, errors.Wrap(err, fmt.Sprintf("failed to decode type map file %s", typeMapPath))
}
return src, errors.Wrap(err, "failed to load table definitions")
}

for _, tbl := range tbls {
if contains(tbl.Name, exTbls) {
continue
}
st, err := PgTableToStruct(tbl, cfg, autoGenKeyCfg)
if err != nil {
return src, errors.Wrap(err, "faield to convert table definition to struct")
return src, errors.Wrap(err, "failed to convert table definition to struct")
}
if customTmpl != "" {
tmpl, err := ioutil.ReadFile(customTmpl)
Expand All @@ -382,11 +542,11 @@ func PgCreateStruct(
} else {
s, err := PgExecuteDefaultTmpl(&StructTmpl{Struct: st}, "template/struct.tmpl")
if err != nil {
return src, errors.Wrap(err, "faield to execute template")
return src, errors.Wrap(err, "failed to execute template")
}
m, err := PgExecuteDefaultTmpl(&StructTmpl{Struct: st}, "template/method.tmpl")
if err != nil {
return src, errors.Wrap(err, "faield to execute template")
return src, errors.Wrap(err, "failed to execute template")
}
src = append(src, s...)
src = append(src, m...)
Expand Down
10 changes: 5 additions & 5 deletions dgw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func testSetupStruct(t *testing.T, conn *sql.DB) []*Struct {

var sts []*Struct
for _, tbl := range tbls {
st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg)
st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestPgColToField(t *testing.T) {
}

for _, c := range cols {
f, err := PgColToField(c, &defaultTypeMapCfg)
f, err := PgColToField(c, defaultTypeMapCfg)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestPgTableToStruct(t *testing.T) {
}

for _, tbl := range tbls {
st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg)
st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg)
if err != nil {
t.Fatal(err)
}
Expand All @@ -142,7 +142,7 @@ func TestPgTableToMethod(t *testing.T) {
t.Fatal(err)
}
for _, tbl := range tbls {
st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg)
st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -173,7 +173,7 @@ func TestPgExecuteCustomTemplate(t *testing.T) {
t.Fatal(err)
}
for _, tbl := range tbls {
st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg)
st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading