Skip to content

Commit

Permalink
sqlparser: new Equality API (vitessio#11906)
Browse files Browse the repository at this point in the history
* sqlparser: use a new Equals API

Signed-off-by: Vicent Marti <[email protected]>

* tools: refactor how asthelpergen is invoked

Signed-off-by: Vicent Marti <[email protected]>

* goimports, flags and linter

Signed-off-by: Andres Taylor <[email protected]>

Signed-off-by: Vicent Marti <[email protected]>
Signed-off-by: Andres Taylor <[email protected]>
Co-authored-by: Andres Taylor <[email protected]>
  • Loading branch information
vmg and systay authored Dec 8, 2022
1 parent c6adaa7 commit 8c1316c
Show file tree
Hide file tree
Showing 26 changed files with 1,864 additions and 1,828 deletions.
10 changes: 2 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,13 @@ parser:
demo:
go install ./examples/demo/demo.go

codegen: asthelpergen sizegen parser astfmtgen
codegen: asthelpergen sizegen parser

visitor: asthelpergen
echo "make visitor has been replaced by make asthelpergen"

asthelpergen:
go run ./go/tools/asthelpergen/main \
--in ./go/vt/sqlparser \
--iface vitess.io/vitess/go/vt/sqlparser.SQLNode \
--except "*ColName"
go generate ./go/vt/sqlparser/...

sizegen:
go run ./go/tools/sizegen/sizegen.go \
Expand All @@ -197,9 +194,6 @@ sizegen:
--gen vitess.io/vitess/go/vt/vttablet/tabletserver.TabletPlan \
--gen vitess.io/vitess/go/sqltypes.Result

astfmtgen:
go run ./go/tools/astfmtgen/main.go vitess.io/vitess/go/vt/sqlparser/...

# To pass extra flags, run test.go manually.
# For example: go run test.go -docker=false -- --extra-flag
# For more info see: go run test.go -help
Expand Down
29 changes: 18 additions & 11 deletions go/tools/asthelpergen/asthelpergen.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,23 @@ var acceptableBuildErrorsOn = map[string]any{
"ast_visit.go": nil,
}

type Options struct {
Packages []string
RootInterface string

Clone CloneOptions
Equals EqualsOptions
}

// GenerateASTHelpers loads the input code, constructs the necessary generators,
// and generates the rewriter and clone methods for the AST
func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType string) (map[string]*jen.File, error) {
func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
loaded, err := packages.Load(&packages.Config{
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule,
}, packagePatterns...)
}, options.Packages...)

if err != nil {
log.Fatal("error loading package")
return nil, err
return nil, fmt.Errorf("failed to load packages: %w", err)
}

checkErrors(loaded, func(fileName string) bool {
Expand All @@ -212,17 +219,17 @@ func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType str
scopes[pkg.PkgPath] = pkg.Types.Scope()
}

pos := strings.LastIndexByte(rootIface, '.')
pos := strings.LastIndexByte(options.RootInterface, '.')
if pos < 0 {
return nil, fmt.Errorf("unexpected input type: %s", rootIface)
return nil, fmt.Errorf("unexpected input type: %s", options.RootInterface)
}

pkgname := rootIface[:pos]
typename := rootIface[pos+1:]
pkgname := options.RootInterface[:pos]
typename := options.RootInterface[pos+1:]

scope := scopes[pkgname]
if scope == nil {
return nil, fmt.Errorf("no scope found for type '%s'", rootIface)
return nil, fmt.Errorf("no scope found for type '%s'", options.RootInterface)
}

tt := scope.Lookup(typename)
Expand All @@ -233,8 +240,8 @@ func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType str
nt := tt.Type().(*types.Named)
pName := nt.Obj().Pkg().Name()
generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt,
newEqualsGen(pName),
newCloneGen(pName, exceptCloneType),
newEqualsGen(pName, &options.Equals),
newCloneGen(pName, &options.Clone),
newVisitGen(pName),
newRewriterGen(pName, types.TypeString(nt, noQualifier)),
)
Expand Down
8 changes: 7 additions & 1 deletion go/tools/asthelpergen/asthelpergen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ import (
)

