From 7135c524f86b19feba49660d5b8643c1f77b2054 Mon Sep 17 00:00:00 2001 From: rebelice Date: Wed, 27 May 2026 17:22:38 +0900 Subject: [PATCH] feat(oracle): add ast walker --- oracle/ast/cmd/genwalker/main.go | 248 +++++ oracle/ast/walk.go | 57 ++ oracle/ast/walk_generated.go | 1427 +++++++++++++++++++++++++++++ oracle/ast/walk_generated_test.go | 202 ++++ oracle/ast/walk_test.go | 92 ++ 5 files changed, 2026 insertions(+) create mode 100644 oracle/ast/cmd/genwalker/main.go create mode 100644 oracle/ast/walk.go create mode 100644 oracle/ast/walk_generated.go create mode 100644 oracle/ast/walk_generated_test.go create mode 100644 oracle/ast/walk_test.go diff --git a/oracle/ast/cmd/genwalker/main.go b/oracle/ast/cmd/genwalker/main.go new file mode 100644 index 00000000..c96681b9 --- /dev/null +++ b/oracle/ast/cmd/genwalker/main.go @@ -0,0 +1,248 @@ +// Command genwalker generates walk_generated.go from parsenodes.go and node.go. +// +// It scans all struct types, identifies fields whose types are Node-like +// (Node, ExprNode, TableExpr interfaces; *List; pointers to AST structs; +// slices of Node-like types), and generates the walkChildren function. +// +// Usage: +// +// go run ./oracle/ast/cmd/genwalker +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "sort" + "strings" +) + +// nodeInterfaces are the interface types that represent walkable child nodes. +var nodeInterfaces = map[string]bool{ + "Node": true, + "ExprNode": true, + "TableExpr": true, + "StmtNode": true, +} + +// excludedStructs are struct types that are NOT AST nodes (no walking needed). +var excludedStructs = map[string]bool{ + "Loc": true, + "List": true, // List is walked via walkList, not walkChildren + "String": true, // value node, no children + "Integer": true, // value node, no children + "Float": true, // value node, no children + "Boolean": true, // value node, no children +} + +func main() { + fset := token.NewFileSet() + + sources := []string{"parsenodes.go", "node.go"} + var files []*ast.File + for _, src := range sources { + f, err := parser.ParseFile(fset, src, nil, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "parse %s: %v\n", src, err) + os.Exit(1) + } + files = append(files, f) + } + + // Collect all struct type names. + structNames := map[string]bool{} + type field struct { + Name string + Type string // e.g., "Node", "ExprNode", "*SelectStmt", "[]ExprNode", "[]*OrderByItem" + Kind string // "interface", "pointer", "slice_interface", "slice_pointer", "*List" + } + type structInfo struct { + Name string + Fields []field + } + + for _, f := range files { + for _, decl := range f.Decls { + gd, ok := decl.(*ast.GenDecl) + if !ok || gd.Tok != token.TYPE { + continue + } + for _, spec := range gd.Specs { + ts := spec.(*ast.TypeSpec) + if _, ok := ts.Type.(*ast.StructType); ok { + structNames[ts.Name.Name] = true + } + } + } + } + + var structs []structInfo + + for _, f := range files { + for _, decl := range f.Decls { + gd, ok := decl.(*ast.GenDecl) + if !ok || gd.Tok != token.TYPE { + continue + } + for _, spec := range gd.Specs { + ts := spec.(*ast.TypeSpec) + st, ok := ts.Type.(*ast.StructType) + if !ok { + continue + } + + var fields []field + for _, fl := range st.Fields.List { + if len(fl.Names) == 0 { + continue // embedded + } + typStr := typeString(fl.Type) + kind := classifyType(typStr, structNames) + if kind != "" { + for _, name := range fl.Names { + fields = append(fields, field{Name: name.Name, Type: typStr, Kind: kind}) + } + } + } + structs = append(structs, structInfo{Name: ts.Name.Name, Fields: fields}) + } + } + } + + sort.Slice(structs, func(i, j int) bool { + return structs[i].Name < structs[j].Name + }) + + // Generate code. + var buf bytes.Buffer + buf.WriteString("// Code generated by genwalker; DO NOT EDIT.\n\n") + buf.WriteString("package ast\n\n") + buf.WriteString("// walkChildren walks the child nodes of node, calling Walk(v, child)\n") + buf.WriteString("// for each child. This function is generated from parsenodes.go.\n") + buf.WriteString("func walkChildren(v Visitor, node Node) {\n") + buf.WriteString("\tswitch n := node.(type) {\n") + + for _, s := range structs { + if len(s.Fields) == 0 { + continue + } + fmt.Fprintf(&buf, "\tcase *%s:\n", s.Name) + for _, f := range s.Fields { + switch f.Kind { + case "interface": + // Node, ExprNode, TableExpr, StmtNode — Walk handles nil via interface check + fmt.Fprintf(&buf, "\t\tWalk(v, n.%s)\n", f.Name) + case "*List": + fmt.Fprintf(&buf, "\t\twalkList(v, n.%s)\n", f.Name) + case "pointer": + // *ConcreteStruct — need nil check + fmt.Fprintf(&buf, "\t\tif n.%s != nil {\n", f.Name) + fmt.Fprintf(&buf, "\t\t\tWalk(v, n.%s)\n", f.Name) + fmt.Fprintf(&buf, "\t\t}\n") + case "slice_interface": + // []ExprNode, []TableExpr, []Node — iterate and walk each + fmt.Fprintf(&buf, "\t\tfor _, item := range n.%s {\n", f.Name) + fmt.Fprintf(&buf, "\t\t\tWalk(v, item)\n") + fmt.Fprintf(&buf, "\t\t}\n") + case "slice_pointer": + // []*ConcreteStruct — iterate, nil check, walk + fmt.Fprintf(&buf, "\t\tfor _, item := range n.%s {\n", f.Name) + fmt.Fprintf(&buf, "\t\t\tif item != nil {\n") + fmt.Fprintf(&buf, "\t\t\t\tWalk(v, item)\n") + fmt.Fprintf(&buf, "\t\t\t}\n") + fmt.Fprintf(&buf, "\t\t}\n") + } + } + } + + buf.WriteString("\t}\n") + buf.WriteString("}\n") + + formatted, err := format.Source(buf.Bytes()) + if err != nil { + fmt.Fprintf(os.Stderr, "format: %v\n", err) + os.WriteFile("walk_generated.go", buf.Bytes(), 0644) + os.Exit(1) + } + + if err := os.WriteFile("walk_generated.go", formatted, 0644); err != nil { + fmt.Fprintf(os.Stderr, "write: %v\n", err) + os.Exit(1) + } + + cases := 0 + fields := 0 + for _, s := range structs { + if len(s.Fields) > 0 { + cases++ + fields += len(s.Fields) + } + } + fmt.Printf("Generated walk_generated.go: %d cases, %d child fields\n", cases, fields) +} + +// typeString returns the string representation of a Go type expression. +func typeString(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.StarExpr: + return "*" + typeString(t.X) + case *ast.SelectorExpr: + return typeString(t.X) + "." + t.Sel.Name + case *ast.ArrayType: + return "[]" + typeString(t.Elt) + default: + return "" + } +} + +// classifyType returns the walk kind for a field type, or "" if not walkable. +func classifyType(typStr string, structNames map[string]bool) string { + // Interface types: Node, ExprNode, TableExpr, StmtNode + if nodeInterfaces[typStr] { + return "interface" + } + + // *List + if typStr == "*List" { + return "*List" + } + + // Pointer to known struct: *SelectStmt, *Limit, etc. + if strings.HasPrefix(typStr, "*") { + name := typStr[1:] + if excludedStructs[name] { + return "" + } + if structNames[name] { + return "pointer" + } + return "" + } + + // Slice of interface: []ExprNode, []TableExpr, []Node + if strings.HasPrefix(typStr, "[]") { + elemType := typStr[2:] + if nodeInterfaces[elemType] { + return "slice_interface" + } + // Slice of pointer to known struct: []*OrderByItem, []*WindowDef + if strings.HasPrefix(elemType, "*") { + structName := elemType[1:] + if excludedStructs[structName] { + return "" + } + if structNames[structName] { + return "slice_pointer" + } + } + return "" + } + + return "" +} diff --git a/oracle/ast/walk.go b/oracle/ast/walk.go new file mode 100644 index 00000000..8a487ede --- /dev/null +++ b/oracle/ast/walk.go @@ -0,0 +1,57 @@ +//go:generate go run ./cmd/genwalker + +package ast + +// Visitor defines the interface for AST traversal. +// Visit is called for each node during a depth-first walk. +// If Visit returns a non-nil Visitor, Walk recurses into the node's children +// with the returned Visitor, then calls Visit(nil) to signal post-order. +// If Visit returns nil, children are not visited. +type Visitor interface { + Visit(node Node) Visitor +} + +// Walk traverses an AST in depth-first order. It calls v.Visit(node); +// if that returns a non-nil visitor w, it walks each child node with w, +// then calls w.Visit(nil). +func Walk(v Visitor, node Node) { + if node == nil { + return + } + w := v.Visit(node) + if w == nil { + return + } + walkChildren(w, node) + w.Visit(nil) +} + +// Inspect traverses an AST in depth-first order, calling f for each node. +// If f returns true, Inspect recurses into the node's children. +func Inspect(node Node, f func(Node) bool) { + Walk(inspector(f), node) +} + +type inspector func(Node) bool + +func (f inspector) Visit(node Node) Visitor { + if node != nil && f(node) { + return f + } + return nil +} + +// walkList visits a List node and then walks each of its items. +func walkList(v Visitor, list *List) { + if list == nil { + return + } + w := v.Visit(list) + if w == nil { + return + } + for _, item := range list.Items { + Walk(w, item) + } + w.Visit(nil) +} diff --git a/oracle/ast/walk_generated.go b/oracle/ast/walk_generated.go new file mode 100644 index 00000000..55068405 --- /dev/null +++ b/oracle/ast/walk_generated.go @@ -0,0 +1,1427 @@ +// Code generated by genwalker; DO NOT EDIT. + +package ast + +// walkChildren walks the child nodes of node, calling Walk(v, child) +// for each child. This function is generated from parsenodes.go. +func walkChildren(v Visitor, node Node) { + switch n := node.(type) { + case *AdminDDLStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Options) + case *Alias: + walkList(v, n.Cols) + case *AlterAnalyticViewStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.NewName != nil { + Walk(v, n.NewName) + } + walkList(v, n.Options) + case *AlterAttributeDimensionStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.NewName != nil { + Walk(v, n.NewName) + } + case *AlterAuditPolicyStmt: + for _, item := range n.Actions { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.ComponentActions { + if item != nil { + Walk(v, item) + } + } + case *AlterClusterStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *AlterDatabaseLinkStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *AlterDimensionStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.AddLevels { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.AddHierarchies { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.AddAttributes { + if item != nil { + Walk(v, item) + } + } + case *AlterDomainStmt: + if n.Name != nil { + Walk(v, n.Name) + } + Walk(v, n.Display) + Walk(v, n.Order) + walkList(v, n.Annotations) + case *AlterFunctionStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.CompilerParams { + if item != nil { + Walk(v, item) + } + } + case *AlterHierarchyStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.NewName != nil { + Walk(v, n.NewName) + } + case *AlterIndexStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.SplitValues) + case *AlterIndextypeStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Modifications { + if item != nil { + Walk(v, item) + } + } + if n.UsingType != nil { + Walk(v, n.UsingType) + } + case *AlterInmemoryJoinGroupStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Member != nil { + Walk(v, n.Member) + } + case *AlterJsonDualityViewStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *AlterMaterializedViewStmt: + if n.Name != nil { + Walk(v, n.Name) + } + Walk(v, n.StartWith) + Walk(v, n.Next) + if n.ScopeTable != nil { + Walk(v, n.ScopeTable) + } + walkList(v, n.Options) + case *AlterMaterializedZonemapStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *AlterMviewLogStmt: + if n.OnTable != nil { + Walk(v, n.OnTable) + } + walkList(v, n.Columns) + walkList(v, n.Options) + case *AlterOperatorStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Binding != nil { + Walk(v, n.Binding) + } + case *AlterPackageStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.CompilerParams { + if item != nil { + Walk(v, item) + } + } + case *AlterProcedureStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.CompilerParams { + if item != nil { + Walk(v, item) + } + } + case *AlterProfileStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Limits { + if item != nil { + Walk(v, item) + } + } + case *AlterResourceCostStmt: + for _, item := range n.Costs { + if item != nil { + Walk(v, item) + } + } + case *AlterRoleStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *AlterSequenceStmt: + if n.Name != nil { + Walk(v, n.Name) + } + Walk(v, n.IncrementBy) + Walk(v, n.MaxValue) + Walk(v, n.MinValue) + Walk(v, n.Cache) + Walk(v, n.RestartWith) + case *AlterSessionStmt: + walkList(v, n.SetParams) + case *AlterSynonymStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *AlterSystemStmt: + walkList(v, n.SetParams) + case *AlterTableCmd: + if n.ColumnDef != nil { + Walk(v, n.ColumnDef) + } + walkList(v, n.ColumnDefs) + if n.Constraint != nil { + Walk(v, n.Constraint) + } + walkList(v, n.Options) + case *AlterTableStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Actions) + case *AlterTablespaceStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Datafiles { + if item != nil { + Walk(v, item) + } + } + if n.Autoextend != nil { + Walk(v, n.Autoextend) + } + case *AlterTriggerStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.CompilerParams { + if item != nil { + Walk(v, item) + } + } + case *AlterTypeStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.CompilerParams { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Attributes { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.MethodParams { + if item != nil { + Walk(v, item) + } + } + if n.MethodReturn != nil { + Walk(v, n.MethodReturn) + } + Walk(v, n.LimitValue) + if n.ElementType != nil { + Walk(v, n.ElementType) + } + case *AlterUserStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Identified != nil { + Walk(v, n.Identified) + } + for _, item := range n.Quotas { + if item != nil { + Walk(v, item) + } + } + if n.DefaultRole != nil { + Walk(v, n.DefaultRole) + } + case *AlterViewStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Constraint != nil { + Walk(v, n.Constraint) + } + walkList(v, n.Annotations) + case *AnalyzeStmt: + if n.Table != nil { + Walk(v, n.Table) + } + if n.IntoTable != nil { + Walk(v, n.IntoTable) + } + case *AssociateStatisticsStmt: + for _, item := range n.Objects { + if item != nil { + Walk(v, item) + } + } + if n.Using != nil { + Walk(v, n.Using) + } + case *AttrDimAllClause: + Walk(v, n.MemberName) + Walk(v, n.MemberCaption) + Walk(v, n.MemberDesc) + case *AttrDimAttribute: + walkList(v, n.Classifications) + case *AttrDimLevel: + walkList(v, n.Classifications) + walkList(v, n.KeyAttrs) + walkList(v, n.AltKeyAttrs) + Walk(v, n.MemberName) + Walk(v, n.MemberCaption) + Walk(v, n.MemberDesc) + walkList(v, n.OrderByAttrs) + walkList(v, n.Determines) + case *AttrDimSourceClause: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.JoinCondition) + case *AuditActionEntry: + if n.Object != nil { + Walk(v, n.Object) + } + case *AuditStmt: + if n.Object != nil { + Walk(v, n.Object) + } + case *BetweenExpr: + Walk(v, n.Expr) + Walk(v, n.Low) + Walk(v, n.High) + case *BinaryExpr: + Walk(v, n.Left) + Walk(v, n.Right) + case *BoolExpr: + walkList(v, n.Args) + case *CTE: + walkList(v, n.Columns) + Walk(v, n.Query) + if n.Search != nil { + Walk(v, n.Search) + } + if n.Cycle != nil { + Walk(v, n.Cycle) + } + case *CTECycleClause: + walkList(v, n.Columns) + Walk(v, n.CycleValue) + Walk(v, n.NoCycleValue) + case *CTESearchClause: + walkList(v, n.Columns) + case *CallStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Args) + Walk(v, n.Into) + case *CaseExpr: + Walk(v, n.Arg) + walkList(v, n.Whens) + Walk(v, n.Default) + case *CaseWhen: + Walk(v, n.Condition) + Walk(v, n.Result) + case *CastExpr: + Walk(v, n.Arg) + if n.TypeName != nil { + Walk(v, n.TypeName) + } + case *ClusterColumn: + if n.DataType != nil { + Walk(v, n.DataType) + } + case *ColumnConstraint: + Walk(v, n.Expr) + if n.RefTable != nil { + Walk(v, n.RefTable) + } + walkList(v, n.RefColumns) + case *ColumnDef: + if n.TypeName != nil { + Walk(v, n.TypeName) + } + if n.Domain != nil { + Walk(v, n.Domain) + } + Walk(v, n.Default) + if n.Identity != nil { + Walk(v, n.Identity) + } + Walk(v, n.Virtual) + walkList(v, n.Constraints) + case *CommentStmt: + if n.Object != nil { + Walk(v, n.Object) + } + case *ContainersExpr: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Alias != nil { + Walk(v, n.Alias) + } + case *CreateAnalyticViewStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.UsingTable != nil { + Walk(v, n.UsingTable) + } + walkList(v, n.DimBy) + walkList(v, n.Measures) + walkList(v, n.Options) + case *CreateAttributeDimensionStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Classifications) + walkList(v, n.Sources) + walkList(v, n.Attributes) + walkList(v, n.Levels) + if n.AllClause != nil { + Walk(v, n.AllClause) + } + case *CreateAuditPolicyStmt: + for _, item := range n.Actions { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.ComponentActions { + if item != nil { + Walk(v, item) + } + } + case *CreateClusterStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } + Walk(v, n.HashExpr) + case *CreateDimensionStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Levels { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Hierarchies { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Attributes { + if item != nil { + Walk(v, item) + } + } + case *CreateDomainStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.DataType != nil { + Walk(v, n.DataType) + } + Walk(v, n.Default) + walkList(v, n.Constraints) + Walk(v, n.Display) + Walk(v, n.Order) + walkList(v, n.Annotations) + walkList(v, n.EnumItems) + walkList(v, n.Columns) + if n.FlexDomainName != nil { + Walk(v, n.FlexDomainName) + } + walkList(v, n.FlexColumns) + walkList(v, n.ChooseUsing) + Walk(v, n.ChooseExpr) + case *CreateFunctionStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Parameters) + if n.ReturnType != nil { + Walk(v, n.ReturnType) + } + Walk(v, n.Body) + case *CreateHierarchyStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Classifications) + if n.UsingAttrDim != nil { + Walk(v, n.UsingAttrDim) + } + if n.LevelHier != nil { + Walk(v, n.LevelHier) + } + walkList(v, n.HierAttrs) + case *CreateIndexStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Table != nil { + Walk(v, n.Table) + } + if n.Cluster != nil { + Walk(v, n.Cluster) + } + walkList(v, n.Columns) + if n.IndexType != nil { + Walk(v, n.IndexType) + } + walkList(v, n.FromTables) + Walk(v, n.Where) + walkList(v, n.Options) + case *CreateIndextypeStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Operators { + if item != nil { + Walk(v, item) + } + } + if n.UsingType != nil { + Walk(v, n.UsingType) + } + case *CreateInmemoryJoinGroupStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Members { + if item != nil { + Walk(v, item) + } + } + case *CreateJsonDualityViewStmt: + if n.Name != nil { + Walk(v, n.Name) + } + Walk(v, n.Query) + walkList(v, n.Options) + case *CreateMaterializedZonemapStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.OnTable != nil { + Walk(v, n.OnTable) + } + if n.AsQuery != nil { + Walk(v, n.AsQuery) + } + case *CreateMviewLogStmt: + if n.OnTable != nil { + Walk(v, n.OnTable) + } + walkList(v, n.Columns) + Walk(v, n.PurgeStart) + Walk(v, n.PurgeNext) + walkList(v, n.Options) + case *CreateOperatorStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Bindings { + if item != nil { + Walk(v, item) + } + } + case *CreatePackageStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Body) + case *CreateProcedureStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Parameters) + Walk(v, n.Body) + case *CreateProfileStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Limits { + if item != nil { + Walk(v, item) + } + } + case *CreatePropertyGraphStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.VertexTables { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.EdgeTables { + if item != nil { + Walk(v, item) + } + } + if n.Options != nil { + Walk(v, n.Options) + } + case *CreateRoleStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *CreateSchemaStmt: + walkList(v, n.Stmts) + case *CreateSequenceStmt: + if n.Name != nil { + Walk(v, n.Name) + } + Walk(v, n.IncrementBy) + Walk(v, n.StartWith) + Walk(v, n.MaxValue) + Walk(v, n.MinValue) + Walk(v, n.Cache) + case *CreateSynonymStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Target != nil { + Walk(v, n.Target) + } + case *CreateTableStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Columns) + walkList(v, n.Constraints) + Walk(v, n.AsQuery) + if n.Storage != nil { + Walk(v, n.Storage) + } + if n.Partition != nil { + Walk(v, n.Partition) + } + walkList(v, n.Hints) + if n.Parent != nil { + Walk(v, n.Parent) + } + walkList(v, n.Options) + case *CreateTablespaceSetStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Datafiles { + if item != nil { + Walk(v, item) + } + } + case *CreateTablespaceStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, item := range n.Datafiles { + if item != nil { + Walk(v, item) + } + } + if n.Autoextend != nil { + Walk(v, n.Autoextend) + } + case *CreateTriggerStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Events) + if n.Table != nil { + Walk(v, n.Table) + } + Walk(v, n.When) + Walk(v, n.Body) + case *CreateTypeStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Attributes) + if n.AsTable != nil { + Walk(v, n.AsTable) + } + if n.AsVarray != nil { + Walk(v, n.AsVarray) + } + Walk(v, n.VarraySize) + walkList(v, n.Body) + case *CreateUserStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Identified != nil { + Walk(v, n.Identified) + } + for _, item := range n.Quotas { + if item != nil { + Walk(v, item) + } + } + case *CreateVectorIndexStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.TableName != nil { + Walk(v, n.TableName) + } + case *CreateViewStmt: + if n.Name != nil { + Walk(v, n.Name) + } + walkList(v, n.Columns) + Walk(v, n.Query) + Walk(v, n.StartWith) + Walk(v, n.Next) + walkList(v, n.Options) + case *CubeClause: + walkList(v, n.Args) + case *CursorExpr: + Walk(v, n.Subquery) + case *DDLOption: + walkList(v, n.Items) + case *DatafileClause: + if n.Autoextend != nil { + Walk(v, n.Autoextend) + } + case *DecodeExpr: + Walk(v, n.Arg) + walkList(v, n.Pairs) + Walk(v, n.Default) + case *DecodePair: + Walk(v, n.Search) + Walk(v, n.Result) + case *DefaultRoleClause: + for _, item := range n.Roles { + if item != nil { + Walk(v, item) + } + } + case *DeleteStmt: + if n.Table != nil { + Walk(v, n.Table) + } + if n.PartitionExt != nil { + Walk(v, n.PartitionExt) + } + if n.Alias != nil { + Walk(v, n.Alias) + } + Walk(v, n.WhereClause) + walkList(v, n.Returning) + if n.ErrorLog != nil { + Walk(v, n.ErrorLog) + } + walkList(v, n.Hints) + case *DimensionAttribute: + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } + case *DimensionHierarchy: + for _, item := range n.JoinKeys { + if item != nil { + Walk(v, item) + } + } + case *DimensionJoinKey: + for _, item := range n.ChildKeys { + if item != nil { + Walk(v, item) + } + } + case *DimensionLevel: + for _, item := range n.Columns { + if item != nil { + Walk(v, item) + } + } + case *DisassociateStatisticsStmt: + for _, item := range n.Objects { + if item != nil { + Walk(v, item) + } + } + case *DomainColumn: + if n.DataType != nil { + Walk(v, n.DataType) + } + walkList(v, n.Annotations) + case *DomainConstraint: + Walk(v, n.CheckExpr) + walkList(v, n.State) + case *DomainEnumItem: + Walk(v, n.Value) + case *DropStmt: + walkList(v, n.Names) + case *DropTablespaceStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *ErrorLogClause: + if n.Into != nil { + Walk(v, n.Into) + } + Walk(v, n.Tag) + Walk(v, n.Reject) + case *ExceptionHandler: + walkList(v, n.Exceptions) + walkList(v, n.Statements) + case *ExistsExpr: + Walk(v, n.Subquery) + case *ExplainPlanStmt: + if n.Into != nil { + Walk(v, n.Into) + } + Walk(v, n.Statement) + case *ExtractExpr: + Walk(v, n.Expr) + case *FetchFirstClause: + Walk(v, n.Offset) + Walk(v, n.Count) + case *FlashbackClause: + Walk(v, n.Expr) + Walk(v, n.VersionsLow) + Walk(v, n.VersionsHigh) + case *FlashbackDatabaseStmt: + if n.DatabaseName != nil { + Walk(v, n.DatabaseName) + } + Walk(v, n.ToSCN) + Walk(v, n.ToTimestamp) + case *FlashbackTableStmt: + for _, item := range n.Tables { + if item != nil { + Walk(v, item) + } + } + if n.Table != nil { + Walk(v, n.Table) + } + Walk(v, n.ToSCN) + Walk(v, n.ToTimestamp) + case *ForUpdateClause: + walkList(v, n.Tables) + Walk(v, n.Wait) + case *FuncCallExpr: + if n.FuncName != nil { + Walk(v, n.FuncName) + } + walkList(v, n.Args) + walkList(v, n.OrderBy) + if n.KeepClause != nil { + Walk(v, n.KeepClause) + } + if n.Over != nil { + Walk(v, n.Over) + } + case *GrantStmt: + walkList(v, n.Privileges) + if n.OnObject != nil { + Walk(v, n.OnObject) + } + walkList(v, n.Grantees) + case *GraphEdgeDef: + if n.Name != nil { + Walk(v, n.Name) + } + case *GraphTableDef: + if n.Name != nil { + Walk(v, n.Name) + } + case *GroupingSetsClause: + walkList(v, n.Sets) + case *HierAttr: + walkList(v, n.Classifications) + case *HierLevelClause: + walkList(v, n.Classifications) + if n.ChildOf != nil { + Walk(v, n.ChildOf) + } + case *HierarchicalClause: + Walk(v, n.ConnectBy) + Walk(v, n.StartWith) + case *IdentityClause: + walkList(v, n.Options) + case *InExpr: + Walk(v, n.Expr) + walkList(v, n.List) + case *IndexColumn: + Walk(v, n.Expr) + case *IndextypeModOp: + if n.Name != nil { + Walk(v, n.Name) + } + case *IndextypeOp: + if n.Name != nil { + Walk(v, n.Name) + } + case *InlineExternalTable: + walkList(v, n.Columns) + Walk(v, n.RejectLimit) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *InsertIntoClause: + if n.Table != nil { + Walk(v, n.Table) + } + walkList(v, n.Columns) + walkList(v, n.Values) + Walk(v, n.When) + case *InsertStmt: + if n.Table != nil { + Walk(v, n.Table) + } + if n.PartitionExt != nil { + Walk(v, n.PartitionExt) + } + if n.Alias != nil { + Walk(v, n.Alias) + } + walkList(v, n.Columns) + walkList(v, n.Values) + walkList(v, n.SetClauses) + if n.Select != nil { + Walk(v, n.Select) + } + walkList(v, n.MultiTable) + Walk(v, n.Subquery) + walkList(v, n.Returning) + if n.ErrorLog != nil { + Walk(v, n.ErrorLog) + } + walkList(v, n.Hints) + case *IntervalExpr: + Walk(v, n.Value) + case *IsExpr: + Walk(v, n.Expr) + walkList(v, n.TypeList) + case *JoinClause: + Walk(v, n.Left) + Walk(v, n.Right) + Walk(v, n.On) + walkList(v, n.Using) + case *JoinGroupMember: + if n.Table != nil { + Walk(v, n.Table) + } + case *JsonTableColumn: + if n.TypeName != nil { + Walk(v, n.TypeName) + } + Walk(v, n.Path) + if n.Nested != nil { + Walk(v, n.Nested) + } + case *JsonTableRef: + Walk(v, n.Expr) + Walk(v, n.Path) + walkList(v, n.Columns) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *KeepClause: + walkList(v, n.OrderBy) + case *LateralRef: + Walk(v, n.Subquery) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *LikeExpr: + Walk(v, n.Expr) + Walk(v, n.Pattern) + Walk(v, n.Escape) + case *List: + for _, item := range n.Items { + Walk(v, item) + } + case *LockTableItem: + if n.Table != nil { + Walk(v, n.Table) + } + case *LockTableStmt: + for _, item := range n.Tables { + if item != nil { + Walk(v, item) + } + } + Walk(v, n.Wait) + case *MatchRecognizeClause: + walkList(v, n.PartitionBy) + walkList(v, n.OrderBy) + walkList(v, n.Measures) + walkList(v, n.Subsets) + walkList(v, n.Definitions) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *MergeClause: + Walk(v, n.Condition) + walkList(v, n.UpdateSet) + Walk(v, n.UpdateWhere) + Walk(v, n.DeleteWhere) + walkList(v, n.InsertCols) + walkList(v, n.InsertVals) + Walk(v, n.InsertWhere) + case *MergeStmt: + if n.Target != nil { + Walk(v, n.Target) + } + if n.TargetAlias != nil { + Walk(v, n.TargetAlias) + } + Walk(v, n.Source) + if n.SourceAlias != nil { + Walk(v, n.SourceAlias) + } + Walk(v, n.On) + walkList(v, n.Clauses) + if n.ErrorLog != nil { + Walk(v, n.ErrorLog) + } + walkList(v, n.Hints) + case *ModelClause: + if n.CellRefOptions != nil { + Walk(v, n.CellRefOptions) + } + for _, item := range n.RefModels { + if item != nil { + Walk(v, item) + } + } + if n.MainModel != nil { + Walk(v, n.MainModel) + } + case *ModelColumnClauses: + walkList(v, n.PartitionBy) + walkList(v, n.DimensionBy) + walkList(v, n.Measures) + case *ModelForLoop: + walkList(v, n.InList) + if n.Subquery != nil { + Walk(v, n.Subquery) + } + Walk(v, n.LikePattern) + Walk(v, n.FromExpr) + Walk(v, n.ToExpr) + Walk(v, n.IncrExpr) + case *ModelMainModel: + if n.ColumnClauses != nil { + Walk(v, n.ColumnClauses) + } + if n.CellRefOptions != nil { + Walk(v, n.CellRefOptions) + } + if n.RulesClause != nil { + Walk(v, n.RulesClause) + } + case *ModelRefModel: + if n.Subquery != nil { + Walk(v, n.Subquery) + } + if n.ColumnClauses != nil { + Walk(v, n.ColumnClauses) + } + if n.CellRefOptions != nil { + Walk(v, n.CellRefOptions) + } + case *ModelRule: + Walk(v, n.CellRef) + Walk(v, n.Expr) + case *ModelRulesClause: + Walk(v, n.Iterate) + Walk(v, n.Until) + walkList(v, n.Rules) + case *MultisetExpr: + Walk(v, n.Left) + Walk(v, n.Right) + case *NoauditStmt: + if n.Object != nil { + Walk(v, n.Object) + } + case *OperatorBinding: + if n.UsingFunc != nil { + Walk(v, n.UsingFunc) + } + if n.AncillaryTo != nil { + Walk(v, n.AncillaryTo) + } + case *PLSQLAssign: + Walk(v, n.Target) + Walk(v, n.Value) + case *PLSQLBlock: + walkList(v, n.Declarations) + walkList(v, n.Statements) + walkList(v, n.Exceptions) + case *PLSQLCall: + Walk(v, n.Name) + case *PLSQLCase: + Walk(v, n.Expr) + for _, item := range n.Whens { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Else { + Walk(v, item) + } + case *PLSQLContinue: + Walk(v, n.Condition) + case *PLSQLCursorDecl: + walkList(v, n.Parameters) + Walk(v, n.Query) + case *PLSQLElsIf: + Walk(v, n.Condition) + walkList(v, n.Then) + case *PLSQLExecImmediate: + Walk(v, n.SQL) + walkList(v, n.Into) + walkList(v, n.Using) + case *PLSQLExit: + Walk(v, n.Condition) + case *PLSQLFetch: + walkList(v, n.Into) + Walk(v, n.Limit) + case *PLSQLForall: + Walk(v, n.Lower) + Walk(v, n.Upper) + Walk(v, n.Body) + case *PLSQLIf: + Walk(v, n.Condition) + walkList(v, n.Then) + walkList(v, n.ElsIfs) + walkList(v, n.Else) + case *PLSQLLoop: + Walk(v, n.Condition) + Walk(v, n.LowerBound) + Walk(v, n.UpperBound) + walkList(v, n.CursorArgs) + walkList(v, n.Statements) + case *PLSQLOpen: + walkList(v, n.Args) + Walk(v, n.ForQuery) + case *PLSQLPipeRow: + Walk(v, n.Row) + case *PLSQLPragma: + walkList(v, n.Args) + case *PLSQLReturn: + Walk(v, n.Expr) + case *PLSQLTypeDecl: + if n.ElementType != nil { + Walk(v, n.ElementType) + } + if n.IndexBy != nil { + Walk(v, n.IndexBy) + } + Walk(v, n.Limit) + walkList(v, n.Fields) + if n.ReturnType != nil { + Walk(v, n.ReturnType) + } + case *PLSQLVarDecl: + if n.TypeName != nil { + Walk(v, n.TypeName) + } + Walk(v, n.Default) + case *PLSQLWhen: + Walk(v, n.Expr) + for _, item := range n.Stmts { + Walk(v, item) + } + case *Parameter: + if n.TypeName != nil { + Walk(v, n.TypeName) + } + Walk(v, n.Default) + case *ParenExpr: + Walk(v, n.Expr) + case *PartitionClause: + walkList(v, n.Columns) + Walk(v, n.Interval) + walkList(v, n.Partitions) + if n.Subpartition != nil { + Walk(v, n.Subpartition) + } + case *PartitionDef: + walkList(v, n.Values) + case *PartitionExtClause: + walkList(v, n.Keys) + case *PivotClause: + walkList(v, n.AggFuncs) + Walk(v, n.ForCol) + walkList(v, n.ForCols) + walkList(v, n.InList) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *PurgeStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *RawStmt: + Walk(v, n.Stmt) + case *RenameStmt: + if n.OldName != nil { + Walk(v, n.OldName) + } + if n.NewName != nil { + Walk(v, n.NewName) + } + case *ResTarget: + Walk(v, n.Expr) + case *RevokeStmt: + walkList(v, n.Privileges) + if n.OnObject != nil { + Walk(v, n.OnObject) + } + walkList(v, n.Grantees) + case *RollupClause: + walkList(v, n.Args) + case *SampleClause: + Walk(v, n.Percent) + Walk(v, n.Seed) + case *SelectStmt: + if n.WithClause != nil { + Walk(v, n.WithClause) + } + walkList(v, n.TargetList) + if n.Into != nil { + Walk(v, n.Into) + } + walkList(v, n.IntoVars) + walkList(v, n.FromClause) + Walk(v, n.WhereClause) + if n.Hierarchical != nil { + Walk(v, n.Hierarchical) + } + walkList(v, n.GroupClause) + Walk(v, n.HavingClause) + if n.ModelClause != nil { + Walk(v, n.ModelClause) + } + for _, item := range n.WindowDefs { + if item != nil { + Walk(v, item) + } + } + Walk(v, n.QualifyClause) + walkList(v, n.OrderBy) + if n.ForUpdate != nil { + Walk(v, n.ForUpdate) + } + if n.FetchFirst != nil { + Walk(v, n.FetchFirst) + } + if n.Pivot != nil { + Walk(v, n.Pivot) + } + if n.Unpivot != nil { + Walk(v, n.Unpivot) + } + walkList(v, n.Hints) + if n.Larg != nil { + Walk(v, n.Larg) + } + if n.Rarg != nil { + Walk(v, n.Rarg) + } + case *SetClause: + if n.Column != nil { + Walk(v, n.Column) + } + walkList(v, n.Columns) + Walk(v, n.Value) + case *SetConstraintsStmt: + for _, item := range n.Constraints { + if item != nil { + Walk(v, item) + } + } + case *SetParam: + Walk(v, n.Value) + case *SetRoleStmt: + for _, item := range n.Roles { + if item != nil { + Walk(v, item) + } + } + for _, item := range n.Except { + if item != nil { + Walk(v, item) + } + } + case *SortBy: + Walk(v, n.Expr) + case *SubqueryExpr: + Walk(v, n.Subquery) + case *SubqueryRef: + Walk(v, n.Subquery) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *TableCollectionExpr: + Walk(v, n.Expr) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *TableConstraint: + walkList(v, n.Columns) + Walk(v, n.Expr) + if n.RefTable != nil { + Walk(v, n.RefTable) + } + walkList(v, n.RefColumns) + case *TableRef: + if n.Name != nil { + Walk(v, n.Name) + } + if n.Alias != nil { + Walk(v, n.Alias) + } + if n.Sample != nil { + Walk(v, n.Sample) + } + if n.Flashback != nil { + Walk(v, n.Flashback) + } + if n.PartitionExt != nil { + Walk(v, n.PartitionExt) + } + case *TreatExpr: + Walk(v, n.Expr) + if n.TypeName != nil { + Walk(v, n.TypeName) + } + case *TruncateStmt: + if n.Table != nil { + Walk(v, n.Table) + } + case *TypeAttribute: + if n.DataType != nil { + Walk(v, n.DataType) + } + case *TypeBodyMember: + Walk(v, n.Subprog) + case *TypeName: + walkList(v, n.Names) + walkList(v, n.TypeMods) + walkList(v, n.ArrayBounds) + case *UnaryExpr: + Walk(v, n.Operand) + case *UnpivotClause: + Walk(v, n.ValueCol) + Walk(v, n.PivotCol) + walkList(v, n.InList) + if n.Alias != nil { + Walk(v, n.Alias) + } + case *UpdateStmt: + if n.Table != nil { + Walk(v, n.Table) + } + if n.PartitionExt != nil { + Walk(v, n.PartitionExt) + } + if n.Alias != nil { + Walk(v, n.Alias) + } + walkList(v, n.SetClauses) + walkList(v, n.FromClause) + Walk(v, n.WhereClause) + walkList(v, n.Returning) + if n.ErrorLog != nil { + Walk(v, n.ErrorLog) + } + walkList(v, n.Hints) + case *UserQuotaClause: + if n.Tablespace != nil { + Walk(v, n.Tablespace) + } + case *WindowBound: + Walk(v, n.Value) + case *WindowDef: + if n.Spec != nil { + Walk(v, n.Spec) + } + case *WindowFrame: + if n.Start != nil { + Walk(v, n.Start) + } + if n.End != nil { + Walk(v, n.End) + } + case *WindowSpec: + walkList(v, n.PartitionBy) + walkList(v, n.OrderBy) + if n.Frame != nil { + Walk(v, n.Frame) + } + case *WithClause: + walkList(v, n.CTEs) + case *XmlTableColumn: + if n.TypeName != nil { + Walk(v, n.TypeName) + } + Walk(v, n.Path) + Walk(v, n.Default) + case *XmlTableRef: + Walk(v, n.XPath) + Walk(v, n.Passing) + walkList(v, n.Columns) + if n.Alias != nil { + Walk(v, n.Alias) + } + } +} diff --git a/oracle/ast/walk_generated_test.go b/oracle/ast/walk_generated_test.go new file mode 100644 index 00000000..3c7df76d --- /dev/null +++ b/oracle/ast/walk_generated_test.go @@ -0,0 +1,202 @@ +package ast + +import ( + goast "go/ast" + "go/parser" + "go/token" + "sort" + "strings" + "testing" +) + +func TestWalkGeneratedCoversNodeLikeFields(t *testing.T) { + expected := expectedWalkFields(t) + generated := generatedWalkFields(t) + + var missing []string + for structName, fields := range expected { + actualFields, ok := generated[structName] + if !ok { + missing = append(missing, structName+": missing case") + continue + } + for _, field := range fields { + if !actualFields[field] { + missing = append(missing, structName+"."+field) + } + } + } + if len(missing) > 0 { + sort.Strings(missing) + t.Fatalf("walk_generated.go is missing walkable fields:\n%s", strings.Join(missing, "\n")) + } +} + +func expectedWalkFields(t *testing.T) map[string][]string { + t.Helper() + + fset := token.NewFileSet() + var files []*goast.File + for _, src := range []string{"parsenodes.go", "node.go"} { + f, err := parser.ParseFile(fset, src, nil, 0) + if err != nil { + t.Fatalf("parse %s: %v", src, err) + } + files = append(files, f) + } + + structNames := map[string]bool{} + for _, f := range files { + for _, decl := range f.Decls { + gd, ok := decl.(*goast.GenDecl) + if !ok || gd.Tok != token.TYPE { + continue + } + for _, spec := range gd.Specs { + ts := spec.(*goast.TypeSpec) + if _, ok := ts.Type.(*goast.StructType); ok { + structNames[ts.Name.Name] = true + } + } + } + } + + out := map[string][]string{} + for _, f := range files { + for _, decl := range f.Decls { + gd, ok := decl.(*goast.GenDecl) + if !ok || gd.Tok != token.TYPE { + continue + } + for _, spec := range gd.Specs { + ts := spec.(*goast.TypeSpec) + st, ok := ts.Type.(*goast.StructType) + if !ok { + continue + } + for _, fl := range st.Fields.List { + if len(fl.Names) == 0 { + continue + } + if !isWalkableFieldType(typeStringForTest(fl.Type), structNames) { + continue + } + for _, name := range fl.Names { + out[ts.Name.Name] = append(out[ts.Name.Name], name.Name) + } + } + } + } + } + return out +} + +func generatedWalkFields(t *testing.T) map[string]map[string]bool { + t.Helper() + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "walk_generated.go", nil, 0) + if err != nil { + t.Fatalf("parse walk_generated.go: %v", err) + } + + out := map[string]map[string]bool{} + for _, decl := range f.Decls { + fn, ok := decl.(*goast.FuncDecl) + if !ok || fn.Name.Name != "walkChildren" { + continue + } + goast.Inspect(fn.Body, func(n goast.Node) bool { + sw, ok := n.(*goast.TypeSwitchStmt) + if !ok { + return true + } + for _, stmt := range sw.Body.List { + clause := stmt.(*goast.CaseClause) + structName := walkCaseStructName(clause) + if structName == "" { + continue + } + fields := map[string]bool{} + for _, bodyStmt := range clause.Body { + goast.Inspect(bodyStmt, func(n goast.Node) bool { + sel, ok := n.(*goast.SelectorExpr) + if !ok { + return true + } + if ident, ok := sel.X.(*goast.Ident); ok && ident.Name == "n" { + fields[sel.Sel.Name] = true + } + return true + }) + } + out[structName] = fields + } + return false + }) + } + return out +} + +func walkCaseStructName(clause *goast.CaseClause) string { + if len(clause.List) != 1 { + return "" + } + star, ok := clause.List[0].(*goast.StarExpr) + if !ok { + return "" + } + ident, ok := star.X.(*goast.Ident) + if !ok { + return "" + } + return ident.Name +} + +func typeStringForTest(expr goast.Expr) string { + switch t := expr.(type) { + case *goast.Ident: + return t.Name + case *goast.StarExpr: + return "*" + typeStringForTest(t.X) + case *goast.SelectorExpr: + return typeStringForTest(t.X) + "." + t.Sel.Name + case *goast.ArrayType: + return "[]" + typeStringForTest(t.Elt) + default: + return "" + } +} + +func isWalkableFieldType(typStr string, structNames map[string]bool) bool { + switch typStr { + case "Node", "ExprNode", "TableExpr", "StmtNode", "*List": + return true + } + + excludedStructs := map[string]bool{ + "Loc": true, + "List": true, + "String": true, + "Integer": true, + "Float": true, + "Boolean": true, + } + + if strings.HasPrefix(typStr, "*") { + name := typStr[1:] + return structNames[name] && !excludedStructs[name] + } + if strings.HasPrefix(typStr, "[]") { + elemType := typStr[2:] + switch elemType { + case "Node", "ExprNode", "TableExpr", "StmtNode": + return true + } + if strings.HasPrefix(elemType, "*") { + name := elemType[1:] + return structNames[name] && !excludedStructs[name] + } + } + return false +} diff --git a/oracle/ast/walk_test.go b/oracle/ast/walk_test.go new file mode 100644 index 00000000..790150c3 --- /dev/null +++ b/oracle/ast/walk_test.go @@ -0,0 +1,92 @@ +package ast + +import ( + "reflect" + "testing" +) + +func TestWalkSelectStmt(t *testing.T) { + stmt := &SelectStmt{ + TargetList: &List{Items: []Node{ + &ResTarget{Expr: &ColumnRef{Column: "id"}}, + &ResTarget{Expr: &ColumnRef{Column: "name"}}, + }}, + FromClause: &List{Items: []Node{ + &TableRef{Name: &ObjectName{Name: "users"}}, + }}, + WhereClause: &BinaryExpr{ + Op: "=", + Left: &ColumnRef{Column: "id"}, + Right: &NumberLiteral{Val: "1", Ival: 1}, + }, + OrderBy: &List{Items: []Node{ + &SortBy{Expr: &ColumnRef{Column: "name"}}, + }}, + FetchFirst: &FetchFirstClause{ + Count: &NumberLiteral{Val: "10", Ival: 10}, + }, + } + + var visited []string + Inspect(stmt, func(n Node) bool { + if n == nil { + return false + } + visited = append(visited, reflect.TypeOf(n).Elem().Name()) + return true + }) + + typeSet := map[string]bool{} + for _, v := range visited { + typeSet[v] = true + } + for _, want := range []string{"SelectStmt", "List", "ResTarget", "ColumnRef", "TableRef", "ObjectName", "BinaryExpr", "NumberLiteral", "SortBy", "FetchFirstClause"} { + if !typeSet[want] { + t.Errorf("expected to visit %s, visited: %v", want, visited) + } + } +} + +func TestWalkNil(t *testing.T) { + Walk(inspector(func(n Node) bool { return true }), nil) +} + +func TestInspectPruning(t *testing.T) { + stmt := &SelectStmt{ + WhereClause: &BinaryExpr{ + Op: "=", + Left: &ColumnRef{Column: "id"}, + Right: &NumberLiteral{Val: "1", Ival: 1}, + }, + } + + var visited []string + Inspect(stmt, func(n Node) bool { + if n == nil { + return false + } + name := reflect.TypeOf(n).Elem().Name() + visited = append(visited, name) + if name == "BinaryExpr" { + return false + } + return true + }) + + typeSet := map[string]bool{} + for _, v := range visited { + typeSet[v] = true + } + if !typeSet["SelectStmt"] { + t.Error("expected SelectStmt") + } + if !typeSet["BinaryExpr"] { + t.Error("expected BinaryExpr") + } + if typeSet["ColumnRef"] { + t.Error("ColumnRef should have been pruned") + } + if typeSet["NumberLiteral"] { + t.Error("NumberLiteral should have been pruned") + } +}