Skip to content

Commit

Permalink
Added Bi-Directional Foreign Keys (#63)
Browse files Browse the repository at this point in the history
Added Bi-Directional Foreign Keys

The foreign constraint database info is read into ForeignKey. Each column is provided with its foreign key
definition.

`Table` exposes tables it references and other tables that reference it. Table also exposes all columns it
owns that are foreign keys.

`Column` exposes both the other columns it references and other columns that reference it.

The structure defines column level foreign keys as `ForeignKeyColumn` and table level foreign keys as `ForeignKey`.
ForeignKey holds a slice of all mapped `ForeignKeyColumns`. This will allow easy iteration of the columns included in a `ForeignKey` at the table level.
  • Loading branch information
daniel-reed authored and natefinch committed Oct 29, 2017
1 parent 80d3e58 commit a894efd
Show file tree
Hide file tree
Showing 8 changed files with 660 additions and 48 deletions.
66 changes: 66 additions & 0 deletions database/drivers/mysql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,36 @@ func parse(log *log.Logger, conn string, schemaNames []string, filterTables func
}
}

foreignKeys, err := queryForeignKeys(log, db, schemaNames)
if err != nil {
return nil, err
}
for _, fk := range foreignKeys {
if !filterTables(fk.SchemaName, fk.TableName) {
log.Printf("skipping constraint %q because it is for filtered-out table %v.%v", fk.Name, fk.SchemaName, fk.TableName)
continue
}

schema, ok := schemas[fk.SchemaName]
if !ok {
log.Printf("Should be impossible: constraint %q references unknown schema %q", fk.Name, fk.SchemaName)
continue
}
table, ok := schema[fk.TableName]
if !ok {
log.Printf("Should be impossible: constraint %q references unknown table %q in schema %q", fk.Name, fk.TableName, fk.SchemaName)
continue
}

for _, col := range table {
if fk.ColumnName != col.Name {
continue
}
col.IsForeignKey = true
col.ForeignKey = fk
}
}

res := &database.Info{Schemas: make([]*database.Schema, 0, len(schemas))}
for _, schema := range schemaNames {
tables := schemas[schema]
Expand Down Expand Up @@ -151,3 +181,39 @@ func toDBColumn(c *columns.Row, log *log.Logger) (*database.Column, *database.En

return col, enum, nil
}

func queryForeignKeys(log *log.Logger, db *sql.DB, schemas []string) ([]*database.ForeignKey, error) {
// TODO: make this work with Gnorm generated types
const q = `SELECT lkc.TABLE_SCHEMA, lkc.TABLE_NAME, lkc.COLUMN_NAME, lkc.CONSTRAINT_NAME, lkc.POSITION_IN_UNIQUE_CONSTRAINT, lkc.REFERENCED_TABLE_NAME, lkc.REFERENCED_COLUMN_NAME
FROM information_schema.REFERENTIAL_CONSTRAINTS as rc
LEFT JOIN information_schema.KEY_COLUMN_USAGE as lkc
ON lkc.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA
AND lkc.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
WHERE rc.CONSTRAINT_SCHEMA IN (%s)`
spots := make([]string, len(schemas))
vals := make([]interface{}, len(schemas))
for x := range schemas {
spots[x] = "?"
vals[x] = schemas[x]
}
query := fmt.Sprintf(q, strings.Join(spots, ", "))
rows, err := db.Query(query, vals...)
if err != nil {
return nil, errors.WithMessage(err, "error querying foreign keys")
}
defer rows.Close()
var ret []*database.ForeignKey

for rows.Next() {
fk := &database.ForeignKey{}
if err := rows.Scan(&fk.SchemaName, &fk.TableName, &fk.ColumnName, &fk.Name, &fk.UniqueConstraintPosition, &fk.ForeignTableName, &fk.ForeignColumnName); err != nil {
return nil, errors.WithMessage(err, "error scanning foreign key constraint")
}
ret = append(ret, fk)
}
if rows.Err() != nil {
return nil, errors.WithMessage(rows.Err(), "error reading foreign keys")
}

return ret, nil
}
73 changes: 71 additions & 2 deletions database/drivers/postgres/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,36 @@ func parse(log *log.Logger, conn string, schemaNames []string, filterTables func
}
}

foreignKeys, err := queryForeignKeys(log, db, schemaNames)
if err != nil {
return nil, err
}
for _, fk := range foreignKeys {
if !filterTables(fk.SchemaName, fk.TableName) {
log.Printf("skipping constraint %q because it is for filtered-out table %v.%v", fk.Name, fk.SchemaName, fk.TableName)
continue
}

schema, ok := schemas[fk.SchemaName]
if !ok {
log.Printf("Should be impossible: constraint %q references unknown schema %q", fk.Name, fk.SchemaName)
continue
}
table, ok := schema[fk.TableName]
if !ok {
log.Printf("Should be impossible: constraint %q references unknown table %q in schema %q", fk.Name, fk.TableName, fk.SchemaName)
continue
}

for _, col := range table {
if fk.ColumnName != col.Name {
continue
}
col.IsForeignKey = true
col.ForeignKey = fk
}
}