func TestFullGeneration(t *testing.T) {
result, err := GenerateASTHelpers([]string{"./integration/..."}, "vitess.io/vitess/go/tools/asthelpergen/integration.AST", "*NoCloneType")
result, err := GenerateASTHelpers(&Options{
Packages: []string{"./integration/..."},
RootInterface: "vitess.io/vitess/go/tools/asthelpergen/integration.AST",
Clone: CloneOptions{
Exclude: []string{"*NoCloneType"},
},
})
require.NoError(t, err)

verifyErrors := VerifyFilesOnDisk(result)
Expand Down
17 changes: 11 additions & 6 deletions go/tools/asthelpergen/clone_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,31 @@ import (
"strings"

"github.com/dave/jennifer/jen"
"golang.org/x/exp/slices"
)

type CloneOptions struct {
Exclude []string
}

// cloneGen creates the deep clone methods for the AST. It works by discovering the types that it needs to support,
// starting from a root interface type. While creating the clone method for this root interface, more types that need
// to be cloned are discovered. This continues type by type until all necessary types have been traversed.
type cloneGen struct {
exceptType string
file *jen.File
exclude []string
file *jen.File
}

var _ generator = (*cloneGen)(nil)

func newCloneGen(pkgname string, exceptType string) *cloneGen {
func newCloneGen(pkgname string, options *CloneOptions) *cloneGen {
file := jen.NewFile(pkgname)
file.HeaderComment(licenseFileHeader)
file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")

return &cloneGen{
exceptType: exceptType,
file: file,
exclude: options.Exclude,
file: file,
}
}

Expand Down Expand Up @@ -222,7 +227,7 @@ func (c *cloneGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gene
//func CloneRefOfType(n *Type) *Type
funcDeclaration := jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType)

if receiveType == c.exceptType {
if slices.Contains(c.exclude, receiveType) {
c.addFunc(funcName, funcDeclaration.Block(
jen.Return(jen.Id("n")),
))
Expand Down
62 changes: 45 additions & 17 deletions go/tools/asthelpergen/equals_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,32 @@ import (
"github.com/dave/jennifer/jen"
)

const equalsName = "Equals"
const Comparator = "Comparator"

type EqualsOptions struct {
AllowCustom []string
}

type equalsGen struct {
file *jen.File
file *jen.File
comparators map[string]types.Type
}

var _ generator = (*equalsGen)(nil)

func newEqualsGen(pkgname string) *equalsGen {
func newEqualsGen(pkgname string, options *EqualsOptions) *equalsGen {
file := jen.NewFile(pkgname)
file.HeaderComment(licenseFileHeader)
file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")

customComparators := make(map[string]types.Type, len(options.AllowCustom))
for _, tt := range options.AllowCustom {
customComparators[tt] = nil
}

return &equalsGen{
file: file,
file: file,
comparators: customComparators,
}
}

Expand All @@ -47,13 +58,27 @@ func (e *equalsGen) addFunc(name string, code *jen.Statement) {
e.file.Add(code)
}

func (e *equalsGen) customComparatorField(t types.Type) string {
return printableTypeName(t) + "_"
}

func (e *equalsGen) genFile() (string, *jen.File) {
e.file.Type().Id(Comparator).StructFunc(func(g *jen.Group) {
for tname, t := range e.comparators {
if t == nil {
continue
}
method := e.customComparatorField(t)
g.Add(jen.Id(method).Func().Call(jen.List(jen.Id("a"), jen.Id("b")).Id(tname)).Bool())
}
})

return "ast_equals.go", e.file
}

func (e *equalsGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
/*
func EqualsAST(inA, inB AST, f ASTComparison) bool {
func (cmp *Comparator) AST(inA, inB AST) bool {
if inA == inB {
return true
}
Expand All @@ -66,7 +91,7 @@ func (e *equalsGen) interfaceMethod(t types.Type, iface *types.Interface, spi ge
if !ok {
return false
}
return EqualsSubImpl(a, b, f)
return cmp.SubImpl(a, b)
}
return false
}
Expand Down Expand Up @@ -116,11 +141,11 @@ func compareValueType(t types.Type, a, b *jen.Statement, eq bool, spi generatorS
return a.Op("!=").Add(b)
}
spi.addType(t)
var neg = "!"
if eq {
neg = ""
fcall := jen.Id("cmp").Dot(printableTypeName(t)).Call(a, b)
if !eq {
return jen.Op("!").Add(fcall)
}
return jen.Id(neg+equalsName+printableTypeName(t)).Call(a, b, jen.Id("f"))
return fcall
}

func (e *equalsGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
Expand Down Expand Up @@ -180,8 +205,6 @@ func compareAllStructFields(strct *types.Struct, spi generatorSPI) jen.Code {
}

func (e *equalsGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
typeString := types.TypeString(t, noQualifier)

/*
func EqualsRefOfType(a, b *Type, f ASTComparison) *Type {
if a == b {
Expand All @@ -206,10 +229,15 @@ func (e *equalsGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gen
jen.If(jen.Id("a == nil").Op("||").Id("b == nil")).Block(jen.Return(jen.False())),
}

if typeString == "*ColName" {
typeString := types.TypeString(t, noQualifier)

if _, ok := e.comparators[typeString]; ok {
e.comparators[typeString] = t

method := e.customComparatorField(t)
stmts = append(stmts,
jen.If(jen.Id("f").Op("!=").Nil()).Block(
jen.Return(jen.Id("f").Dot("ColNames").Call(jen.Id("a"), jen.Id("b"))),
jen.If(jen.Id("cmp").Dot(method).Op("!=").Nil()).Block(
jen.Return(jen.Id("cmp").Dot(method).Call(jen.Id("a"), jen.Id("b"))),
))
}

Expand Down Expand Up @@ -243,10 +271,10 @@ func (e *equalsGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generator

func (e *equalsGen) declareFunc(t types.Type, aArg, bArg string) (*jen.Statement, string) {
typeString := types.TypeString(t, noQualifier)
funcName := equalsName + printableTypeName(t)
funcName := printableTypeName(t)

// func EqualsFunNameS(a, b <T>, f ASTComparison) bool
return jen.Func().Id(funcName).Call(jen.Id(aArg), jen.Id(bArg).Id(typeString), jen.Id("f").Id("ASTComparison")).Bool(), funcName
return jen.Func().Params(jen.Id("cmp").Op("*").Id(Comparator)).Id(funcName).Call(jen.Id(aArg), jen.Id(bArg).Id(typeString)).Bool(), funcName
}

func (e *equalsGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
Expand Down
Loading

0 comments on commit 8c1316c

Please sign in to comment.