enums, err := queryEnums(log, db, schemaNames)
if err != nil {
return nil, err
Expand Down Expand Up @@ -186,10 +216,10 @@ func queryPrimaryKeys(log *log.Logger, db *sql.DB, schemas []string) ([]*databas
}
query := fmt.Sprintf(q, strings.Join(spots, ", "))
rows, err := db.Query(query, vals...)
defer rows.Close()
if err != nil {
return nil, errors.WithMessage(err, "error querying keys")
}
defer rows.Close()
var ret []*database.PrimaryKey

for rows.Next() {
Expand All @@ -202,6 +232,45 @@ func queryPrimaryKeys(log *log.Logger, db *sql.DB, schemas []string) ([]*databas
return ret, nil
}

func queryForeignKeys(log *log.Logger, db *sql.DB, schemas []string) ([]*database.ForeignKey, error) {
// TODO: make this work with Gnorm generated types
const q = `SELECT rc.constraint_schema, lkc.table_name, lkc.column_name, lkc.constraint_name, lkc.position_in_unique_constraint, fkc.table_name, fkc.column_name
FROM information_schema.referential_constraints rc
LEFT JOIN information_schema.key_column_usage lkc
ON lkc.table_schema = rc.constraint_schema
AND lkc.constraint_name = rc.constraint_name
LEFT JOIN information_schema.key_column_usage fkc
ON fkc.table_schema = rc.constraint_schema
AND fkc.ordinal_position = lkc.position_in_unique_constraint
AND fkc.constraint_name = rc.unique_constraint_name
WHERE rc.constraint_schema IN (%s)`
spots := make([]string, len(schemas))
vals := make([]interface{}, len(schemas))
for x := range schemas {
spots[x] = fmt.Sprintf("$%v", x+1)
vals[x] = schemas[x]
}
query := fmt.Sprintf(q, strings.Join(spots, ", "))
rows, err := db.Query(query, vals...)
if err != nil {
return nil, errors.WithMessage(err, "error querying foreign keys")
}
defer rows.Close()
var ret []*database.ForeignKey

for rows.Next() {
fk := &database.ForeignKey{}
if err := rows.Scan(&fk.SchemaName, &fk.TableName, &fk.ColumnName, &fk.Name, &fk.UniqueConstraintPosition, &fk.ForeignTableName, &fk.ForeignColumnName); err != nil {
return nil, errors.WithMessage(err, "error scanning foreign key constraint")
}
ret = append(ret, fk)
}
if rows.Err() != nil {
return nil, errors.WithMessage(rows.Err(), "error reading foreign keys")
}
return ret, nil
}

func queryEnums(log *log.Logger, db *sql.DB, schemas []string) (map[string][]*database.Enum, error) {
// TODO: make this work with Gnorm generated types
const q = `
Expand All @@ -220,10 +289,10 @@ func queryEnums(log *log.Logger, db *sql.DB, schemas []string) (map[string][]*da
}
query := fmt.Sprintf(q, strings.Join(spots, ", "))
rows, err := db.Query(query, vals...)
defer rows.Close()
if err != nil {
return nil, errors.WithMessage(err, "error querying enum names")
}
defer rows.Close()
ret := map[string][]*database.Enum{}
for rows.Next() {
var name, schema string
Expand Down
13 changes: 13 additions & 0 deletions database/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ type PrimaryKey struct {
Name string // the original name of the key constraint in the db
}

// Foreign Key contains the definition of a database foreign key
type ForeignKey struct {
SchemaName string // the original name of the schema in the db
TableName string // the original name of the table in the db
ColumnName string // the original name of the column in the db
Name string // the original name of the foreign key constraint in the db
UniqueConstraintPosition int // the position of the unique constraint in the db
ForeignTableName string // the original name of the table in the db for the referenced table
ForeignColumnName string // the original name of the column in the db for the referenced column
}

// Column contains data about a column in a table.
type Column struct {
Name string // the original name of the column in the DB
Expand All @@ -52,6 +63,8 @@ type Column struct {
Nullable bool // true if the column is not NON NULL
HasDefault bool // true if the column has a default
IsPrimaryKey bool // true if the column is a primary key
IsForeignKey bool // true if the column is a foreign key
ForeignKey *ForeignKey // foreign key database definition
Orig interface{} // the raw database column data
}

Expand Down
124 changes: 113 additions & 11 deletions run/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"gnorm.org/gnorm/run/data"
)

type nameConverter func(s string) (string, error)

func makeData(log *log.Logger, info *database.Info, cfg *Config) (*data.DBData, error) {
convert := func(s string) (string, error) {
buf := &bytes.Buffer{}
Expand Down Expand Up @@ -62,6 +64,8 @@ func makeData(log *log.Logger, info *database.Info, cfg *Config) (*data.DBData,
DBName: t.Name,
Schema: sch,
ColumnsByName: make(map[string]*data.Column, len(t.Columns)),
FKByName: map[string]*data.ForeignKey{},
FKRefsByName: map[string]*data.ForeignKey{},
}
sch.Tables = append(sch.Tables, table)
sch.TablesByName[table.DBName] = table
Expand All @@ -71,15 +75,18 @@ func makeData(log *log.Logger, info *database.Info, cfg *Config) (*data.DBData,
}
for _, c := range t.Columns {
col := &data.Column{
DBName: c.Name,
DBType: c.Type,
IsArray: c.IsArray,
Length: c.Length,
UserDefined: c.UserDefined,
Nullable: c.Nullable,
HasDefault: c.HasDefault,
IsPrimaryKey: c.IsPrimaryKey,
Orig: c.Orig,
Table: table,
DBName: c.Name,
DBType: c.Type,
IsArray: c.IsArray,
Length: c.Length,
UserDefined: c.UserDefined,
Nullable: c.Nullable,
HasDefault: c.HasDefault,
IsPrimaryKey: c.IsPrimaryKey,
IsFK: c.IsForeignKey,
FKColumnRefsByName: map[string]*data.ForeignKeyColumn{},
Orig: c.Orig,
}
table.Columns = append(table.Columns, col)
table.ColumnsByName[col.DBName] = col
Expand All @@ -102,12 +109,15 @@ func makeData(log *log.Logger, info *database.Info, cfg *Config) (*data.DBData,
}
table.PrimaryKeys = filterPrimaryKeyColumns(table.Columns)
}
if err = mapSchemaForeignKeyReferences(s, sch, convert); err != nil {
return nil, err
}
}
return db, nil
}

func filterPrimaryKeyColumns(columns []*data.Column) []*data.Column {
var pkColumns []*data.Column
func filterPrimaryKeyColumns(columns data.Columns) data.Columns {
var pkColumns data.Columns
for _, column := range columns {
if column.IsPrimaryKey {
pkColumns = append(pkColumns, column)
Expand All @@ -116,3 +126,95 @@ func filterPrimaryKeyColumns(columns []*data.Column) []*data.Column {

return pkColumns
}

func mapSchemaForeignKeyReferences(isch *database.Schema, sch *data.Schema, convert nameConverter) error {
for _, t := range isch.Tables {
table, ok := sch.TablesByName[t.Name]
if !ok {
log.Printf("Unmapped table %v in %v", t.Name, isch.Name)
continue
}

fkColumnsByFKNames := map[string]data.ForeignKeyColumns{}

for _, c := range t.Columns {
column, ok := table.ColumnsByName[c.Name]
if !ok {
log.Printf("Unmapped column %v in %v.%v", c.Name, isch.Name, t.Name)
continue
}

if column.IsFK {
refTable, ok := sch.TablesByName[c.ForeignKey.ForeignTableName]
if !ok {
log.Printf("Unmapped foreign table %v in %v", c.ForeignKey.ForeignTableName, isch.Name)
continue
}
refColumn, ok := refTable.ColumnsByName[c.ForeignKey.ForeignColumnName]
if !ok {
log.Printf("Unmapped foreign column %v in %v.%v", c.ForeignKey.ForeignColumnName, isch.Name, c.ForeignKey.ForeignTableName)
continue
}

fkColumn := &data.ForeignKeyColumn{
DBName: c.ForeignKey.Name,
ColumnDBName: column.DBName,
RefColumnDBName: refColumn.DBName,
Column: column,
RefColumn: refColumn,
}
column.FKColumn = fkColumn

refColumn.HasFKRef = true
refColumn.FKColumnRefs = append(refColumn.FKColumnRefs, fkColumn)
refColumn.FKColumnRefsByName[fkColumn.DBName] = fkColumn

if _, ok := fkColumnsByFKNames[fkColumn.DBName]; !ok {
fkColumnsByFKNames[fkColumn.DBName] = data.ForeignKeyColumns{fkColumn}
} else {
fkColumnsByFKNames[fkColumn.DBName] = append(fkColumnsByFKNames[fkColumn.DBName], fkColumn)
}
}
}

for _, fkc := range fkColumnsByFKNames {
err := mapForeignTable(fkc, convert)
if err != nil {
return err
}
}
}

return nil
}

func mapForeignTable(fkc data.ForeignKeyColumns, convert nameConverter) error {
if len(fkc) == 0 {
return nil
}

// All ForeignKeyColumns will point to same table/refTable and have the same name, use first one
table := fkc[0].Column.Table
refTable := fkc[0].RefColumn.Table
cName, err := convert(fkc[0].DBName)
if err != nil {
return errors.Wrap(err, "foreign key")
}

fk := &data.ForeignKey{
DBName: fkc[0].DBName,
Name: cName,
TableDBName: table.DBName,
RefTableDBName: refTable.DBName,
Table: table,
RefTable: refTable,
FKColumns: fkc,
}

table.ForeignKeys = append(table.ForeignKeys, fk)
refTable.ForeignKeyRefs = append(table.ForeignKeyRefs, fk)
table.FKByName[fk.DBName] = fk
refTable.FKRefsByName[fk.DBName] = fk

return nil
}
Loading

0 comments on commit a894efd

Please sign in to comment.