diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 00000000..feb47b76 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,353 @@ +# Plan to Unskip All Query Tests + +## Overview + +This plan outlines the strategy to unskip all 1,011 currently skipped query tests in the T-SQL parser. The tests are organized in `/home/user/teesql/parser/testdata/` with each test containing: +- `metadata.json` - Skip flag (`{"skip": true}` or `{"skip": false}`) +- `query.sql` - T-SQL query to parse +- `ast.json` - Expected AST output + +### Current Status +- **Total tests:** 1,023 +- **Skipped tests:** 1,011 (98.9%) +- **Active tests:** 12 (1.1%) + +### Currently Implemented Features +- SELECT statements (basic: columns, FROM, aliases) +- PRINT statements +- THROW statements +- ALTER TABLE DROP INDEX +- DROP DATABASE SCOPED CREDENTIAL +- REVERT statements + +--- + +## Phase 1: Complete SELECT Statement Support + +**Goal:** Unskip `SelectStatementTests` and related baseline tests + +### 1.1 Core SELECT Enhancements +- [ ] **TOP clause** - `SELECT TOP 10 ...`, `TOP (n) PERCENT WITH TIES` +- [ ] **INTO clause** - `SELECT ... INTO table FROM ...` +- [ ] **Column aliases** - `AS alias`, `[column name]` without AS +- [ ] **Bracketed identifiers** - `[schema].[table].[column]` + +### 1.2 WHERE Clause +- [ ] **Comparison operators** - `=`, `<>`, `<`, `>`, `<=`, `>=` +- [ ] **Boolean operators** - `AND`, `OR`, `NOT` +- [ ] **IN expressions** - `col IN (1, 2, 3)` +- [ ] **BETWEEN expressions** - `col BETWEEN 1 AND 10` +- [ ] **LIKE expressions** - `col LIKE 'pattern%'` +- [ ] **IS NULL / IS NOT NULL** + +### 1.3 GROUP BY and HAVING +- [ ] **Basic GROUP BY** - `GROUP BY col1, col2` +- [ ] **GROUP BY ALL** +- [ ] **WITH ROLLUP / WITH CUBE** +- [ ] **HAVING clause** + +### 1.4 ORDER BY +- [ ] **ORDER BY clause** - `ORDER BY col ASC/DESC` +- [ ] **Multiple columns** +- [ ] **Ordinal references** - `ORDER BY 1, 2` + +### 1.5 JOINs +- [ ] **INNER JOIN** +- [ ] **LEFT/RIGHT/FULL OUTER JOIN** +- [ ] **CROSS JOIN** +- [ ] **JOIN hints** (LOOP, HASH, MERGE) + +### 1.6 Set Operations +- [ ] **UNION / UNION ALL** +- [ ] **EXCEPT** +- [ ] **INTERSECT** + +### 1.7 Subqueries +- [ ] **Scalar subqueries** - `(SELECT ...)` +- [ ] **Table subqueries** - `FROM (SELECT ...) AS t` +- [ ] **EXISTS / NOT EXISTS** + +### 1.8 Tests to Unskip +- `SelectStatementTests` → `SelectStatementTests/metadata.json` +- `Baselines*_SelectStatementTests` variants + +--- + +## Phase 2: Expression Support + +**Goal:** Support all expression types used across tests + +### 2.1 Literals +- [ ] **Numeric literals** - integers, decimals, floats +- [ ] **Binary literals** - `0x...` +- [ ] **National strings** - `N'...'` +- [ ] **GUID literals** - `{guid'...'}` +- [ ] **Date/time literals** +- [ ] **NULL literal** + +### 2.2 Arithmetic Expressions +- [ ] **Multiplication / Division** - `*`, `/`, `%` +- [ ] **Unary minus/plus** +- [ ] **Bitwise operators** - `&`, `|`, `^`, `~` + +### 2.3 Function Calls +- [ ] **Scalar functions** - `GETDATE()`, `ISNULL()`, etc. +- [ ] **Aggregate functions** - `COUNT()`, `SUM()`, `AVG()`, `MIN()`, `MAX()` +- [ ] **Window functions** - `ROW_NUMBER() OVER(...)` +- [ ] **CAST / CONVERT** +- [ ] **CASE expressions** + +### 2.4 Special Expressions +- [ ] **COALESCE** +- [ ] **NULLIF** +- [ ] **IIF** +- [ ] **Collation expressions** + +--- + +## Phase 3: DML Statements + +### 3.1 INSERT Statement +- [ ] **INSERT INTO ... VALUES** +- [ ] **INSERT INTO ... SELECT** +- [ ] **INSERT INTO ... EXEC** +- [ ] **DEFAULT VALUES** +- [ ] **OUTPUT clause** + +**Tests:** `InsertStatementTests`, related baselines + +### 3.2 UPDATE Statement +- [ ] **UPDATE ... SET** +- [ ] **UPDATE with FROM clause** +- [ ] **UPDATE with JOINs** +- [ ] **OUTPUT clause** + +**Tests:** `UpdateStatementTests`, related baselines + +### 3.3 DELETE Statement +- [ ] **DELETE FROM** +- [ ] **DELETE with JOINs** +- [ ] **OUTPUT clause** +- [ ] **TRUNCATE TABLE** + +**Tests:** `DeleteStatementTests`, `TruncateTableStatementTests`, related baselines + +### 3.4 MERGE Statement +- [ ] **MERGE ... USING ... ON** +- [ ] **WHEN MATCHED / NOT MATCHED** +- [ ] **OUTPUT clause** + +**Tests:** `MergeStatementTests*` + +--- + +## Phase 4: DDL Statements - Tables and Indexes + +### 4.1 CREATE TABLE +- [ ] **Column definitions** +- [ ] **Data types** (all SQL Server types) +- [ ] **Constraints** (PRIMARY KEY, FOREIGN KEY, UNIQUE, CHECK, DEFAULT) +- [ ] **Computed columns** +- [ ] **Temporal tables** +- [ ] **Partitioning** + +**Tests:** `CreateTableTests*` + +### 4.2 ALTER TABLE +- [ ] **ADD column** +- [ ] **ALTER COLUMN** +- [ ] **DROP COLUMN** +- [ ] **ADD/DROP CONSTRAINT** + +**Tests:** `AlterTableStatementTests*` + +### 4.3 CREATE/ALTER/DROP INDEX +- [ ] **Clustered/Nonclustered indexes** +- [ ] **INCLUDE columns** +- [ ] **WHERE clause (filtered)** +- [ ] **Index options** + +**Tests:** `CreateIndexStatementTests*`, `AlterIndexStatementTests*` + +--- + +## Phase 5: Programmability + +### 5.1 Variables and Control Flow +- [ ] **DECLARE** - variables, table variables +- [ ] **SET** - variable assignment +- [ ] **IF...ELSE** +- [ ] **WHILE** +- [ ] **BEGIN...END blocks** +- [ ] **TRY...CATCH** +- [ ] **GOTO/LABEL** +- [ ] **RETURN** +- [ ] **WAITFOR** + +**Tests:** `DeclareStatementTests`, `SetStatementTests`, `IfStatementTests`, `WhileStatementTests`, `TryCatchStatementTests`, etc. + +### 5.2 Stored Procedures +- [ ] **CREATE/ALTER PROCEDURE** +- [ ] **EXECUTE/EXEC** +- [ ] **Parameters (IN, OUT, DEFAULT)** +- [ ] **WITH options** (RECOMPILE, ENCRYPTION, etc.) + +**Tests:** `CreateProcedureStatementTests*`, `AlterProcedureStatementTests*`, `ExecuteStatementTests*` + +### 5.3 Functions +- [ ] **CREATE/ALTER FUNCTION** +- [ ] **Scalar functions** +- [ ] **Table-valued functions** +- [ ] **Inline table-valued functions** + +**Tests:** `CreateFunctionStatementTests*`, `AlterFunctionStatementTests*` + +### 5.4 Triggers +- [ ] **CREATE/ALTER TRIGGER** +- [ ] **DML triggers** +- [ ] **DDL triggers** +- [ ] **Logon triggers** + +**Tests:** `CreateTriggerStatementTests*`, `AlterTriggerStatementTests*` + +--- + +## Phase 6: DDL Statements - Schema Objects + +### 6.1 Views +- [ ] **CREATE/ALTER VIEW** +- [ ] **WITH CHECK OPTION** +- [ ] **WITH SCHEMABINDING** + +**Tests:** `CreateViewStatementTests*`, `AlterViewStatementTests*` + +### 6.2 Schemas and Users +- [ ] **CREATE/ALTER SCHEMA** +- [ ] **CREATE/ALTER USER** +- [ ] **CREATE/ALTER LOGIN** +- [ ] **CREATE/ALTER ROLE** + +**Tests:** `CreateSchemaStatementTests*`, `CreateUserStatementTests*`, etc. + +### 6.3 Other DDL Objects +- [ ] **Sequences** +- [ ] **Synonyms** +- [ ] **Types** (user-defined types) +- [ ] **Assemblies** +- [ ] **Certificates and Keys** +- [ ] **Credentials** + +--- + +## Phase 7: Database Management + +### 7.1 Database Statements +- [ ] **CREATE/ALTER DATABASE** +- [ ] **DROP DATABASE** +- [ ] **USE database** +- [ ] **Database options** + +**Tests:** `AlterCreateDatabaseStatementTests*`, `AlterDatabaseOptionsTests*` + +### 7.2 Backup and Restore +- [ ] **BACKUP DATABASE/LOG** +- [ ] **RESTORE DATABASE/LOG** + +**Tests:** `BackupStatementTests*`, `RestoreStatementTests*` + +### 7.3 Server-level +- [ ] **Server configuration** +- [ ] **Endpoints** +- [ ] **Linked servers** + +--- + +## Phase 8: Advanced Features + +### 8.1 Common Table Expressions (CTEs) +- [ ] **WITH ... AS (SELECT ...)** +- [ ] **Recursive CTEs** + +**Tests:** `CTEStatementTests*` + +### 8.2 XML Features +- [ ] **FOR XML** +- [ ] **OPENXML** +- [ ] **XML methods** (query, value, nodes, etc.) + +**Tests:** `ForXmlTests*`, `OpenXmlStatementTests*` + +### 8.3 JSON Features (SQL 2016+) +- [ ] **FOR JSON** +- [ ] **OPENJSON** +- [ ] **JSON functions** + +**Tests:** `JsonFunctionTests*` + +### 8.4 Fulltext Search +- [ ] **CONTAINS** +- [ ] **FREETEXT** +- [ ] **Fulltext indexes** + +**Tests:** `ContainsStatementTests*`, `FulltextTests*` + +### 8.5 Spatial Data +- [ ] **Geometry/Geography types** +- [ ] **Spatial methods** + +--- + +## Phase 9: Baseline Tests + +Once statement types are implemented, unskip corresponding baseline tests: +- `Baselines80_*` - SQL Server 2000 +- `Baselines90_*` - SQL Server 2005 +- `Baselines100_*` - SQL Server 2008 +- `Baselines110_*` - SQL Server 2012 +- `Baselines120_*` - SQL Server 2014 +- `Baselines130_*` - SQL Server 2016 +- `Baselines140_*` - SQL Server 2017 +- `Baselines150_*` - SQL Server 2019 +- `Baselines160_*` - SQL Server 2022 +- `Baselines170_*` - Future versions +- `BaselinesCommon_*` - Common tests + +--- + +## Implementation Strategy + +### For Each Feature: + +1. **Analyze test files** - Read the `query.sql` and `ast.json` for relevant tests +2. **Implement lexer tokens** - Add any new tokens to `/parser/lexer.go` +3. **Add AST types** - Create new files in `/ast/` for new node types +4. **Implement parser** - Add parsing logic to `/parser/parser.go` +5. **Add JSON marshaling** - Add `*ToJSON` functions in parser +6. **Run tests** - Execute `go test ./parser/...` +7. **Unskip tests** - Change `"skip": true` to `"skip": false` in `metadata.json` +8. **Commit** - Commit changes with descriptive message + +### Priority Order: + +1. **High Priority** - Complete SELECT support (enables many baseline tests) +2. **Medium Priority** - DML statements (INSERT, UPDATE, DELETE) +3. **Medium Priority** - Control flow (IF, WHILE, TRY/CATCH) +4. **Lower Priority** - Complex DDL (stored procedures, functions) +5. **Lowest Priority** - Advanced features (XML, JSON, Fulltext) + +--- + +## Success Metrics + +- [ ] All 1,023 tests pass without skipping +- [ ] All statement types properly generate matching AST JSON +- [ ] No regressions in currently passing tests + +--- + +## Notes + +- Tests are organized in `testdata/` with related baselines prefixed by version +- The parser uses a hand-written recursive descent approach +- AST JSON format follows the Microsoft SqlScriptDOM conventions +- Some tests may require version-specific behavior diff --git a/ast/begin_end_block_statement.go b/ast/begin_end_block_statement.go new file mode 100644 index 00000000..75707590 --- /dev/null +++ b/ast/begin_end_block_statement.go @@ -0,0 +1,14 @@ +package ast + +// BeginEndBlockStatement represents a BEGIN...END block. +type BeginEndBlockStatement struct { + StatementList *StatementList `json:"StatementList,omitempty"` +} + +func (b *BeginEndBlockStatement) node() {} +func (b *BeginEndBlockStatement) statement() {} + +// StatementList is a list of statements. +type StatementList struct { + Statements []Statement `json:"Statements,omitempty"` +} diff --git a/ast/binary_query_expression.go b/ast/binary_query_expression.go new file mode 100644 index 00000000..30146091 --- /dev/null +++ b/ast/binary_query_expression.go @@ -0,0 +1,13 @@ +package ast + +// BinaryQueryExpression represents UNION, EXCEPT, or INTERSECT queries. +type BinaryQueryExpression struct { + BinaryQueryExpressionType string `json:"BinaryQueryExpressionType,omitempty"` + All bool `json:"All"` + FirstQueryExpression QueryExpression `json:"FirstQueryExpression,omitempty"` + SecondQueryExpression QueryExpression `json:"SecondQueryExpression,omitempty"` + OrderByClause *OrderByClause `json:"OrderByClause,omitempty"` +} + +func (*BinaryQueryExpression) node() {} +func (*BinaryQueryExpression) queryExpression() {} diff --git a/ast/break_statement.go b/ast/break_statement.go new file mode 100644 index 00000000..57b144ee --- /dev/null +++ b/ast/break_statement.go @@ -0,0 +1,7 @@ +package ast + +// BreakStatement represents a BREAK statement. +type BreakStatement struct{} + +func (b *BreakStatement) node() {} +func (b *BreakStatement) statement() {} diff --git a/ast/continue_statement.go b/ast/continue_statement.go new file mode 100644 index 00000000..0015f201 --- /dev/null +++ b/ast/continue_statement.go @@ -0,0 +1,7 @@ +package ast + +// ContinueStatement represents a CONTINUE statement. +type ContinueStatement struct{} + +func (c *ContinueStatement) node() {} +func (c *ContinueStatement) statement() {} diff --git a/ast/create_schema_statement.go b/ast/create_schema_statement.go new file mode 100644 index 00000000..f7976362 --- /dev/null +++ b/ast/create_schema_statement.go @@ -0,0 +1,11 @@ +package ast + +// CreateSchemaStatement represents a CREATE SCHEMA statement. +type CreateSchemaStatement struct { + Name *Identifier `json:"Name,omitempty"` + Owner *Identifier `json:"Owner,omitempty"` + StatementList *StatementList `json:"StatementList,omitempty"` +} + +func (c *CreateSchemaStatement) node() {} +func (c *CreateSchemaStatement) statement() {} diff --git a/ast/create_table_statement.go b/ast/create_table_statement.go new file mode 100644 index 00000000..59f31443 --- /dev/null +++ b/ast/create_table_statement.go @@ -0,0 +1,108 @@ +package ast + +// CreateTableStatement represents a CREATE TABLE statement +type CreateTableStatement struct { + SchemaObjectName *SchemaObjectName + AsEdge bool + AsFileTable bool + AsNode bool + Definition *TableDefinition +} + +func (s *CreateTableStatement) node() {} +func (s *CreateTableStatement) statement() {} + +// TableDefinition represents a table definition +type TableDefinition struct { + ColumnDefinitions []*ColumnDefinition + TableConstraints []TableConstraint + Indexes []*IndexDefinition +} + +func (t *TableDefinition) node() {} + +// ColumnDefinition represents a column definition in CREATE TABLE +type ColumnDefinition struct { + ColumnIdentifier *Identifier + DataType DataTypeReference + Collation *Identifier + DefaultConstraint *DefaultConstraintDefinition + IdentityOptions *IdentityOptions + Constraints []ConstraintDefinition + IsPersisted bool + IsRowGuidCol bool + IsHidden bool + IsMasked bool + Nullable *NullableConstraintDefinition +} + +func (c *ColumnDefinition) node() {} + +// DataTypeReference is an interface for data type references +type DataTypeReference interface { + Node + dataTypeReference() +} + +// DefaultConstraintDefinition represents a DEFAULT constraint +type DefaultConstraintDefinition struct { + ConstraintIdentifier *Identifier + Expression ScalarExpression +} + +func (d *DefaultConstraintDefinition) node() {} + +// IdentityOptions represents IDENTITY options +type IdentityOptions struct { + IdentitySeed ScalarExpression + IdentityIncrement ScalarExpression + NotForReplication bool +} + +func (i *IdentityOptions) node() {} + +// ConstraintDefinition is an interface for constraint definitions +type ConstraintDefinition interface { + Node + constraintDefinition() +} + +// NullableConstraintDefinition represents a NULL or NOT NULL constraint +type NullableConstraintDefinition struct { + Nullable bool +} + +func (n *NullableConstraintDefinition) node() {} +func (n *NullableConstraintDefinition) constraintDefinition() {} + +// TableConstraint is an interface for table-level constraints +type TableConstraint interface { + Node + tableConstraint() +} + +// IndexDefinition represents an index definition within CREATE TABLE +type IndexDefinition struct { + Name *Identifier + Columns []*ColumnWithSortOrder + Unique bool +} + +func (i *IndexDefinition) node() {} + +// ColumnWithSortOrder represents a column with optional sort order +type ColumnWithSortOrder struct { + Column *ColumnReferenceExpression + SortOrder SortOrder +} + +func (c *ColumnWithSortOrder) node() {} + +// SortOrder represents sort order (ASC/DESC) +type SortOrder int + +const ( + SortOrderNotSpecified SortOrder = iota + SortOrderAscending + SortOrderDescending +) diff --git a/ast/create_view_statement.go b/ast/create_view_statement.go new file mode 100644 index 00000000..951a1ef4 --- /dev/null +++ b/ast/create_view_statement.go @@ -0,0 +1,19 @@ +package ast + +// CreateViewStatement represents a CREATE VIEW statement. +type CreateViewStatement struct { + SchemaObjectName *SchemaObjectName `json:"SchemaObjectName,omitempty"` + Columns []*Identifier `json:"Columns,omitempty"` + SelectStatement *SelectStatement `json:"SelectStatement,omitempty"` + WithCheckOption bool `json:"WithCheckOption"` + ViewOptions []ViewOption `json:"ViewOptions,omitempty"` + IsMaterialized bool `json:"IsMaterialized"` +} + +func (c *CreateViewStatement) node() {} +func (c *CreateViewStatement) statement() {} + +// ViewOption represents a view option like SCHEMABINDING. +type ViewOption struct { + OptionKind string `json:"OptionKind,omitempty"` +} diff --git a/ast/cursor_id.go b/ast/cursor_id.go new file mode 100644 index 00000000..3f758d56 --- /dev/null +++ b/ast/cursor_id.go @@ -0,0 +1,7 @@ +package ast + +// CursorId represents a cursor identifier. +type CursorId struct { + IsGlobal bool `json:"IsGlobal"` + Name *IdentifierOrValueExpression `json:"Name,omitempty"` +} diff --git a/ast/declare_variable_statement.go b/ast/declare_variable_statement.go new file mode 100644 index 00000000..52a2596d --- /dev/null +++ b/ast/declare_variable_statement.go @@ -0,0 +1,26 @@ +package ast + +// DeclareVariableStatement represents a DECLARE statement. +type DeclareVariableStatement struct { + Declarations []*DeclareVariableElement `json:"Declarations,omitempty"` +} + +func (d *DeclareVariableStatement) node() {} +func (d *DeclareVariableStatement) statement() {} + +// DeclareVariableElement represents a single variable declaration. +type DeclareVariableElement struct { + VariableName *Identifier `json:"VariableName,omitempty"` + DataType *SqlDataTypeReference `json:"DataType,omitempty"` + Value ScalarExpression `json:"Value,omitempty"` +} + +// SqlDataTypeReference represents a SQL data type. +type SqlDataTypeReference struct { + SqlDataTypeOption string `json:"SqlDataTypeOption,omitempty"` + Parameters []ScalarExpression `json:"Parameters,omitempty"` + Name *SchemaObjectName `json:"Name,omitempty"` +} + +func (s *SqlDataTypeReference) node() {} +func (s *SqlDataTypeReference) dataTypeReference() {} diff --git a/ast/default_literal.go b/ast/default_literal.go new file mode 100644 index 00000000..90c42a7a --- /dev/null +++ b/ast/default_literal.go @@ -0,0 +1,10 @@ +package ast + +// DefaultLiteral represents a DEFAULT literal. +type DefaultLiteral struct { + LiteralType string `json:"LiteralType,omitempty"` + Value string `json:"Value,omitempty"` +} + +func (d *DefaultLiteral) node() {} +func (d *DefaultLiteral) scalarExpression() {} diff --git a/ast/delete_statement.go b/ast/delete_statement.go new file mode 100644 index 00000000..a783b3df --- /dev/null +++ b/ast/delete_statement.go @@ -0,0 +1,17 @@ +package ast + +// DeleteStatement represents a DELETE statement. +type DeleteStatement struct { + DeleteSpecification *DeleteSpecification `json:"DeleteSpecification,omitempty"` + OptimizerHints []*OptimizerHint `json:"OptimizerHints,omitempty"` +} + +func (d *DeleteStatement) node() {} +func (d *DeleteStatement) statement() {} + +// DeleteSpecification contains the details of a DELETE. +type DeleteSpecification struct { + Target TableReference `json:"Target,omitempty"` + FromClause *FromClause `json:"FromClause,omitempty"` + WhereClause *WhereClause `json:"WhereClause,omitempty"` +} diff --git a/ast/execute_statement.go b/ast/execute_statement.go new file mode 100644 index 00000000..c778239e --- /dev/null +++ b/ast/execute_statement.go @@ -0,0 +1,46 @@ +package ast + +// ExecuteStatement represents an EXECUTE/EXEC statement. +type ExecuteStatement struct { + ExecuteSpecification *ExecuteSpecification `json:"ExecuteSpecification,omitempty"` +} + +func (e *ExecuteStatement) node() {} +func (e *ExecuteStatement) statement() {} + +// ExecuteSpecification contains the details of an EXECUTE. +type ExecuteSpecification struct { + Variable *VariableReference `json:"Variable,omitempty"` + ExecutableEntity ExecutableEntity `json:"ExecutableEntity,omitempty"` +} + +// ExecutableEntity is an interface for executable entities. +type ExecutableEntity interface { + executableEntity() +} + +// ExecutableProcedureReference represents a procedure reference to execute. +type ExecutableProcedureReference struct { + ProcedureReference *ProcedureReferenceName `json:"ProcedureReference,omitempty"` + Parameters []*ExecuteParameter `json:"Parameters,omitempty"` +} + +func (e *ExecutableProcedureReference) executableEntity() {} + +// ProcedureReferenceName holds either a variable or a procedure reference. +type ProcedureReferenceName struct { + ProcedureVariable *VariableReference `json:"ProcedureVariable,omitempty"` + ProcedureReference *ProcedureReference `json:"ProcedureReference,omitempty"` +} + +// ProcedureReference references a stored procedure by name. +type ProcedureReference struct { + Name *SchemaObjectName `json:"Name,omitempty"` +} + +// ExecuteParameter represents a parameter to an EXEC call. +type ExecuteParameter struct { + ParameterValue ScalarExpression `json:"ParameterValue,omitempty"` + Variable *VariableReference `json:"Variable,omitempty"` + IsOutput bool `json:"IsOutput"` +} diff --git a/ast/grant_statement.go b/ast/grant_statement.go new file mode 100644 index 00000000..9c6f48a2 --- /dev/null +++ b/ast/grant_statement.go @@ -0,0 +1,33 @@ +package ast + +// GrantStatement represents a GRANT statement +type GrantStatement struct { + Permissions []*Permission + Principals []*SecurityPrincipal + WithGrantOption bool +} + +func (s *GrantStatement) node() {} +func (s *GrantStatement) statement() {} + +// Permission represents a permission in GRANT/REVOKE +type Permission struct { + Identifiers []*Identifier +} + +func (p *Permission) node() {} + +// SecurityPrincipal represents a security principal in GRANT/REVOKE +type SecurityPrincipal struct { + PrincipalType string + Identifier *Identifier +} + +func (s *SecurityPrincipal) node() {} + +// PrincipalType values +const ( + PrincipalTypeIdentifier = "Identifier" + PrincipalTypePublic = "Public" + PrincipalTypeNull = "Null" +) diff --git a/ast/if_statement.go b/ast/if_statement.go new file mode 100644 index 00000000..7dc3bad9 --- /dev/null +++ b/ast/if_statement.go @@ -0,0 +1,11 @@ +package ast + +// IfStatement represents an IF statement. +type IfStatement struct { + Predicate BooleanExpression `json:"Predicate,omitempty"` + ThenStatement Statement `json:"ThenStatement,omitempty"` + ElseStatement Statement `json:"ElseStatement,omitempty"` +} + +func (i *IfStatement) node() {} +func (i *IfStatement) statement() {} diff --git a/ast/insert_statement.go b/ast/insert_statement.go new file mode 100644 index 00000000..a805ea82 --- /dev/null +++ b/ast/insert_statement.go @@ -0,0 +1,50 @@ +package ast + +// InsertStatement represents an INSERT statement. +type InsertStatement struct { + InsertSpecification *InsertSpecification `json:"InsertSpecification,omitempty"` + OptimizerHints []*OptimizerHint `json:"OptimizerHints,omitempty"` +} + +func (i *InsertStatement) node() {} +func (i *InsertStatement) statement() {} + +// InsertSpecification contains the details of an INSERT. +type InsertSpecification struct { + InsertOption string `json:"InsertOption,omitempty"` + InsertSource InsertSource `json:"InsertSource,omitempty"` + Target TableReference `json:"Target,omitempty"` + Columns []*ColumnReferenceExpression `json:"Columns,omitempty"` +} + +// InsertSource is an interface for INSERT sources. +type InsertSource interface { + insertSource() +} + +// ValuesInsertSource represents DEFAULT VALUES or VALUES (...). +type ValuesInsertSource struct { + IsDefaultValues bool `json:"IsDefaultValues"` + RowValues []*RowValue `json:"RowValues,omitempty"` +} + +func (v *ValuesInsertSource) insertSource() {} + +// RowValue represents a row of values. +type RowValue struct { + ColumnValues []ScalarExpression `json:"ColumnValues,omitempty"` +} + +// SelectInsertSource represents INSERT ... SELECT. +type SelectInsertSource struct { + Select QueryExpression `json:"Select,omitempty"` +} + +func (s *SelectInsertSource) insertSource() {} + +// ExecuteInsertSource represents INSERT ... EXEC. +type ExecuteInsertSource struct { + Execute *ExecuteSpecification `json:"Execute,omitempty"` +} + +func (e *ExecuteInsertSource) insertSource() {} diff --git a/ast/internal_open_rowset.go b/ast/internal_open_rowset.go new file mode 100644 index 00000000..468d9f42 --- /dev/null +++ b/ast/internal_open_rowset.go @@ -0,0 +1,11 @@ +package ast + +// InternalOpenRowset represents an OPENROWSET table reference. +type InternalOpenRowset struct { + Identifier *Identifier `json:"Identifier,omitempty"` + VarArgs []ScalarExpression `json:"VarArgs,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (i *InternalOpenRowset) node() {} +func (i *InternalOpenRowset) tableReference() {} diff --git a/ast/literal_optimizer_hint.go b/ast/literal_optimizer_hint.go new file mode 100644 index 00000000..a5ac14e1 --- /dev/null +++ b/ast/literal_optimizer_hint.go @@ -0,0 +1,7 @@ +package ast + +// LiteralOptimizerHint represents an optimizer hint with a value. +type LiteralOptimizerHint struct { + HintKind string `json:"HintKind,omitempty"` + Value ScalarExpression `json:"Value,omitempty"` +} diff --git a/ast/named_table_reference.go b/ast/named_table_reference.go index b6ada740..97d92565 100644 --- a/ast/named_table_reference.go +++ b/ast/named_table_reference.go @@ -4,6 +4,7 @@ package ast type NamedTableReference struct { SchemaObject *SchemaObjectName `json:"SchemaObject,omitempty"` Alias *Identifier `json:"Alias,omitempty"` + TableHints []*TableHint `json:"TableHints,omitempty"` ForPath bool `json:"ForPath,omitempty"` } diff --git a/ast/null_literal.go b/ast/null_literal.go new file mode 100644 index 00000000..ff7b0eec --- /dev/null +++ b/ast/null_literal.go @@ -0,0 +1,10 @@ +package ast + +// NullLiteral represents a NULL literal. +type NullLiteral struct { + LiteralType string `json:"LiteralType,omitempty"` + Value string `json:"Value,omitempty"` +} + +func (n *NullLiteral) node() {} +func (n *NullLiteral) scalarExpression() {} diff --git a/ast/numeric_literal.go b/ast/numeric_literal.go new file mode 100644 index 00000000..a33947eb --- /dev/null +++ b/ast/numeric_literal.go @@ -0,0 +1,10 @@ +package ast + +// NumericLiteral represents a numeric literal (decimal). +type NumericLiteral struct { + LiteralType string `json:"LiteralType,omitempty"` + Value string `json:"Value,omitempty"` +} + +func (*NumericLiteral) node() {} +func (*NumericLiteral) scalarExpression() {} diff --git a/ast/odbc_literal.go b/ast/odbc_literal.go new file mode 100644 index 00000000..6c4a5d4b --- /dev/null +++ b/ast/odbc_literal.go @@ -0,0 +1,12 @@ +package ast + +// OdbcLiteral represents an ODBC literal like {guid'...'}. +type OdbcLiteral struct { + LiteralType string `json:"LiteralType,omitempty"` + OdbcLiteralType string `json:"OdbcLiteralType,omitempty"` + IsNational bool `json:"IsNational"` + Value string `json:"Value,omitempty"` +} + +func (*OdbcLiteral) node() {} +func (*OdbcLiteral) scalarExpression() {} diff --git a/ast/predicate_set_statement.go b/ast/predicate_set_statement.go new file mode 100644 index 00000000..7a316bc9 --- /dev/null +++ b/ast/predicate_set_statement.go @@ -0,0 +1,40 @@ +package ast + +// PredicateSetStatement represents a SET statement like SET ANSI_NULLS ON +type PredicateSetStatement struct { + Options SetOptions + IsOn bool +} + +func (s *PredicateSetStatement) node() {} +func (s *PredicateSetStatement) statement() {} + +// SetOptions represents the options for SET statements +type SetOptions string + +const ( + SetOptionsAnsiNulls SetOptions = "AnsiNulls" + SetOptionsAnsiPadding SetOptions = "AnsiPadding" + SetOptionsAnsiWarnings SetOptions = "AnsiWarnings" + SetOptionsArithAbort SetOptions = "ArithAbort" + SetOptionsArithIgnore SetOptions = "ArithIgnore" + SetOptionsConcatNullYieldsNull SetOptions = "ConcatNullYieldsNull" + SetOptionsCursorCloseOnCommit SetOptions = "CursorCloseOnCommit" + SetOptionsFmtOnly SetOptions = "FmtOnly" + SetOptionsForceplan SetOptions = "Forceplan" + SetOptionsImplicitTransactions SetOptions = "ImplicitTransactions" + SetOptionsNoCount SetOptions = "NoCount" + SetOptionsNoExec SetOptions = "NoExec" + SetOptionsNumericRoundAbort SetOptions = "NumericRoundAbort" + SetOptionsParseOnly SetOptions = "ParseOnly" + SetOptionsQuotedIdentifier SetOptions = "QuotedIdentifier" + SetOptionsRemoteProcTransactions SetOptions = "RemoteProcTransactions" + SetOptionsShowplanAll SetOptions = "ShowplanAll" + SetOptionsShowplanText SetOptions = "ShowplanText" + SetOptionsShowplanXml SetOptions = "ShowplanXml" + SetOptionsStatisticsIo SetOptions = "StatisticsIo" + SetOptionsStatisticsProfile SetOptions = "StatisticsProfile" + SetOptionsStatisticsTime SetOptions = "StatisticsTime" + SetOptionsStatisticsXml SetOptions = "StatisticsXml" + SetOptionsXactAbort SetOptions = "XactAbort" +) diff --git a/ast/query_parenthesis_expression.go b/ast/query_parenthesis_expression.go new file mode 100644 index 00000000..1be738bc --- /dev/null +++ b/ast/query_parenthesis_expression.go @@ -0,0 +1,9 @@ +package ast + +// QueryParenthesisExpression represents a parenthesized query expression. +type QueryParenthesisExpression struct { + QueryExpression QueryExpression `json:"QueryExpression,omitempty"` +} + +func (*QueryParenthesisExpression) node() {} +func (*QueryParenthesisExpression) queryExpression() {} diff --git a/ast/query_specification.go b/ast/query_specification.go index fb2a3bc3..df4b748a 100644 --- a/ast/query_specification.go +++ b/ast/query_specification.go @@ -3,6 +3,7 @@ package ast // QuerySpecification represents a query specification (SELECT ... FROM ...). type QuerySpecification struct { UniqueRowFilter string `json:"UniqueRowFilter,omitempty"` + TopRowFilter *TopRowFilter `json:"TopRowFilter,omitempty"` SelectElements []SelectElement `json:"SelectElements,omitempty"` FromClause *FromClause `json:"FromClause,omitempty"` WhereClause *WhereClause `json:"WhereClause,omitempty"` diff --git a/ast/return_statement.go b/ast/return_statement.go new file mode 100644 index 00000000..76813248 --- /dev/null +++ b/ast/return_statement.go @@ -0,0 +1,9 @@ +package ast + +// ReturnStatement represents a RETURN statement. +type ReturnStatement struct { + Expression ScalarExpression `json:"Expression,omitempty"` +} + +func (r *ReturnStatement) node() {} +func (r *ReturnStatement) statement() {} diff --git a/ast/schema_object_function_table_reference.go b/ast/schema_object_function_table_reference.go new file mode 100644 index 00000000..3a35fc31 --- /dev/null +++ b/ast/schema_object_function_table_reference.go @@ -0,0 +1,11 @@ +package ast + +// SchemaObjectFunctionTableReference represents a function call as a table reference. +type SchemaObjectFunctionTableReference struct { + SchemaObject *SchemaObjectName `json:"SchemaObject,omitempty"` + Parameters []ScalarExpression `json:"Parameters,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (s *SchemaObjectFunctionTableReference) node() {} +func (s *SchemaObjectFunctionTableReference) tableReference() {} diff --git a/ast/schema_object_name.go b/ast/schema_object_name.go index a4894dae..b3f4f91c 100644 --- a/ast/schema_object_name.go +++ b/ast/schema_object_name.go @@ -2,9 +2,12 @@ package ast // SchemaObjectName represents a schema object name. type SchemaObjectName struct { - BaseIdentifier *Identifier `json:"BaseIdentifier,omitempty"` - Count int `json:"Count,omitempty"` - Identifiers []*Identifier `json:"Identifiers,omitempty"` + ServerIdentifier *Identifier `json:"ServerIdentifier,omitempty"` + DatabaseIdentifier *Identifier `json:"DatabaseIdentifier,omitempty"` + SchemaIdentifier *Identifier `json:"SchemaIdentifier,omitempty"` + BaseIdentifier *Identifier `json:"BaseIdentifier,omitempty"` + Count int `json:"Count,omitempty"` + Identifiers []*Identifier `json:"Identifiers,omitempty"` } func (*SchemaObjectName) node() {} diff --git a/ast/select_statement.go b/ast/select_statement.go index 36d1b55f..8a38a9d3 100644 --- a/ast/select_statement.go +++ b/ast/select_statement.go @@ -2,8 +2,9 @@ package ast // SelectStatement represents a SELECT statement. type SelectStatement struct { - QueryExpression QueryExpression `json:"QueryExpression,omitempty"` - OptimizerHints []*OptimizerHint `json:"OptimizerHints,omitempty"` + QueryExpression QueryExpression `json:"QueryExpression,omitempty"` + Into *SchemaObjectName `json:"Into,omitempty"` + OptimizerHints []*OptimizerHint `json:"OptimizerHints,omitempty"` } func (*SelectStatement) node() {} diff --git a/ast/set_variable_statement.go b/ast/set_variable_statement.go new file mode 100644 index 00000000..b8e7dad3 --- /dev/null +++ b/ast/set_variable_statement.go @@ -0,0 +1,18 @@ +package ast + +// SetVariableStatement represents a SET @var = value statement. +type SetVariableStatement struct { + Variable *VariableReference `json:"Variable,omitempty"` + Expression ScalarExpression `json:"Expression,omitempty"` + CursorDefinition *CursorDefinition `json:"CursorDefinition,omitempty"` + AssignmentKind string `json:"AssignmentKind,omitempty"` + SeparatorType string `json:"SeparatorType,omitempty"` +} + +func (s *SetVariableStatement) node() {} +func (s *SetVariableStatement) statement() {} + +// CursorDefinition represents a cursor definition. +type CursorDefinition struct { + Select QueryExpression `json:"Select,omitempty"` +} diff --git a/ast/table_hint.go b/ast/table_hint.go new file mode 100644 index 00000000..82d6fd54 --- /dev/null +++ b/ast/table_hint.go @@ -0,0 +1,6 @@ +package ast + +// TableHint represents a table hint. +type TableHint struct { + HintKind string `json:"HintKind,omitempty"` +} diff --git a/ast/top_row_filter.go b/ast/top_row_filter.go new file mode 100644 index 00000000..afa56bfc --- /dev/null +++ b/ast/top_row_filter.go @@ -0,0 +1,10 @@ +package ast + +// TopRowFilter represents a TOP clause in a SELECT statement. +type TopRowFilter struct { + Expression ScalarExpression `json:"Expression,omitempty"` + Percent bool `json:"Percent"` + WithTies bool `json:"WithTies"` +} + +func (*TopRowFilter) node() {} diff --git a/ast/unary_expression.go b/ast/unary_expression.go new file mode 100644 index 00000000..577e8cad --- /dev/null +++ b/ast/unary_expression.go @@ -0,0 +1,10 @@ +package ast + +// UnaryExpression represents a unary expression (e.g., -1, +5). +type UnaryExpression struct { + UnaryExpressionType string `json:"UnaryExpressionType,omitempty"` + Expression ScalarExpression `json:"Expression,omitempty"` +} + +func (u *UnaryExpression) node() {} +func (u *UnaryExpression) scalarExpression() {} diff --git a/ast/unqualified_join.go b/ast/unqualified_join.go new file mode 100644 index 00000000..0c272cc0 --- /dev/null +++ b/ast/unqualified_join.go @@ -0,0 +1,11 @@ +package ast + +// UnqualifiedJoin represents a CROSS JOIN or similar join without ON clause. +type UnqualifiedJoin struct { + UnqualifiedJoinType string `json:"UnqualifiedJoinType,omitempty"` + FirstTableReference TableReference `json:"FirstTableReference,omitempty"` + SecondTableReference TableReference `json:"SecondTableReference,omitempty"` +} + +func (*UnqualifiedJoin) node() {} +func (*UnqualifiedJoin) tableReference() {} diff --git a/ast/update_statement.go b/ast/update_statement.go new file mode 100644 index 00000000..348f1c50 --- /dev/null +++ b/ast/update_statement.go @@ -0,0 +1,33 @@ +package ast + +// UpdateStatement represents an UPDATE statement. +type UpdateStatement struct { + UpdateSpecification *UpdateSpecification `json:"UpdateSpecification,omitempty"` + OptimizerHints []*OptimizerHint `json:"OptimizerHints,omitempty"` +} + +func (u *UpdateStatement) node() {} +func (u *UpdateStatement) statement() {} + +// UpdateSpecification contains the details of an UPDATE. +type UpdateSpecification struct { + SetClauses []SetClause `json:"SetClauses,omitempty"` + Target TableReference `json:"Target,omitempty"` + FromClause *FromClause `json:"FromClause,omitempty"` + WhereClause *WhereClause `json:"WhereClause,omitempty"` +} + +// SetClause is an interface for SET clauses. +type SetClause interface { + setClause() +} + +// AssignmentSetClause represents column = value in UPDATE. +type AssignmentSetClause struct { + Variable *VariableReference `json:"Variable,omitempty"` + Column *ColumnReferenceExpression `json:"Column,omitempty"` + NewValue ScalarExpression `json:"NewValue,omitempty"` + AssignmentKind string `json:"AssignmentKind,omitempty"` +} + +func (a *AssignmentSetClause) setClause() {} diff --git a/ast/variable_table_reference.go b/ast/variable_table_reference.go new file mode 100644 index 00000000..59700136 --- /dev/null +++ b/ast/variable_table_reference.go @@ -0,0 +1,10 @@ +package ast + +// VariableTableReference represents a table variable reference (@var). +type VariableTableReference struct { + Variable *VariableReference `json:"Variable,omitempty"` + ForPath bool `json:"ForPath"` +} + +func (v *VariableTableReference) node() {} +func (v *VariableTableReference) tableReference() {} diff --git a/ast/where_clause.go b/ast/where_clause.go index f71dfaa3..944a92c2 100644 --- a/ast/where_clause.go +++ b/ast/where_clause.go @@ -3,6 +3,7 @@ package ast // WhereClause represents a WHERE clause. type WhereClause struct { SearchCondition BooleanExpression `json:"SearchCondition,omitempty"` + Cursor *CursorId `json:"Cursor,omitempty"` } func (*WhereClause) node() {} diff --git a/ast/while_statement.go b/ast/while_statement.go new file mode 100644 index 00000000..e186a996 --- /dev/null +++ b/ast/while_statement.go @@ -0,0 +1,10 @@ +package ast + +// WhileStatement represents a WHILE statement. +type WhileStatement struct { + Predicate BooleanExpression `json:"Predicate,omitempty"` + Statement Statement `json:"Statement,omitempty"` +} + +func (w *WhileStatement) node() {} +func (w *WhileStatement) statement() {} diff --git a/parser/lexer.go b/parser/lexer.go index 94e8270c..2f69345e 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -50,6 +50,86 @@ const ( TokenDatabase TokenScoped TokenCredential + TokenTop + TokenPercent + TokenTies + TokenInto + TokenGroup + TokenBy + TokenHaving + TokenOrder + TokenAsc + TokenDesc + TokenUnion + TokenExcept + TokenIntersect + TokenCross + TokenJoin + TokenInner + TokenLeft + TokenRight + TokenFull + TokenOuter + TokenOn + TokenRollup + TokenCube + TokenNotEqual + TokenLessOrEqual + TokenGreaterOrEqual + TokenNot + TokenLBrace + TokenRBrace + + // DML Keywords + TokenInsert + TokenUpdate + TokenDelete + TokenSet + TokenValues + TokenDefault + TokenNull + TokenExec + TokenExecute + TokenOver + + // DDL Keywords + TokenCreate + TokenView + TokenSchema + TokenProcedure + TokenFunction + TokenTrigger + TokenAuthorization + + // Control flow keywords + TokenDeclare + TokenIf + TokenElse + TokenWhile + TokenBegin + TokenEnd + TokenReturn + TokenBreak + TokenContinue + TokenGoto + TokenTry + TokenCatch + + // Additional keywords + TokenCurrent + TokenOf + TokenCursor + TokenOpenRowset + TokenHoldlock + TokenNowait + TokenFast + TokenMaxdop + + // Security keywords + TokenGrant + TokenRevoke + TokenTo + TokenPublic ) // Token represents a lexical token. @@ -136,12 +216,39 @@ func (l *Lexer) NextToken() Token { tok.Literal = "=" l.readChar() case '<': - tok.Type = TokenLessThan - tok.Literal = "<" - l.readChar() + if l.peekChar() == '>' { + l.readChar() + tok.Type = TokenNotEqual + tok.Literal = "<>" + l.readChar() + } else if l.peekChar() == '=' { + l.readChar() + tok.Type = TokenLessOrEqual + tok.Literal = "<=" + l.readChar() + } else { + tok.Type = TokenLessThan + tok.Literal = "<" + l.readChar() + } case '>': - tok.Type = TokenGreaterThan - tok.Literal = ">" + if l.peekChar() == '=' { + l.readChar() + tok.Type = TokenGreaterOrEqual + tok.Literal = ">=" + l.readChar() + } else { + tok.Type = TokenGreaterThan + tok.Literal = ">" + l.readChar() + } + case '{': + tok.Type = TokenLBrace + tok.Literal = "{" + l.readChar() + case '}': + tok.Type = TokenRBrace + tok.Literal = "}" l.readChar() case '+': tok.Type = TokenPlus @@ -284,27 +391,93 @@ func isDigit(ch byte) bool { } var keywords = map[string]TokenType{ - "SELECT": TokenSelect, - "FROM": TokenFrom, - "WHERE": TokenWhere, - "AND": TokenAnd, - "OR": TokenOr, - "AS": TokenAs, - "OPTION": TokenOption, - "ALL": TokenAll, - "DISTINCT": TokenDistinct, - "PRINT": TokenPrint, - "THROW": TokenThrow, - "ALTER": TokenAlter, - "TABLE": TokenTable, - "DROP": TokenDrop, - "INDEX": TokenIndex, - "REVERT": TokenRevert, - "WITH": TokenWith, - "COOKIE": TokenCookie, - "DATABASE": TokenDatabase, - "SCOPED": TokenScoped, - "CREDENTIAL": TokenCredential, + "SELECT": TokenSelect, + "FROM": TokenFrom, + "WHERE": TokenWhere, + "AND": TokenAnd, + "OR": TokenOr, + "AS": TokenAs, + "OPTION": TokenOption, + "ALL": TokenAll, + "DISTINCT": TokenDistinct, + "PRINT": TokenPrint, + "THROW": TokenThrow, + "ALTER": TokenAlter, + "TABLE": TokenTable, + "DROP": TokenDrop, + "INDEX": TokenIndex, + "REVERT": TokenRevert, + "WITH": TokenWith, + "COOKIE": TokenCookie, + "DATABASE": TokenDatabase, + "SCOPED": TokenScoped, + "CREDENTIAL": TokenCredential, + "TOP": TokenTop, + "PERCENT": TokenPercent, + "TIES": TokenTies, + "INTO": TokenInto, + "GROUP": TokenGroup, + "BY": TokenBy, + "HAVING": TokenHaving, + "ORDER": TokenOrder, + "ASC": TokenAsc, + "DESC": TokenDesc, + "UNION": TokenUnion, + "EXCEPT": TokenExcept, + "INTERSECT": TokenIntersect, + "CROSS": TokenCross, + "JOIN": TokenJoin, + "INNER": TokenInner, + "LEFT": TokenLeft, + "RIGHT": TokenRight, + "FULL": TokenFull, + "OUTER": TokenOuter, + "ON": TokenOn, + "ROLLUP": TokenRollup, + "CUBE": TokenCube, + "NOT": TokenNot, + "INSERT": TokenInsert, + "UPDATE": TokenUpdate, + "DELETE": TokenDelete, + "SET": TokenSet, + "VALUES": TokenValues, + "DEFAULT": TokenDefault, + "NULL": TokenNull, + "EXEC": TokenExec, + "EXECUTE": TokenExecute, + "OVER": TokenOver, + "CREATE": TokenCreate, + "VIEW": TokenView, + "SCHEMA": TokenSchema, + "PROCEDURE": TokenProcedure, + "PROC": TokenProcedure, + "FUNCTION": TokenFunction, + "TRIGGER": TokenTrigger, + "AUTHORIZATION": TokenAuthorization, + "DECLARE": TokenDeclare, + "IF": TokenIf, + "ELSE": TokenElse, + "WHILE": TokenWhile, + "BEGIN": TokenBegin, + "END": TokenEnd, + "RETURN": TokenReturn, + "BREAK": TokenBreak, + "CONTINUE": TokenContinue, + "GOTO": TokenGoto, + "TRY": TokenTry, + "CATCH": TokenCatch, + "CURRENT": TokenCurrent, + "OF": TokenOf, + "CURSOR": TokenCursor, + "OPENROWSET": TokenOpenRowset, + "HOLDLOCK": TokenHoldlock, + "NOWAIT": TokenNowait, + "FAST": TokenFast, + "MAXDOP": TokenMaxdop, + "GRANT": TokenGrant, + "REVOKE": TokenRevoke, + "TO": TokenTo, + "PUBLIC": TokenPublic, } func lookupKeyword(ident string) TokenType { diff --git a/parser/parser.go b/parser/parser.go index 1bed31cc..47a45fdd 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -51,12 +51,14 @@ func (p *Parser) parseScript() (*ast.Script, error) { script := &ast.Script{} // Parse all batches (separated by GO) - batch, err := p.parseBatch() - if err != nil { - return nil, err - } - if batch != nil && len(batch.Statements) > 0 { - script.Batches = append(script.Batches, batch) + for p.curTok.Type != TokenEOF { + batch, err := p.parseBatch() + if err != nil { + return nil, err + } + if batch != nil && len(batch.Statements) > 0 { + script.Batches = append(script.Batches, batch) + } } return script, nil @@ -66,10 +68,10 @@ func (p *Parser) parseBatch() (*ast.Batch, error) { batch := &ast.Batch{} for p.curTok.Type != TokenEOF { - // Skip GO statements (batch separators) + // Stop at GO statements (batch separators) if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "GO" { p.nextToken() - continue + break } stmt, err := p.parseStatement() @@ -86,8 +88,28 @@ func (p *Parser) parseBatch() (*ast.Batch, error) { func (p *Parser) parseStatement() (ast.Statement, error) { switch p.curTok.Type { - case TokenSelect: + case TokenSelect, TokenLParen: return p.parseSelectStatement() + case TokenInsert: + return p.parseInsertStatement() + case TokenUpdate: + return p.parseUpdateStatement() + case TokenDelete: + return p.parseDeleteStatement() + case TokenDeclare: + return p.parseDeclareVariableStatement() + case TokenSet: + return p.parseSetVariableStatement() + case TokenIf: + return p.parseIfStatement() + case TokenWhile: + return p.parseWhileStatement() + case TokenBegin: + return p.parseBeginEndBlockStatement() + case TokenCreate: + return p.parseCreateStatement() + case TokenExec, TokenExecute: + return p.parseExecuteStatement() case TokenPrint: return p.parsePrintStatement() case TokenThrow: @@ -98,6 +120,14 @@ func (p *Parser) parseStatement() (ast.Statement, error) { return p.parseRevertStatement() case TokenDrop: return p.parseDropStatement() + case TokenReturn: + return p.parseReturnStatement() + case TokenBreak: + return p.parseBreakStatement() + case TokenContinue: + return p.parseContinueStatement() + case TokenGrant: + return p.parseGrantStatement() case TokenSemicolon: p.nextToken() return nil, nil @@ -344,12 +374,13 @@ func (p *Parser) parseThrowStatement() (*ast.ThrowStatement, error) { func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) { stmt := &ast.SelectStatement{} - // Parse query expression - qe, err := p.parseQueryExpression() + // Parse query expression (handles UNION, parens, etc.) + qe, into, err := p.parseQueryExpressionWithInto() if err != nil { return nil, err } stmt.QueryExpression = qe + stmt.Into = into // Parse optional OPTION clause if p.curTok.Type == TokenOption { @@ -369,10 +400,161 @@ func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) { } func (p *Parser) parseQueryExpression() (ast.QueryExpression, error) { - return p.parseQuerySpecification() + qe, _, err := p.parseQueryExpressionWithInto() + return qe, err +} + +func (p *Parser) parseQueryExpressionWithInto() (ast.QueryExpression, *ast.SchemaObjectName, error) { + // Parse primary query expression (could be SELECT or parenthesized) + left, into, err := p.parsePrimaryQueryExpression() + if err != nil { + return nil, nil, err + } + + // Track if we have any binary operations + hasBinaryOp := false + + // Check for binary operations (UNION, EXCEPT, INTERSECT) + for p.curTok.Type == TokenUnion || p.curTok.Type == TokenExcept || p.curTok.Type == TokenIntersect { + hasBinaryOp = true + var opType string + switch p.curTok.Type { + case TokenUnion: + opType = "Union" + case TokenExcept: + opType = "Except" + case TokenIntersect: + opType = "Intersect" + } + p.nextToken() + + // Check for ALL + all := false + if p.curTok.Type == TokenAll { + all = true + p.nextToken() + } + + // Parse the right side + right, rightInto, err := p.parsePrimaryQueryExpression() + if err != nil { + return nil, nil, err + } + + // INTO can only appear in the first query of a UNION + if rightInto != nil && into == nil { + into = rightInto + } + + bqe := &ast.BinaryQueryExpression{ + BinaryQueryExpressionType: opType, + All: all, + FirstQueryExpression: left, + SecondQueryExpression: right, + } + + left = bqe + } + + // Parse ORDER BY after all UNION operations + if p.curTok.Type == TokenOrder { + obc, err := p.parseOrderByClause() + if err != nil { + return nil, nil, err + } + + if hasBinaryOp { + // Attach to BinaryQueryExpression + if bqe, ok := left.(*ast.BinaryQueryExpression); ok { + bqe.OrderByClause = obc + } + } else { + // Attach to QuerySpecification + if qs, ok := left.(*ast.QuerySpecification); ok { + qs.OrderByClause = obc + } + } + } + + return left, into, nil +} + +func (p *Parser) parsePrimaryQueryExpression() (ast.QueryExpression, *ast.SchemaObjectName, error) { + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + qe, into, err := p.parseQueryExpressionWithInto() + if err != nil { + return nil, nil, err + } + if p.curTok.Type != TokenRParen { + return nil, nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) + } + p.nextToken() // consume ) + return &ast.QueryParenthesisExpression{QueryExpression: qe}, into, nil + } + + return p.parseQuerySpecificationWithInto() +} + +func (p *Parser) parseQuerySpecificationWithInto() (*ast.QuerySpecification, *ast.SchemaObjectName, error) { + qs, err := p.parseQuerySpecificationCore() + if err != nil { + return nil, nil, err + } + + // Check for INTO clause after SELECT elements, before FROM + var into *ast.SchemaObjectName + if p.curTok.Type == TokenInto { + p.nextToken() // consume INTO + into, err = p.parseSchemaObjectName() + if err != nil { + return nil, nil, err + } + } + + // Parse optional FROM clause + if p.curTok.Type == TokenFrom { + fromClause, err := p.parseFromClause() + if err != nil { + return nil, nil, err + } + qs.FromClause = fromClause + } + + // Parse optional WHERE clause + if p.curTok.Type == TokenWhere { + whereClause, err := p.parseWhereClause() + if err != nil { + return nil, nil, err + } + qs.WhereClause = whereClause + } + + // Parse optional GROUP BY clause + if p.curTok.Type == TokenGroup { + groupByClause, err := p.parseGroupByClause() + if err != nil { + return nil, nil, err + } + qs.GroupByClause = groupByClause + } + + // Parse optional HAVING clause + if p.curTok.Type == TokenHaving { + havingClause, err := p.parseHavingClause() + if err != nil { + return nil, nil, err + } + qs.HavingClause = havingClause + } + + // Note: ORDER BY is parsed at the top level in parseQueryExpressionWithInto + // to correctly handle UNION/EXCEPT/INTERSECT cases + + return qs, into, nil } -func (p *Parser) parseQuerySpecification() (*ast.QuerySpecification, error) { +func (p *Parser) parseQuerySpecificationCore() (*ast.QuerySpecification, error) { qs := &ast.QuerySpecification{ UniqueRowFilter: "NotSpecified", } @@ -392,6 +574,15 @@ func (p *Parser) parseQuerySpecification() (*ast.QuerySpecification, error) { p.nextToken() } + // Check for TOP clause + if p.curTok.Type == TokenTop { + top, err := p.parseTopRowFilter() + if err != nil { + return nil, err + } + qs.TopRowFilter = top + } + // Parse select elements elements, err := p.parseSelectElements() if err != nil { @@ -399,16 +590,52 @@ func (p *Parser) parseQuerySpecification() (*ast.QuerySpecification, error) { } qs.SelectElements = elements - // Parse optional FROM clause - if p.curTok.Type == TokenFrom { - fromClause, err := p.parseFromClause() + return qs, nil +} + +func (p *Parser) parseTopRowFilter() (*ast.TopRowFilter, error) { + // Consume TOP + p.nextToken() + + top := &ast.TopRowFilter{} + + // Check for parenthesized expression + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + expr, err := p.parseScalarExpression() if err != nil { return nil, err } - qs.FromClause = fromClause + top.Expression = expr + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) + } + p.nextToken() // consume ) + } else { + // Parse literal expression + expr, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + top.Expression = expr } - return qs, nil + // Check for PERCENT + if p.curTok.Type == TokenPercent { + top.Percent = true + p.nextToken() + } + + // Check for WITH TIES + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if p.curTok.Type == TokenTies { + top.WithTies = true + p.nextToken() + } + } + + return top, nil } func (p *Parser) parseSelectElements() ([]ast.SelectElement, error) { @@ -443,7 +670,54 @@ func (p *Parser) parseSelectElement() (ast.SelectElement, error) { return nil, err } - return &ast.SelectScalarExpression{Expression: expr}, nil + sse := &ast.SelectScalarExpression{Expression: expr} + + // Check for column alias: [alias], AS alias, or just alias + if p.curTok.Type == TokenIdent && p.curTok.Literal[0] == '[' { + // Bracketed alias without AS + alias := p.parseIdentifier() + sse.ColumnName = &ast.IdentifierOrValueExpression{ + Value: alias.Value, + Identifier: alias, + } + } else if p.curTok.Type == TokenAs { + p.nextToken() // consume AS + alias := p.parseIdentifier() + sse.ColumnName = &ast.IdentifierOrValueExpression{ + Value: alias.Value, + Identifier: alias, + } + } else if p.curTok.Type == TokenIdent { + // Check if this is an alias (not a keyword that starts a new clause) + upper := strings.ToUpper(p.curTok.Literal) + if upper != "FROM" && upper != "WHERE" && upper != "GROUP" && upper != "HAVING" && upper != "ORDER" && upper != "OPTION" && upper != "INTO" && upper != "UNION" && upper != "EXCEPT" && upper != "INTERSECT" && upper != "GO" { + alias := p.parseIdentifier() + sse.ColumnName = &ast.IdentifierOrValueExpression{ + Value: alias.Value, + Identifier: alias, + } + } + } + + return sse, nil +} + +func (p *Parser) parseIdentifier() *ast.Identifier { + literal := p.curTok.Literal + quoteType := "NotQuoted" + + // Handle bracketed identifiers + if len(literal) >= 2 && literal[0] == '[' && literal[len(literal)-1] == ']' { + quoteType = "SquareBracket" + literal = literal[1 : len(literal)-1] + } + + id := &ast.Identifier{ + Value: literal, + QuoteType: quoteType, + } + p.nextToken() + return id } func (p *Parser) parseScalarExpression() (ast.ScalarExpression, error) { @@ -482,6 +756,27 @@ func (p *Parser) parseAdditiveExpression() (ast.ScalarExpression, error) { func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { switch p.curTok.Type { + case TokenNull: + p.nextToken() + return &ast.NullLiteral{LiteralType: "Null", Value: "null"}, nil + case TokenDefault: + val := p.curTok.Literal + p.nextToken() + return &ast.DefaultLiteral{LiteralType: "Default", Value: val}, nil + case TokenMinus: + p.nextToken() + expr, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + return &ast.UnaryExpression{UnaryExpressionType: "Negative", Expression: expr}, nil + case TokenPlus: + p.nextToken() + expr, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + return &ast.UnaryExpression{UnaryExpressionType: "Positive", Expression: expr}, nil case TokenIdent: // Check if it's a variable reference (starts with @) if strings.HasPrefix(p.curTok.Literal, "@") { @@ -493,14 +788,77 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { case TokenNumber: val := p.curTok.Literal p.nextToken() + // Check if it's a decimal number + if strings.Contains(val, ".") { + return &ast.NumericLiteral{LiteralType: "Numeric", Value: val}, nil + } return &ast.IntegerLiteral{LiteralType: "Integer", Value: val}, nil case TokenString: return p.parseStringLiteral() + case TokenLBrace: + return p.parseOdbcLiteral() + case TokenLParen: + // Parenthesized expression or subquery + p.nextToken() + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) + } + p.nextToken() + return expr, nil default: return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal) } } +func (p *Parser) parseOdbcLiteral() (*ast.OdbcLiteral, error) { + // Consume { + p.nextToken() + + // Expect "guid" identifier + if p.curTok.Type != TokenIdent || strings.ToLower(p.curTok.Literal) != "guid" { + return nil, fmt.Errorf("expected guid in ODBC literal, got %s", p.curTok.Literal) + } + p.nextToken() + + // Check for N prefix for national string + isNational := false + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "N" { + isNational = true + p.nextToken() + } + + // Expect string literal + if p.curTok.Type != TokenString { + return nil, fmt.Errorf("expected string in ODBC literal, got %s", p.curTok.Literal) + } + + raw := p.curTok.Literal + p.nextToken() + + // Remove surrounding quotes + value := raw + if len(raw) >= 2 && raw[0] == '\'' && raw[len(raw)-1] == '\'' { + value = raw[1 : len(raw)-1] + } + + // Consume } + if p.curTok.Type != TokenRBrace { + return nil, fmt.Errorf("expected } in ODBC literal, got %s", p.curTok.Literal) + } + p.nextToken() + + return &ast.OdbcLiteral{ + LiteralType: "Odbc", + OdbcLiteralType: "Guid", + IsNational: isNational, + Value: value, + }, nil +} + func (p *Parser) parseStringLiteral() (*ast.StringLiteral, error) { raw := p.curTok.Literal p.nextToken() @@ -583,7 +941,97 @@ func (p *Parser) parseFromClause() (*ast.FromClause, error) { } func (p *Parser) parseTableReference() (ast.TableReference, error) { - return p.parseNamedTableReference() + // Parse the base table reference + baseRef, err := p.parseNamedTableReference() + if err != nil { + return nil, err + } + var left ast.TableReference = baseRef + + // Check for JOINs + for { + // Check for CROSS JOIN + if p.curTok.Type == TokenCross { + p.nextToken() // consume CROSS + if p.curTok.Type != TokenJoin { + return nil, fmt.Errorf("expected JOIN after CROSS, got %s", p.curTok.Literal) + } + p.nextToken() // consume JOIN + + right, err := p.parseNamedTableReference() + if err != nil { + return nil, err + } + + left = &ast.UnqualifiedJoin{ + UnqualifiedJoinType: "CrossJoin", + FirstTableReference: left, + SecondTableReference: right, + } + continue + } + + // Check for qualified JOINs (INNER, LEFT, RIGHT, FULL) + joinType := "" + if p.curTok.Type == TokenInner { + joinType = "Inner" + p.nextToken() + } else if p.curTok.Type == TokenLeft { + joinType = "LeftOuter" + p.nextToken() + if p.curTok.Type == TokenOuter { + p.nextToken() + } + } else if p.curTok.Type == TokenRight { + joinType = "RightOuter" + p.nextToken() + if p.curTok.Type == TokenOuter { + p.nextToken() + } + } else if p.curTok.Type == TokenFull { + joinType = "FullOuter" + p.nextToken() + if p.curTok.Type == TokenOuter { + p.nextToken() + } + } else if p.curTok.Type == TokenJoin { + joinType = "Inner" + } + + if joinType == "" { + break + } + + if p.curTok.Type != TokenJoin { + return nil, fmt.Errorf("expected JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume JOIN + + right, err := p.parseNamedTableReference() + if err != nil { + return nil, err + } + + // Parse ON clause + if p.curTok.Type != TokenOn { + return nil, fmt.Errorf("expected ON after JOIN, got %s", p.curTok.Literal) + } + p.nextToken() // consume ON + + condition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + left = &ast.QualifiedJoin{ + QualifiedJoinType: joinType, + FirstTableReference: left, + SecondTableReference: right, + SearchCondition: condition, + } + } + + return left, nil } func (p *Parser) parseNamedTableReference() (*ast.NamedTableReference, error) { @@ -622,16 +1070,23 @@ func (p *Parser) parseSchemaObjectName() (*ast.SchemaObjectName, error) { var identifiers []*ast.Identifier for { + // Handle empty parts (e.g., myDb..table means myDb..table) + if p.curTok.Type == TokenDot { + // Add an empty identifier for the missing part + identifiers = append(identifiers, &ast.Identifier{ + Value: "", + QuoteType: "NotQuoted", + }) + p.nextToken() // consume dot + continue + } + if p.curTok.Type != TokenIdent { break } - id := &ast.Identifier{ - Value: p.curTok.Literal, - QuoteType: "NotQuoted", - } + id := p.parseIdentifier() identifiers = append(identifiers, id) - p.nextToken() if p.curTok.Type != TokenDot { break @@ -643,14 +1098,42 @@ func (p *Parser) parseSchemaObjectName() (*ast.SchemaObjectName, error) { return nil, fmt.Errorf("expected identifier for schema object name") } - // BaseIdentifier is the last identifier - baseId := identifiers[len(identifiers)-1] + // Filter out nil identifiers for the count and assignment + var nonNilIdentifiers []*ast.Identifier + for _, id := range identifiers { + if id != nil { + nonNilIdentifiers = append(nonNilIdentifiers, id) + } + } - return &ast.SchemaObjectName{ - BaseIdentifier: baseId, - Count: len(identifiers), - Identifiers: identifiers, - }, nil + son := &ast.SchemaObjectName{ + Count: len(identifiers), + Identifiers: identifiers, + } + + // Set the appropriate identifier fields based on count + // server.database.schema.table (4 parts) + // database.schema.table (3 parts) + // schema.table (2 parts) - but with .., schema is nil + // table (1 part) + switch len(identifiers) { + case 4: + son.ServerIdentifier = identifiers[0] + son.DatabaseIdentifier = identifiers[1] + son.SchemaIdentifier = identifiers[2] + son.BaseIdentifier = identifiers[3] + case 3: + son.DatabaseIdentifier = identifiers[0] + son.SchemaIdentifier = identifiers[1] + son.BaseIdentifier = identifiers[2] + case 2: + son.SchemaIdentifier = identifiers[0] + son.BaseIdentifier = identifiers[1] + case 1: + son.BaseIdentifier = identifiers[0] + } + + return son, nil } func (p *Parser) parseOptionClause() ([]*ast.OptimizerHint, error) { @@ -702,47 +1185,1544 @@ func convertHintKind(hint string) string { return hint } -// jsonNode represents a generic JSON node from the AST JSON format. -type jsonNode map[string]any +func (p *Parser) parseWhereClause() (*ast.WhereClause, error) { + // Consume WHERE + p.nextToken() -// MarshalScript marshals a Script to JSON in the expected format. -func MarshalScript(s *ast.Script) ([]byte, error) { - node := scriptToJSON(s) - return json.MarshalIndent(node, "", " ") + condition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + return &ast.WhereClause{SearchCondition: condition}, nil } -func scriptToJSON(s *ast.Script) jsonNode { - node := jsonNode{ - "$type": "TSqlScript", - } - if len(s.Batches) > 0 { - batches := make([]jsonNode, len(s.Batches)) - for i, b := range s.Batches { - batches[i] = batchToJSON(b) - } - node["Batches"] = batches +func (p *Parser) parseGroupByClause() (*ast.GroupByClause, error) { + // Consume GROUP + p.nextToken() + + if p.curTok.Type != TokenBy { + return nil, fmt.Errorf("expected BY after GROUP, got %s", p.curTok.Literal) } - return node -} + p.nextToken() // consume BY -func batchToJSON(b *ast.Batch) jsonNode { - node := jsonNode{ - "$type": "TSqlBatch", + gbc := &ast.GroupByClause{ + GroupByOption: "None", + All: false, } - if len(b.Statements) > 0 { - stmts := make([]jsonNode, len(b.Statements)) - for i, stmt := range b.Statements { - stmts[i] = statementToJSON(stmt) - } - node["Statements"] = stmts + + // Check for ALL + if p.curTok.Type == TokenAll { + gbc.All = true + p.nextToken() } - return node -} + + // Parse grouping specifications + for { + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + + spec := &ast.ExpressionGroupingSpecification{ + Expression: expr, + DistributedAggregation: false, + } + gbc.GroupingSpecifications = append(gbc.GroupingSpecifications, spec) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() // consume comma + } + + // Check for WITH ROLLUP or WITH CUBE + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + if p.curTok.Type == TokenRollup { + gbc.GroupByOption = "Rollup" + p.nextToken() + } else if p.curTok.Type == TokenCube { + gbc.GroupByOption = "Cube" + p.nextToken() + } + } + + return gbc, nil +} + +func (p *Parser) parseHavingClause() (*ast.HavingClause, error) { + // Consume HAVING + p.nextToken() + + condition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + return &ast.HavingClause{SearchCondition: condition}, nil +} + +func (p *Parser) parseOrderByClause() (*ast.OrderByClause, error) { + // Consume ORDER + p.nextToken() + + if p.curTok.Type != TokenBy { + return nil, fmt.Errorf("expected BY after ORDER, got %s", p.curTok.Literal) + } + p.nextToken() // consume BY + + obc := &ast.OrderByClause{} + + // Parse order by elements + for { + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + + elem := &ast.ExpressionWithSortOrder{ + Expression: expr, + SortOrder: "NotSpecified", + } + + // Check for ASC or DESC + if p.curTok.Type == TokenAsc { + elem.SortOrder = "Ascending" + p.nextToken() + } else if p.curTok.Type == TokenDesc { + elem.SortOrder = "Descending" + p.nextToken() + } + + obc.OrderByElements = append(obc.OrderByElements, elem) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() // consume comma + } + + return obc, nil +} + +func (p *Parser) parseBooleanExpression() (ast.BooleanExpression, error) { + return p.parseBooleanOrExpression() +} + +func (p *Parser) parseBooleanOrExpression() (ast.BooleanExpression, error) { + left, err := p.parseBooleanAndExpression() + if err != nil { + return nil, err + } + + for p.curTok.Type == TokenOr { + p.nextToken() // consume OR + + right, err := p.parseBooleanAndExpression() + if err != nil { + return nil, err + } + + left = &ast.BooleanBinaryExpression{ + BinaryExpressionType: "Or", + FirstExpression: left, + SecondExpression: right, + } + } + + return left, nil +} + +func (p *Parser) parseBooleanAndExpression() (ast.BooleanExpression, error) { + left, err := p.parseBooleanPrimaryExpression() + if err != nil { + return nil, err + } + + for p.curTok.Type == TokenAnd { + p.nextToken() // consume AND + + right, err := p.parseBooleanPrimaryExpression() + if err != nil { + return nil, err + } + + left = &ast.BooleanBinaryExpression{ + BinaryExpressionType: "And", + FirstExpression: left, + SecondExpression: right, + } + } + + return left, nil +} + +func (p *Parser) parseBooleanPrimaryExpression() (ast.BooleanExpression, error) { + // Parse left scalar expression + left, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + + // Check for comparison operator + var compType string + switch p.curTok.Type { + case TokenEquals: + compType = "Equals" + case TokenNotEqual: + compType = "NotEqualToBrackets" + case TokenLessThan: + compType = "LessThan" + case TokenGreaterThan: + compType = "GreaterThan" + case TokenLessOrEqual: + compType = "LessThanOrEqualTo" + case TokenGreaterOrEqual: + compType = "GreaterThanOrEqualTo" + default: + return nil, fmt.Errorf("expected comparison operator, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse right scalar expression + right, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + + return &ast.BooleanComparisonExpression{ + ComparisonType: compType, + FirstExpression: left, + SecondExpression: right, + }, nil +} + +// ======================= New Statement Parsing Functions ======================= + +func (p *Parser) parseInsertStatement() (*ast.InsertStatement, error) { + // Consume INSERT + p.nextToken() + + stmt := &ast.InsertStatement{ + InsertSpecification: &ast.InsertSpecification{ + InsertOption: "None", + }, + } + + // Check for INTO or OVER + if p.curTok.Type == TokenInto { + stmt.InsertSpecification.InsertOption = "Into" + p.nextToken() + } else if p.curTok.Type == TokenOver { + stmt.InsertSpecification.InsertOption = "Over" + p.nextToken() + } + + // Parse target + target, err := p.parseDMLTarget() + if err != nil { + return nil, err + } + stmt.InsertSpecification.Target = target + + // Parse optional column list + if p.curTok.Type == TokenLParen { + cols, err := p.parseColumnList() + if err != nil { + return nil, err + } + stmt.InsertSpecification.Columns = cols + } + + // Parse insert source + source, err := p.parseInsertSource() + if err != nil { + return nil, err + } + stmt.InsertSpecification.InsertSource = source + + // Parse optional OPTION clause + if p.curTok.Type == TokenOption { + hints, err := p.parseOptionClause() + if err != nil { + return nil, err + } + stmt.OptimizerHints = hints + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseDMLTarget() (ast.TableReference, error) { + // Check for variable + if p.curTok.Type == TokenIdent && strings.HasPrefix(p.curTok.Literal, "@") { + name := p.curTok.Literal + p.nextToken() + return &ast.VariableTableReference{ + Variable: &ast.VariableReference{Name: name}, + ForPath: false, + }, nil + } + + // Check for OPENROWSET + if p.curTok.Type == TokenOpenRowset { + return p.parseOpenRowset() + } + + // Parse schema object name + son, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + + // Check for function call (has parentheses) + if p.curTok.Type == TokenLParen { + params, err := p.parseFunctionParameters() + if err != nil { + return nil, err + } + return &ast.SchemaObjectFunctionTableReference{ + SchemaObject: son, + Parameters: params, + ForPath: false, + }, nil + } + + ref := &ast.NamedTableReference{ + SchemaObject: son, + ForPath: false, + } + + // Check for table hints WITH (...) + if p.curTok.Type == TokenWith { + hints, err := p.parseTableHints() + if err != nil { + return nil, err + } + ref.TableHints = hints + } + + return ref, nil +} + +func (p *Parser) parseOpenRowset() (*ast.InternalOpenRowset, error) { + // Consume OPENROWSET + p.nextToken() + + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after OPENROWSET, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse identifier + if p.curTok.Type != TokenIdent { + return nil, fmt.Errorf("expected identifier in OPENROWSET, got %s", p.curTok.Literal) + } + id := &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"} + p.nextToken() + + var varArgs []ast.ScalarExpression + for p.curTok.Type == TokenComma { + p.nextToken() + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + varArgs = append(varArgs, expr) + } + + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) in OPENROWSET, got %s", p.curTok.Literal) + } + p.nextToken() + + return &ast.InternalOpenRowset{ + Identifier: id, + VarArgs: varArgs, + ForPath: false, + }, nil +} + +func (p *Parser) parseFunctionParameters() ([]ast.ScalarExpression, error) { + // Consume ( + p.nextToken() + + var params []ast.ScalarExpression + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + params = append(params, expr) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) + } + p.nextToken() + + return params, nil +} + +func (p *Parser) parseTableHints() ([]*ast.TableHint, error) { + // Consume WITH + p.nextToken() + + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after WITH, got %s", p.curTok.Literal) + } + p.nextToken() + + var hints []*ast.TableHint + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + if p.curTok.Type == TokenIdent || p.curTok.Type == TokenHoldlock || p.curTok.Type == TokenNowait { + hintKind := convertTableHintKind(p.curTok.Literal) + hints = append(hints, &ast.TableHint{HintKind: hintKind}) + p.nextToken() + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + + if p.curTok.Type == TokenRParen { + p.nextToken() + } + + return hints, nil +} + +func convertTableHintKind(hint string) string { + hintMap := map[string]string{ + "HOLDLOCK": "HoldLock", + "NOWAIT": "NoWait", + "NOLOCK": "NoLock", + "UPDLOCK": "UpdLock", + "XLOCK": "XLock", + } + if mapped, ok := hintMap[strings.ToUpper(hint)]; ok { + return mapped + } + return hint +} + +func (p *Parser) parseColumnList() ([]*ast.ColumnReferenceExpression, error) { + // Consume ( + p.nextToken() + + var cols []*ast.ColumnReferenceExpression + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col, err := p.parseMultiPartIdentifierAsColumn() + if err != nil { + return nil, err + } + cols = append(cols, col) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) + } + p.nextToken() + + return cols, nil +} + +func (p *Parser) parseMultiPartIdentifierAsColumn() (*ast.ColumnReferenceExpression, error) { + var identifiers []*ast.Identifier + + for { + // Handle empty parts (e.g., ..a means two empty parts then a) + if p.curTok.Type == TokenDot { + identifiers = append(identifiers, &ast.Identifier{Value: "", QuoteType: "NotQuoted"}) + p.nextToken() + continue + } + + if p.curTok.Type != TokenIdent { + break + } + + id := p.parseIdentifier() + identifiers = append(identifiers, id) + + if p.curTok.Type != TokenDot { + break + } + p.nextToken() + } + + return &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Count: len(identifiers), + Identifiers: identifiers, + }, + }, nil +} + +func (p *Parser) parseInsertSource() (ast.InsertSource, error) { + // Check for DEFAULT VALUES + if p.curTok.Type == TokenDefault { + p.nextToken() + if p.curTok.Type == TokenValues { + p.nextToken() + return &ast.ValuesInsertSource{IsDefaultValues: true}, nil + } + return nil, fmt.Errorf("expected VALUES after DEFAULT, got %s", p.curTok.Literal) + } + + // Check for VALUES (...) + if p.curTok.Type == TokenValues { + return p.parseValuesInsertSource() + } + + // Check for EXEC/EXECUTE + if p.curTok.Type == TokenExec || p.curTok.Type == TokenExecute { + return p.parseExecuteInsertSource() + } + + // Otherwise it's a SELECT + qe, err := p.parseQueryExpression() + if err != nil { + return nil, err + } + return &ast.SelectInsertSource{Select: qe}, nil +} + +func (p *Parser) parseValuesInsertSource() (*ast.ValuesInsertSource, error) { + // Consume VALUES + p.nextToken() + + source := &ast.ValuesInsertSource{IsDefaultValues: false} + + // Parse row values + for { + if p.curTok.Type != TokenLParen { + break + } + row, err := p.parseRowValue() + if err != nil { + return nil, err + } + source.RowValues = append(source.RowValues, row) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + return source, nil +} + +func (p *Parser) parseRowValue() (*ast.RowValue, error) { + // Consume ( + p.nextToken() + + row := &ast.RowValue{} + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + row.ColumnValues = append(row.ColumnValues, expr) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ), got %s", p.curTok.Literal) + } + p.nextToken() + + return row, nil +} + +func (p *Parser) parseExecuteInsertSource() (*ast.ExecuteInsertSource, error) { + execSpec, err := p.parseExecuteSpecification() + if err != nil { + return nil, err + } + return &ast.ExecuteInsertSource{Execute: execSpec}, nil +} + +func (p *Parser) parseExecuteSpecification() (*ast.ExecuteSpecification, error) { + // Consume EXEC/EXECUTE + p.nextToken() + + spec := &ast.ExecuteSpecification{} + + // Check for return variable assignment @var = + if p.curTok.Type == TokenIdent && strings.HasPrefix(p.curTok.Literal, "@") { + varName := p.curTok.Literal + p.nextToken() + if p.curTok.Type == TokenEquals { + spec.Variable = &ast.VariableReference{Name: varName} + p.nextToken() + } else { + // It's actually the procedure variable + spec.ExecutableEntity = &ast.ExecutableProcedureReference{ + ProcedureReference: &ast.ProcedureReferenceName{ + ProcedureVariable: &ast.VariableReference{Name: varName}, + }, + } + return spec, nil + } + } + + // Parse procedure reference + procRef := &ast.ExecutableProcedureReference{} + + if p.curTok.Type == TokenIdent && strings.HasPrefix(p.curTok.Literal, "@") { + // Procedure variable + procRef.ProcedureReference = &ast.ProcedureReferenceName{ + ProcedureVariable: &ast.VariableReference{Name: p.curTok.Literal}, + } + p.nextToken() + } else { + // Procedure name + son, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + procRef.ProcedureReference = &ast.ProcedureReferenceName{ + ProcedureReference: &ast.ProcedureReference{Name: son}, + } + } + + // Parse parameters + for p.curTok.Type != TokenEOF && p.curTok.Type != TokenSemicolon && + p.curTok.Type != TokenOption && !p.isStatementTerminator() { + param, err := p.parseExecuteParameter() + if err != nil { + break + } + procRef.Parameters = append(procRef.Parameters, param) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + spec.ExecutableEntity = procRef + return spec, nil +} + +func (p *Parser) parseExecuteParameter() (*ast.ExecuteParameter, error) { + param := &ast.ExecuteParameter{IsOutput: false} + + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + param.ParameterValue = expr + + return param, nil +} + +func (p *Parser) isStatementTerminator() bool { + switch p.curTok.Type { + case TokenSelect, TokenInsert, TokenUpdate, TokenDelete, TokenDeclare, + TokenIf, TokenWhile, TokenBegin, TokenEnd, TokenCreate, TokenAlter, + TokenDrop, TokenExec, TokenExecute, TokenPrint, TokenThrow: + return true + } + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "GO" { + return true + } + return false +} + +func (p *Parser) parseUpdateStatement() (*ast.UpdateStatement, error) { + // Consume UPDATE + p.nextToken() + + stmt := &ast.UpdateStatement{ + UpdateSpecification: &ast.UpdateSpecification{}, + } + + // Parse target + target, err := p.parseDMLTarget() + if err != nil { + return nil, err + } + stmt.UpdateSpecification.Target = target + + // Expect SET + if p.curTok.Type != TokenSet { + return nil, fmt.Errorf("expected SET, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse SET clauses + setClauses, err := p.parseSetClauses() + if err != nil { + return nil, err + } + stmt.UpdateSpecification.SetClauses = setClauses + + // Parse optional FROM clause + if p.curTok.Type == TokenFrom { + fromClause, err := p.parseFromClause() + if err != nil { + return nil, err + } + stmt.UpdateSpecification.FromClause = fromClause + } + + // Parse optional WHERE clause + if p.curTok.Type == TokenWhere { + whereClause, err := p.parseWhereClause() + if err != nil { + return nil, err + } + stmt.UpdateSpecification.WhereClause = whereClause + } + + // Parse optional OPTION clause + if p.curTok.Type == TokenOption { + hints, err := p.parseOptionClause() + if err != nil { + return nil, err + } + stmt.OptimizerHints = hints + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseSetClauses() ([]ast.SetClause, error) { + var clauses []ast.SetClause + + for { + clause, err := p.parseAssignmentSetClause() + if err != nil { + return nil, err + } + clauses = append(clauses, clause) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + return clauses, nil +} + +func (p *Parser) parseAssignmentSetClause() (*ast.AssignmentSetClause, error) { + clause := &ast.AssignmentSetClause{AssignmentKind: "Equals"} + + // Could be @var = col = value, @var = value, or col = value + if p.curTok.Type == TokenIdent && strings.HasPrefix(p.curTok.Literal, "@") { + varName := p.curTok.Literal + p.nextToken() + if p.curTok.Type == TokenEquals { + clause.Variable = &ast.VariableReference{Name: varName} + p.nextToken() + + // Check if next is column = value (SET @a = col = value) + if p.curTok.Type == TokenIdent && !strings.HasPrefix(p.curTok.Literal, "@") { + // Could be @a = col = value or @a = expr + savedTok := p.curTok + col, err := p.parseMultiPartIdentifierAsColumn() + if err != nil { + return nil, err + } + if p.curTok.Type == TokenEquals { + clause.Column = col + p.nextToken() + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + clause.NewValue = val + return clause, nil + } + // Restore and parse as expression - need different approach + // The column was actually the value expression + _ = savedTok + clause.NewValue = &ast.ColumnReferenceExpression{ + ColumnType: col.ColumnType, + MultiPartIdentifier: col.MultiPartIdentifier, + } + return clause, nil + } + + // Just @var = value + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + clause.NewValue = val + return clause, nil + } + } + + // col = value + col, err := p.parseMultiPartIdentifierAsColumn() + if err != nil { + return nil, err + } + clause.Column = col + + if p.curTok.Type != TokenEquals { + return nil, fmt.Errorf("expected =, got %s", p.curTok.Literal) + } + p.nextToken() + + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + clause.NewValue = val + + return clause, nil +} + +func (p *Parser) parseDeleteStatement() (*ast.DeleteStatement, error) { + // Consume DELETE + p.nextToken() + + stmt := &ast.DeleteStatement{ + DeleteSpecification: &ast.DeleteSpecification{}, + } + + // Skip optional FROM + if p.curTok.Type == TokenFrom { + p.nextToken() + } + + // Parse target + target, err := p.parseDMLTarget() + if err != nil { + return nil, err + } + stmt.DeleteSpecification.Target = target + + // Parse optional FROM clause + if p.curTok.Type == TokenFrom { + fromClause, err := p.parseFromClause() + if err != nil { + return nil, err + } + stmt.DeleteSpecification.FromClause = fromClause + } + + // Parse optional WHERE clause + if p.curTok.Type == TokenWhere { + whereClause, err := p.parseDeleteWhereClause() + if err != nil { + return nil, err + } + stmt.DeleteSpecification.WhereClause = whereClause + } + + // Parse optional OPTION clause + if p.curTok.Type == TokenOption { + hints, err := p.parseOptionClause() + if err != nil { + return nil, err + } + stmt.OptimizerHints = hints + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseDeleteWhereClause() (*ast.WhereClause, error) { + // Consume WHERE + p.nextToken() + + // Check for CURRENT OF cursor_name + if p.curTok.Type == TokenCurrent { + p.nextToken() + if p.curTok.Type != TokenOf { + return nil, fmt.Errorf("expected OF after CURRENT, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse cursor name + cursorName := p.curTok.Literal + p.nextToken() + + return &ast.WhereClause{ + Cursor: &ast.CursorId{ + IsGlobal: false, + Name: &ast.IdentifierOrValueExpression{ + Value: cursorName, + Identifier: &ast.Identifier{ + Value: cursorName, + QuoteType: "NotQuoted", + }, + }, + }, + }, nil + } + + condition, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + + return &ast.WhereClause{SearchCondition: condition}, nil +} + +func (p *Parser) parseDeclareVariableStatement() (*ast.DeclareVariableStatement, error) { + // Consume DECLARE + p.nextToken() + + stmt := &ast.DeclareVariableStatement{} + + for { + decl, err := p.parseDeclareVariableElement() + if err != nil { + return nil, err + } + stmt.Declarations = append(stmt.Declarations, decl) + + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseDeclareVariableElement() (*ast.DeclareVariableElement, error) { + elem := &ast.DeclareVariableElement{} + + // Parse variable name + if p.curTok.Type != TokenIdent || !strings.HasPrefix(p.curTok.Literal, "@") { + return nil, fmt.Errorf("expected variable name, got %s", p.curTok.Literal) + } + elem.VariableName = &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"} + p.nextToken() + + // Skip optional AS + if p.curTok.Type == TokenAs { + p.nextToken() + } + + // Parse data type + dataType, err := p.parseDataType() + if err != nil { + return nil, err + } + elem.DataType = dataType + + // Check for = initial value + if p.curTok.Type == TokenEquals { + p.nextToken() + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + elem.Value = val + } + + return elem, nil +} + +func (p *Parser) parseDataType() (*ast.SqlDataTypeReference, error) { + dt := &ast.SqlDataTypeReference{} + + if p.curTok.Type == TokenCursor { + dt.SqlDataTypeOption = "Cursor" + p.nextToken() + return dt, nil + } + + if p.curTok.Type != TokenIdent { + return nil, fmt.Errorf("expected data type, got %s", p.curTok.Literal) + } + + typeName := p.curTok.Literal + dt.SqlDataTypeOption = convertDataTypeOption(typeName) + baseId := &ast.Identifier{Value: typeName, QuoteType: "NotQuoted"} + dt.Name = &ast.SchemaObjectName{ + BaseIdentifier: baseId, + Count: 1, + Identifiers: []*ast.Identifier{baseId}, + } + p.nextToken() + + // Check for parameters like VARCHAR(100) + if p.curTok.Type == TokenLParen { + p.nextToken() + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + dt.Parameters = append(dt.Parameters, expr) + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + + return dt, nil +} + +func convertDataTypeOption(typeName string) string { + typeMap := map[string]string{ + "INT": "Int", + "INTEGER": "Int", + "BIGINT": "BigInt", + "SMALLINT": "SmallInt", + "TINYINT": "TinyInt", + "BIT": "Bit", + "DECIMAL": "Decimal", + "NUMERIC": "Numeric", + "MONEY": "Money", + "SMALLMONEY": "SmallMoney", + "FLOAT": "Float", + "REAL": "Real", + "DATETIME": "DateTime", + "DATETIME2": "DateTime2", + "DATE": "Date", + "TIME": "Time", + "CHAR": "Char", + "VARCHAR": "VarChar", + "TEXT": "Text", + "NCHAR": "NChar", + "NVARCHAR": "NVarChar", + "NTEXT": "NText", + "BINARY": "Binary", + "VARBINARY": "VarBinary", + "IMAGE": "Image", + "CURSOR": "Cursor", + "SQL_VARIANT": "Sql_Variant", + "TABLE": "Table", + "UNIQUEIDENTIFIER": "UniqueIdentifier", + "XML": "Xml", + } + if mapped, ok := typeMap[strings.ToUpper(typeName)]; ok { + return mapped + } + // Return with first letter capitalized + if len(typeName) > 0 { + return strings.ToUpper(typeName[:1]) + strings.ToLower(typeName[1:]) + } + return typeName +} + +func (p *Parser) parseSetVariableStatement() (ast.Statement, error) { + // Consume SET + p.nextToken() + + // Check for predicate SET options like SET ANSI_NULLS ON/OFF + if p.curTok.Type == TokenIdent { + optionName := strings.ToUpper(p.curTok.Literal) + var setOpt ast.SetOptions + switch optionName { + case "ANSI_NULLS": + setOpt = ast.SetOptionsAnsiNulls + case "ANSI_PADDING": + setOpt = ast.SetOptionsAnsiPadding + case "ANSI_WARNINGS": + setOpt = ast.SetOptionsAnsiWarnings + case "ARITHABORT": + setOpt = ast.SetOptionsArithAbort + case "ARITHIGNORE": + setOpt = ast.SetOptionsArithIgnore + case "CONCAT_NULL_YIELDS_NULL": + setOpt = ast.SetOptionsConcatNullYieldsNull + case "CURSOR_CLOSE_ON_COMMIT": + setOpt = ast.SetOptionsCursorCloseOnCommit + case "FMTONLY": + setOpt = ast.SetOptionsFmtOnly + case "FORCEPLAN": + setOpt = ast.SetOptionsForceplan + case "IMPLICIT_TRANSACTIONS": + setOpt = ast.SetOptionsImplicitTransactions + case "NOCOUNT": + setOpt = ast.SetOptionsNoCount + case "NOEXEC": + setOpt = ast.SetOptionsNoExec + case "NUMERIC_ROUNDABORT": + setOpt = ast.SetOptionsNumericRoundAbort + case "PARSEONLY": + setOpt = ast.SetOptionsParseOnly + case "QUOTED_IDENTIFIER": + setOpt = ast.SetOptionsQuotedIdentifier + case "REMOTE_PROC_TRANSACTIONS": + setOpt = ast.SetOptionsRemoteProcTransactions + case "SHOWPLAN_ALL": + setOpt = ast.SetOptionsShowplanAll + case "SHOWPLAN_TEXT": + setOpt = ast.SetOptionsShowplanText + case "SHOWPLAN_XML": + setOpt = ast.SetOptionsShowplanXml + case "STATISTICS_IO": + setOpt = ast.SetOptionsStatisticsIo + case "STATISTICS_PROFILE": + setOpt = ast.SetOptionsStatisticsProfile + case "STATISTICS_TIME": + setOpt = ast.SetOptionsStatisticsTime + case "STATISTICS_XML": + setOpt = ast.SetOptionsStatisticsXml + case "XACT_ABORT": + setOpt = ast.SetOptionsXactAbort + } + if setOpt != "" { + p.nextToken() // consume option name + isOn := false + // ON is tokenized as TokenOn, not TokenIdent + if p.curTok.Type == TokenOn || (p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "ON") { + isOn = true + p.nextToken() + } else if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "OFF" { + isOn = false + p.nextToken() + } + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return &ast.PredicateSetStatement{ + Options: setOpt, + IsOn: isOn, + }, nil + } + } + + stmt := &ast.SetVariableStatement{ + AssignmentKind: "Equals", + SeparatorType: "Equals", + } + + // Parse variable name + if p.curTok.Type != TokenIdent || !strings.HasPrefix(p.curTok.Literal, "@") { + return nil, fmt.Errorf("expected variable name, got %s", p.curTok.Literal) + } + stmt.Variable = &ast.VariableReference{Name: p.curTok.Literal} + p.nextToken() + + // Expect = + if p.curTok.Type != TokenEquals { + return nil, fmt.Errorf("expected =, got %s", p.curTok.Literal) + } + p.nextToken() + + // Check for CURSOR definition + if p.curTok.Type == TokenCursor { + p.nextToken() + // Parse cursor options and FOR SELECT + // For now, simplified - skip to FOR + for p.curTok.Type != TokenEOF && p.curTok.Type != TokenSemicolon { + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "FOR" { + p.nextToken() + break + } + p.nextToken() + } + if p.curTok.Type == TokenSelect { + qe, err := p.parseQueryExpression() + if err != nil { + return nil, err + } + stmt.CursorDefinition = &ast.CursorDefinition{Select: qe} + } + } else { + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.Expression = expr + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseIfStatement() (*ast.IfStatement, error) { + // Consume IF + p.nextToken() + + stmt := &ast.IfStatement{} + + // Parse predicate + pred, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + stmt.Predicate = pred + + // Parse THEN statement + thenStmt, err := p.parseStatement() + if err != nil { + return nil, err + } + stmt.ThenStatement = thenStmt + + // Check for ELSE + if p.curTok.Type == TokenElse { + p.nextToken() + elseStmt, err := p.parseStatement() + if err != nil { + return nil, err + } + stmt.ElseStatement = elseStmt + } + + return stmt, nil +} + +func (p *Parser) parseWhileStatement() (*ast.WhileStatement, error) { + // Consume WHILE + p.nextToken() + + stmt := &ast.WhileStatement{} + + // Parse predicate + pred, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + stmt.Predicate = pred + + // Parse body statement + bodyStmt, err := p.parseStatement() + if err != nil { + return nil, err + } + stmt.Statement = bodyStmt + + return stmt, nil +} + +func (p *Parser) parseBeginEndBlockStatement() (*ast.BeginEndBlockStatement, error) { + // Consume BEGIN + p.nextToken() + + stmt := &ast.BeginEndBlockStatement{ + StatementList: &ast.StatementList{}, + } + + // Parse statements until END + for p.curTok.Type != TokenEnd && p.curTok.Type != TokenEOF { + s, err := p.parseStatement() + if err != nil { + return nil, err + } + if s != nil { + stmt.StatementList.Statements = append(stmt.StatementList.Statements, s) + } + } + + // Consume END + if p.curTok.Type == TokenEnd { + p.nextToken() + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseCreateStatement() (ast.Statement, error) { + // Consume CREATE + p.nextToken() + + switch p.curTok.Type { + case TokenTable: + return p.parseCreateTableStatement() + case TokenView: + return p.parseCreateViewStatement() + case TokenSchema: + return p.parseCreateSchemaStatement() + default: + return nil, fmt.Errorf("unexpected token after CREATE: %s", p.curTok.Literal) + } +} + +func (p *Parser) parseCreateViewStatement() (*ast.CreateViewStatement, error) { + // Consume VIEW + p.nextToken() + + stmt := &ast.CreateViewStatement{} + + // Parse view name + son, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + stmt.SchemaObjectName = son + + // Check for column list + if p.curTok.Type == TokenLParen { + p.nextToken() + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + if p.curTok.Type == TokenIdent { + stmt.Columns = append(stmt.Columns, &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"}) + p.nextToken() + } + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + + // Check for WITH options + if p.curTok.Type == TokenWith { + p.nextToken() + // Parse view options + for p.curTok.Type == TokenIdent { + opt := ast.ViewOption{OptionKind: p.curTok.Literal} + stmt.ViewOptions = append(stmt.ViewOptions, opt) + p.nextToken() + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + } + + // Expect AS + if p.curTok.Type == TokenAs { + p.nextToken() + } + + // Parse SELECT statement + selStmt, err := p.parseSelectStatement() + if err != nil { + return nil, err + } + stmt.SelectStatement = selStmt + + return stmt, nil +} + +func (p *Parser) parseCreateSchemaStatement() (*ast.CreateSchemaStatement, error) { + // Consume SCHEMA + p.nextToken() + + stmt := &ast.CreateSchemaStatement{} + + // Parse schema name (can be bracketed) or AUTHORIZATION + if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + stmt.Name = p.parseIdentifier() + } + + // Check for AUTHORIZATION + if p.curTok.Type == TokenAuthorization { + p.nextToken() + if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + stmt.Owner = p.parseIdentifier() + } + } + + // Parse schema elements (CREATE TABLE, CREATE VIEW, GRANT) + stmt.StatementList = &ast.StatementList{} + for p.curTok.Type != TokenEOF && p.curTok.Type != TokenSemicolon { + // Check for GO (batch separator) + if p.curTok.Type == TokenIdent && strings.ToUpper(p.curTok.Literal) == "GO" { + break + } + // Parse schema element statements + if p.curTok.Type == TokenCreate || p.curTok.Type == TokenGrant { + elemStmt, err := p.parseStatement() + if err != nil { + break + } + if elemStmt != nil { + stmt.StatementList.Statements = append(stmt.StatementList.Statements, elemStmt) + } + } else { + break + } + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseExecuteStatement() (*ast.ExecuteStatement, error) { + execSpec, err := p.parseExecuteSpecification() + if err != nil { + return nil, err + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return &ast.ExecuteStatement{ExecuteSpecification: execSpec}, nil +} + +func (p *Parser) parseReturnStatement() (*ast.ReturnStatement, error) { + // Consume RETURN + p.nextToken() + + stmt := &ast.ReturnStatement{} + + // Check for expression + if p.curTok.Type != TokenSemicolon && p.curTok.Type != TokenEOF && !p.isStatementTerminator() { + expr, err := p.parseScalarExpression() + if err == nil { + stmt.Expression = expr + } + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseBreakStatement() (*ast.BreakStatement, error) { + // Consume BREAK + p.nextToken() + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return &ast.BreakStatement{}, nil +} + +func (p *Parser) parseContinueStatement() (*ast.ContinueStatement, error) { + // Consume CONTINUE + p.nextToken() + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return &ast.ContinueStatement{}, nil +} + +// ======================= End New Statement Parsing Functions ======================= + +// jsonNode represents a generic JSON node from the AST JSON format. +type jsonNode map[string]any + +// MarshalScript marshals a Script to JSON in the expected format. +func MarshalScript(s *ast.Script) ([]byte, error) { + node := scriptToJSON(s) + return json.MarshalIndent(node, "", " ") +} + +func scriptToJSON(s *ast.Script) jsonNode { + node := jsonNode{ + "$type": "TSqlScript", + } + if len(s.Batches) > 0 { + batches := make([]jsonNode, len(s.Batches)) + for i, b := range s.Batches { + batches[i] = batchToJSON(b) + } + node["Batches"] = batches + } + return node +} + +func batchToJSON(b *ast.Batch) jsonNode { + node := jsonNode{ + "$type": "TSqlBatch", + } + if len(b.Statements) > 0 { + stmts := make([]jsonNode, len(b.Statements)) + for i, stmt := range b.Statements { + stmts[i] = statementToJSON(stmt) + } + node["Statements"] = stmts + } + return node +} func statementToJSON(stmt ast.Statement) jsonNode { switch s := stmt.(type) { case *ast.SelectStatement: return selectStatementToJSON(s) + case *ast.InsertStatement: + return insertStatementToJSON(s) + case *ast.UpdateStatement: + return updateStatementToJSON(s) + case *ast.DeleteStatement: + return deleteStatementToJSON(s) + case *ast.DeclareVariableStatement: + return declareVariableStatementToJSON(s) + case *ast.SetVariableStatement: + return setVariableStatementToJSON(s) + case *ast.IfStatement: + return ifStatementToJSON(s) + case *ast.WhileStatement: + return whileStatementToJSON(s) + case *ast.BeginEndBlockStatement: + return beginEndBlockStatementToJSON(s) + case *ast.CreateViewStatement: + return createViewStatementToJSON(s) + case *ast.CreateSchemaStatement: + return createSchemaStatementToJSON(s) + case *ast.ExecuteStatement: + return executeStatementToJSON(s) + case *ast.ReturnStatement: + return returnStatementToJSON(s) + case *ast.BreakStatement: + return breakStatementToJSON() + case *ast.ContinueStatement: + return continueStatementToJSON() case *ast.PrintStatement: return printStatementToJSON(s) case *ast.ThrowStatement: @@ -753,507 +2733,1429 @@ func statementToJSON(stmt ast.Statement) jsonNode { return revertStatementToJSON(s) case *ast.DropCredentialStatement: return dropCredentialStatementToJSON(s) + case *ast.CreateTableStatement: + return createTableStatementToJSON(s) + case *ast.GrantStatement: + return grantStatementToJSON(s) + case *ast.PredicateSetStatement: + return predicateSetStatementToJSON(s) + default: + return jsonNode{"$type": "UnknownStatement"} + } +} + +func revertStatementToJSON(s *ast.RevertStatement) jsonNode { + node := jsonNode{ + "$type": "RevertStatement", + } + if s.Cookie != nil { + node["Cookie"] = scalarExpressionToJSON(s.Cookie) + } + return node +} + +func dropCredentialStatementToJSON(s *ast.DropCredentialStatement) jsonNode { + node := jsonNode{ + "$type": "DropCredentialStatement", + } + node["IsDatabaseScoped"] = s.IsDatabaseScoped + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + node["IsIfExists"] = s.IsIfExists + return node +} + +func alterTableDropTableElementStatementToJSON(s *ast.AlterTableDropTableElementStatement) jsonNode { + node := jsonNode{ + "$type": "AlterTableDropTableElementStatement", + } + if len(s.AlterTableDropTableElements) > 0 { + elements := make([]jsonNode, len(s.AlterTableDropTableElements)) + for i, e := range s.AlterTableDropTableElements { + elements[i] = alterTableDropTableElementToJSON(e) + } + node["AlterTableDropTableElements"] = elements + } + if s.SchemaObjectName != nil { + node["SchemaObjectName"] = schemaObjectNameToJSON(s.SchemaObjectName) + } + return node +} + +func alterTableDropTableElementToJSON(e *ast.AlterTableDropTableElement) jsonNode { + node := jsonNode{ + "$type": "AlterTableDropTableElement", + } + if e.TableElementType != "" { + node["TableElementType"] = e.TableElementType + } + if e.Name != nil { + node["Name"] = identifierToJSON(e.Name) + } + node["IsIfExists"] = e.IsIfExists + return node +} + +func printStatementToJSON(s *ast.PrintStatement) jsonNode { + node := jsonNode{ + "$type": "PrintStatement", + } + if s.Expression != nil { + node["Expression"] = scalarExpressionToJSON(s.Expression) + } + return node +} + +func throwStatementToJSON(s *ast.ThrowStatement) jsonNode { + node := jsonNode{ + "$type": "ThrowStatement", + } + if s.ErrorNumber != nil { + node["ErrorNumber"] = scalarExpressionToJSON(s.ErrorNumber) + } + if s.Message != nil { + node["Message"] = scalarExpressionToJSON(s.Message) + } + if s.State != nil { + node["State"] = scalarExpressionToJSON(s.State) + } + return node +} + +func selectStatementToJSON(s *ast.SelectStatement) jsonNode { + node := jsonNode{ + "$type": "SelectStatement", + } + if s.QueryExpression != nil { + node["QueryExpression"] = queryExpressionToJSON(s.QueryExpression) + } + if s.Into != nil { + node["Into"] = schemaObjectNameToJSON(s.Into) + } + if len(s.OptimizerHints) > 0 { + hints := make([]jsonNode, len(s.OptimizerHints)) + for i, h := range s.OptimizerHints { + hints[i] = optimizerHintToJSON(h) + } + node["OptimizerHints"] = hints + } + return node +} + +func optimizerHintToJSON(h *ast.OptimizerHint) jsonNode { + node := jsonNode{ + "$type": "OptimizerHint", + } + if h.HintKind != "" { + node["HintKind"] = h.HintKind + } + return node +} + +func queryExpressionToJSON(qe ast.QueryExpression) jsonNode { + switch q := qe.(type) { + case *ast.QuerySpecification: + return querySpecificationToJSON(q) + case *ast.QueryParenthesisExpression: + return queryParenthesisExpressionToJSON(q) + case *ast.BinaryQueryExpression: + return binaryQueryExpressionToJSON(q) + default: + return jsonNode{"$type": "UnknownQueryExpression"} + } +} + +func queryParenthesisExpressionToJSON(q *ast.QueryParenthesisExpression) jsonNode { + node := jsonNode{ + "$type": "QueryParenthesisExpression", + } + if q.QueryExpression != nil { + node["QueryExpression"] = queryExpressionToJSON(q.QueryExpression) + } + return node +} + +func binaryQueryExpressionToJSON(q *ast.BinaryQueryExpression) jsonNode { + node := jsonNode{ + "$type": "BinaryQueryExpression", + } + if q.BinaryQueryExpressionType != "" { + node["BinaryQueryExpressionType"] = q.BinaryQueryExpressionType + } + node["All"] = q.All + if q.FirstQueryExpression != nil { + node["FirstQueryExpression"] = queryExpressionToJSON(q.FirstQueryExpression) + } + if q.SecondQueryExpression != nil { + node["SecondQueryExpression"] = queryExpressionToJSON(q.SecondQueryExpression) + } + if q.OrderByClause != nil { + node["OrderByClause"] = orderByClauseToJSON(q.OrderByClause) + } + return node +} + +func querySpecificationToJSON(q *ast.QuerySpecification) jsonNode { + node := jsonNode{ + "$type": "QuerySpecification", + } + if q.UniqueRowFilter != "" { + node["UniqueRowFilter"] = q.UniqueRowFilter + } + if q.TopRowFilter != nil { + node["TopRowFilter"] = topRowFilterToJSON(q.TopRowFilter) + } + if len(q.SelectElements) > 0 { + elems := make([]jsonNode, len(q.SelectElements)) + for i, elem := range q.SelectElements { + elems[i] = selectElementToJSON(elem) + } + node["SelectElements"] = elems + } + if q.FromClause != nil { + node["FromClause"] = fromClauseToJSON(q.FromClause) + } + if q.WhereClause != nil { + node["WhereClause"] = whereClauseToJSON(q.WhereClause) + } + if q.GroupByClause != nil { + node["GroupByClause"] = groupByClauseToJSON(q.GroupByClause) + } + if q.HavingClause != nil { + node["HavingClause"] = havingClauseToJSON(q.HavingClause) + } + if q.OrderByClause != nil { + node["OrderByClause"] = orderByClauseToJSON(q.OrderByClause) + } + return node +} + +func topRowFilterToJSON(t *ast.TopRowFilter) jsonNode { + node := jsonNode{ + "$type": "TopRowFilter", + } + if t.Expression != nil { + node["Expression"] = scalarExpressionToJSON(t.Expression) + } + node["Percent"] = t.Percent + node["WithTies"] = t.WithTies + return node +} + +func selectElementToJSON(elem ast.SelectElement) jsonNode { + switch e := elem.(type) { + case *ast.SelectScalarExpression: + node := jsonNode{ + "$type": "SelectScalarExpression", + } + if e.Expression != nil { + node["Expression"] = scalarExpressionToJSON(e.Expression) + } + if e.ColumnName != nil { + node["ColumnName"] = identifierOrValueExpressionToJSON(e.ColumnName) + } + return node + case *ast.SelectStarExpression: + node := jsonNode{ + "$type": "SelectStarExpression", + } + if e.Qualifier != nil { + node["Qualifier"] = multiPartIdentifierToJSON(e.Qualifier) + } + return node + default: + return jsonNode{"$type": "UnknownSelectElement"} + } +} + +func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { + switch e := expr.(type) { + case *ast.ColumnReferenceExpression: + node := jsonNode{ + "$type": "ColumnReferenceExpression", + } + if e.ColumnType != "" { + node["ColumnType"] = e.ColumnType + } + if e.MultiPartIdentifier != nil { + node["MultiPartIdentifier"] = multiPartIdentifierToJSON(e.MultiPartIdentifier) + } + return node + case *ast.IntegerLiteral: + node := jsonNode{ + "$type": "IntegerLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + if e.Value != "" { + node["Value"] = e.Value + } + return node + case *ast.StringLiteral: + node := jsonNode{ + "$type": "StringLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + // Always include IsNational and IsLargeObject + node["IsNational"] = e.IsNational + node["IsLargeObject"] = e.IsLargeObject + if e.Value != "" { + node["Value"] = e.Value + } + return node + case *ast.FunctionCall: + node := jsonNode{ + "$type": "FunctionCall", + } + if e.FunctionName != nil { + node["FunctionName"] = identifierToJSON(e.FunctionName) + } + if len(e.Parameters) > 0 { + params := make([]jsonNode, len(e.Parameters)) + for i, p := range e.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params + } + if e.UniqueRowFilter != "" { + node["UniqueRowFilter"] = e.UniqueRowFilter + } + if e.WithArrayWrapper { + node["WithArrayWrapper"] = e.WithArrayWrapper + } + return node + case *ast.BinaryExpression: + node := jsonNode{ + "$type": "BinaryExpression", + } + if e.BinaryExpressionType != "" { + node["BinaryExpressionType"] = e.BinaryExpressionType + } + if e.FirstExpression != nil { + node["FirstExpression"] = scalarExpressionToJSON(e.FirstExpression) + } + if e.SecondExpression != nil { + node["SecondExpression"] = scalarExpressionToJSON(e.SecondExpression) + } + return node + case *ast.VariableReference: + node := jsonNode{ + "$type": "VariableReference", + } + if e.Name != "" { + node["Name"] = e.Name + } + return node + case *ast.NumericLiteral: + node := jsonNode{ + "$type": "NumericLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + if e.Value != "" { + node["Value"] = e.Value + } + return node + case *ast.OdbcLiteral: + node := jsonNode{ + "$type": "OdbcLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + if e.OdbcLiteralType != "" { + node["OdbcLiteralType"] = e.OdbcLiteralType + } + node["IsNational"] = e.IsNational + if e.Value != "" { + node["Value"] = e.Value + } + return node + case *ast.NullLiteral: + node := jsonNode{ + "$type": "NullLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + if e.Value != "" { + node["Value"] = e.Value + } + return node + case *ast.DefaultLiteral: + node := jsonNode{ + "$type": "DefaultLiteral", + } + if e.LiteralType != "" { + node["LiteralType"] = e.LiteralType + } + if e.Value != "" { + node["Value"] = e.Value + } + return node + case *ast.UnaryExpression: + node := jsonNode{ + "$type": "UnaryExpression", + } + if e.UnaryExpressionType != "" { + node["UnaryExpressionType"] = e.UnaryExpressionType + } + if e.Expression != nil { + node["Expression"] = scalarExpressionToJSON(e.Expression) + } + return node + default: + return jsonNode{"$type": "UnknownScalarExpression"} + } +} + +func identifierToJSON(id *ast.Identifier) jsonNode { + node := jsonNode{ + "$type": "Identifier", + } + // Always include Value, even if empty + node["Value"] = id.Value + if id.QuoteType != "" { + node["QuoteType"] = id.QuoteType + } + return node +} + +func multiPartIdentifierToJSON(mpi *ast.MultiPartIdentifier) jsonNode { + node := jsonNode{ + "$type": "MultiPartIdentifier", + } + if mpi.Count > 0 { + node["Count"] = mpi.Count + } + if len(mpi.Identifiers) > 0 { + ids := make([]jsonNode, len(mpi.Identifiers)) + for i, id := range mpi.Identifiers { + ids[i] = identifierToJSON(id) + } + node["Identifiers"] = ids + } + return node +} + +func identifierOrValueExpressionToJSON(iove *ast.IdentifierOrValueExpression) jsonNode { + node := jsonNode{ + "$type": "IdentifierOrValueExpression", + } + if iove.Value != "" { + node["Value"] = iove.Value + } + if iove.Identifier != nil { + node["Identifier"] = identifierToJSON(iove.Identifier) + } + return node +} + +func fromClauseToJSON(fc *ast.FromClause) jsonNode { + node := jsonNode{ + "$type": "FromClause", + } + if len(fc.TableReferences) > 0 { + refs := make([]jsonNode, len(fc.TableReferences)) + for i, ref := range fc.TableReferences { + refs[i] = tableReferenceToJSON(ref) + } + node["TableReferences"] = refs + } + return node +} + +func tableReferenceToJSON(ref ast.TableReference) jsonNode { + switch r := ref.(type) { + case *ast.NamedTableReference: + node := jsonNode{ + "$type": "NamedTableReference", + } + if r.SchemaObject != nil { + node["SchemaObject"] = schemaObjectNameToJSON(r.SchemaObject) + } + if len(r.TableHints) > 0 { + hints := make([]jsonNode, len(r.TableHints)) + for i, h := range r.TableHints { + hints[i] = tableHintToJSON(h) + } + node["TableHints"] = hints + } + if r.Alias != nil { + node["Alias"] = identifierToJSON(r.Alias) + } + node["ForPath"] = r.ForPath + return node + case *ast.QualifiedJoin: + node := jsonNode{ + "$type": "QualifiedJoin", + } + if r.SearchCondition != nil { + node["SearchCondition"] = booleanExpressionToJSON(r.SearchCondition) + } + if r.QualifiedJoinType != "" { + node["QualifiedJoinType"] = r.QualifiedJoinType + } + if r.JoinHint != "" { + node["JoinHint"] = r.JoinHint + } + if r.FirstTableReference != nil { + node["FirstTableReference"] = tableReferenceToJSON(r.FirstTableReference) + } + if r.SecondTableReference != nil { + node["SecondTableReference"] = tableReferenceToJSON(r.SecondTableReference) + } + return node + case *ast.UnqualifiedJoin: + node := jsonNode{ + "$type": "UnqualifiedJoin", + } + if r.UnqualifiedJoinType != "" { + node["UnqualifiedJoinType"] = r.UnqualifiedJoinType + } + if r.FirstTableReference != nil { + node["FirstTableReference"] = tableReferenceToJSON(r.FirstTableReference) + } + if r.SecondTableReference != nil { + node["SecondTableReference"] = tableReferenceToJSON(r.SecondTableReference) + } + return node + case *ast.VariableTableReference: + node := jsonNode{ + "$type": "VariableTableReference", + } + if r.Variable != nil { + node["Variable"] = scalarExpressionToJSON(r.Variable) + } + node["ForPath"] = r.ForPath + return node + case *ast.SchemaObjectFunctionTableReference: + node := jsonNode{ + "$type": "SchemaObjectFunctionTableReference", + } + if r.SchemaObject != nil { + node["SchemaObject"] = schemaObjectNameToJSON(r.SchemaObject) + } + if len(r.Parameters) > 0 { + params := make([]jsonNode, len(r.Parameters)) + for i, p := range r.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params + } + node["ForPath"] = r.ForPath + return node + case *ast.InternalOpenRowset: + node := jsonNode{ + "$type": "InternalOpenRowset", + } + if r.Identifier != nil { + node["Identifier"] = identifierToJSON(r.Identifier) + } + if len(r.VarArgs) > 0 { + args := make([]jsonNode, len(r.VarArgs)) + for i, a := range r.VarArgs { + args[i] = scalarExpressionToJSON(a) + } + node["VarArgs"] = args + } + node["ForPath"] = r.ForPath + return node default: - return jsonNode{"$type": "UnknownStatement"} + return jsonNode{"$type": "UnknownTableReference"} + } +} + +func schemaObjectNameToJSON(son *ast.SchemaObjectName) jsonNode { + node := jsonNode{ + "$type": "SchemaObjectName", + } + if son.ServerIdentifier != nil { + node["ServerIdentifier"] = identifierToJSON(son.ServerIdentifier) + } + if son.DatabaseIdentifier != nil { + node["DatabaseIdentifier"] = identifierToJSON(son.DatabaseIdentifier) + } + if son.SchemaIdentifier != nil { + node["SchemaIdentifier"] = identifierToJSON(son.SchemaIdentifier) + } + if son.BaseIdentifier != nil { + node["BaseIdentifier"] = identifierToJSON(son.BaseIdentifier) + } + if son.Count > 0 { + node["Count"] = son.Count + } + if len(son.Identifiers) > 0 { + // Handle $ref for identifiers that reference the named identifiers + ids := make([]any, len(son.Identifiers)) + for i, id := range son.Identifiers { + // Check if this identifier is referenced by one of the named fields + isRef := false + if son.ServerIdentifier != nil && id == son.ServerIdentifier { + isRef = true + } else if son.DatabaseIdentifier != nil && id == son.DatabaseIdentifier { + isRef = true + } else if son.SchemaIdentifier != nil && id == son.SchemaIdentifier { + isRef = true + } else if son.BaseIdentifier != nil && id == son.BaseIdentifier { + isRef = true + } + + if isRef { + ids[i] = jsonNode{"$ref": "Identifier"} + } else { + ids[i] = identifierToJSON(id) + } + } + node["Identifiers"] = ids + } + return node +} + +func booleanExpressionToJSON(expr ast.BooleanExpression) jsonNode { + switch e := expr.(type) { + case *ast.BooleanComparisonExpression: + node := jsonNode{ + "$type": "BooleanComparisonExpression", + } + if e.ComparisonType != "" { + node["ComparisonType"] = e.ComparisonType + } + if e.FirstExpression != nil { + node["FirstExpression"] = scalarExpressionToJSON(e.FirstExpression) + } + if e.SecondExpression != nil { + node["SecondExpression"] = scalarExpressionToJSON(e.SecondExpression) + } + return node + case *ast.BooleanBinaryExpression: + node := jsonNode{ + "$type": "BooleanBinaryExpression", + } + if e.BinaryExpressionType != "" { + node["BinaryExpressionType"] = e.BinaryExpressionType + } + if e.FirstExpression != nil { + node["FirstExpression"] = booleanExpressionToJSON(e.FirstExpression) + } + if e.SecondExpression != nil { + node["SecondExpression"] = booleanExpressionToJSON(e.SecondExpression) + } + return node + default: + return jsonNode{"$type": "UnknownBooleanExpression"} + } +} + +func groupByClauseToJSON(gbc *ast.GroupByClause) jsonNode { + node := jsonNode{ + "$type": "GroupByClause", + } + if gbc.GroupByOption != "" { + node["GroupByOption"] = gbc.GroupByOption + } + // Always include All field + node["All"] = gbc.All + if len(gbc.GroupingSpecifications) > 0 { + specs := make([]jsonNode, len(gbc.GroupingSpecifications)) + for i, spec := range gbc.GroupingSpecifications { + specs[i] = groupingSpecificationToJSON(spec) + } + node["GroupingSpecifications"] = specs + } + return node +} + +func groupingSpecificationToJSON(spec ast.GroupingSpecification) jsonNode { + switch s := spec.(type) { + case *ast.ExpressionGroupingSpecification: + node := jsonNode{ + "$type": "ExpressionGroupingSpecification", + } + // Always include DistributedAggregation field + node["DistributedAggregation"] = s.DistributedAggregation + if s.Expression != nil { + node["Expression"] = scalarExpressionToJSON(s.Expression) + } + return node + default: + return jsonNode{"$type": "UnknownGroupingSpecification"} + } +} + +func havingClauseToJSON(hc *ast.HavingClause) jsonNode { + node := jsonNode{ + "$type": "HavingClause", + } + if hc.SearchCondition != nil { + node["SearchCondition"] = booleanExpressionToJSON(hc.SearchCondition) + } + return node +} + +func orderByClauseToJSON(obc *ast.OrderByClause) jsonNode { + node := jsonNode{ + "$type": "OrderByClause", + } + if len(obc.OrderByElements) > 0 { + elems := make([]jsonNode, len(obc.OrderByElements)) + for i, elem := range obc.OrderByElements { + elems[i] = expressionWithSortOrderToJSON(elem) + } + node["OrderByElements"] = elems + } + return node +} + +func expressionWithSortOrderToJSON(ewso *ast.ExpressionWithSortOrder) jsonNode { + node := jsonNode{ + "$type": "ExpressionWithSortOrder", + } + if ewso.SortOrder != "" { + node["SortOrder"] = ewso.SortOrder + } + if ewso.Expression != nil { + node["Expression"] = scalarExpressionToJSON(ewso.Expression) + } + return node +} + +// ======================= New Statement JSON Functions ======================= + +func tableHintToJSON(h *ast.TableHint) jsonNode { + node := jsonNode{ + "$type": "TableHint", + } + if h.HintKind != "" { + node["HintKind"] = h.HintKind + } + return node +} + +func insertStatementToJSON(s *ast.InsertStatement) jsonNode { + node := jsonNode{ + "$type": "InsertStatement", + } + if s.InsertSpecification != nil { + node["InsertSpecification"] = insertSpecificationToJSON(s.InsertSpecification) + } + if len(s.OptimizerHints) > 0 { + hints := make([]jsonNode, len(s.OptimizerHints)) + for i, h := range s.OptimizerHints { + hints[i] = optimizerHintToJSON(h) + } + node["OptimizerHints"] = hints + } + return node +} + +func insertSpecificationToJSON(spec *ast.InsertSpecification) jsonNode { + node := jsonNode{ + "$type": "InsertSpecification", + } + if spec.InsertOption != "" && spec.InsertOption != "None" { + node["InsertOption"] = spec.InsertOption + } + if spec.InsertSource != nil { + node["InsertSource"] = insertSourceToJSON(spec.InsertSource) + } + if spec.Target != nil { + node["Target"] = tableReferenceToJSON(spec.Target) + } + if len(spec.Columns) > 0 { + cols := make([]jsonNode, len(spec.Columns)) + for i, c := range spec.Columns { + cols[i] = scalarExpressionToJSON(c) + } + node["Columns"] = cols + } + return node +} + +func insertSourceToJSON(src ast.InsertSource) jsonNode { + switch s := src.(type) { + case *ast.ValuesInsertSource: + node := jsonNode{ + "$type": "ValuesInsertSource", + } + node["IsDefaultValues"] = s.IsDefaultValues + if len(s.RowValues) > 0 { + rows := make([]jsonNode, len(s.RowValues)) + for i, r := range s.RowValues { + rows[i] = rowValueToJSON(r) + } + node["RowValues"] = rows + } + return node + case *ast.SelectInsertSource: + node := jsonNode{ + "$type": "SelectInsertSource", + } + if s.Select != nil { + node["Select"] = queryExpressionToJSON(s.Select) + } + return node + case *ast.ExecuteInsertSource: + node := jsonNode{ + "$type": "ExecuteInsertSource", + } + if s.Execute != nil { + node["Execute"] = executeSpecificationToJSON(s.Execute) + } + return node + default: + return jsonNode{"$type": "UnknownInsertSource"} + } +} + +func rowValueToJSON(rv *ast.RowValue) jsonNode { + node := jsonNode{ + "$type": "RowValue", + } + if len(rv.ColumnValues) > 0 { + vals := make([]jsonNode, len(rv.ColumnValues)) + for i, v := range rv.ColumnValues { + vals[i] = scalarExpressionToJSON(v) + } + node["ColumnValues"] = vals + } + return node +} + +func executeSpecificationToJSON(spec *ast.ExecuteSpecification) jsonNode { + node := jsonNode{ + "$type": "ExecuteSpecification", + } + if spec.Variable != nil { + node["Variable"] = scalarExpressionToJSON(spec.Variable) + } + if spec.ExecutableEntity != nil { + node["ExecutableEntity"] = executableEntityToJSON(spec.ExecutableEntity) + } + return node +} + +func executableEntityToJSON(entity ast.ExecutableEntity) jsonNode { + switch e := entity.(type) { + case *ast.ExecutableProcedureReference: + node := jsonNode{ + "$type": "ExecutableProcedureReference", + } + if e.ProcedureReference != nil { + node["ProcedureReference"] = procedureReferenceNameToJSON(e.ProcedureReference) + } + if len(e.Parameters) > 0 { + params := make([]jsonNode, len(e.Parameters)) + for i, p := range e.Parameters { + params[i] = executeParameterToJSON(p) + } + node["Parameters"] = params + } + return node + default: + return jsonNode{"$type": "UnknownExecutableEntity"} + } +} + +func procedureReferenceNameToJSON(prn *ast.ProcedureReferenceName) jsonNode { + node := jsonNode{ + "$type": "ProcedureReferenceName", + } + if prn.ProcedureVariable != nil { + node["ProcedureVariable"] = scalarExpressionToJSON(prn.ProcedureVariable) + } + if prn.ProcedureReference != nil { + node["ProcedureReference"] = procedureReferenceToJSON(prn.ProcedureReference) + } + return node +} + +func procedureReferenceToJSON(pr *ast.ProcedureReference) jsonNode { + node := jsonNode{ + "$type": "ProcedureReference", } + if pr.Name != nil { + node["Name"] = schemaObjectNameToJSON(pr.Name) + } + return node } -func revertStatementToJSON(s *ast.RevertStatement) jsonNode { +func executeParameterToJSON(ep *ast.ExecuteParameter) jsonNode { node := jsonNode{ - "$type": "RevertStatement", + "$type": "ExecuteParameter", } - if s.Cookie != nil { - node["Cookie"] = scalarExpressionToJSON(s.Cookie) + if ep.ParameterValue != nil { + node["ParameterValue"] = scalarExpressionToJSON(ep.ParameterValue) } + if ep.Variable != nil { + node["Variable"] = scalarExpressionToJSON(ep.Variable) + } + node["IsOutput"] = ep.IsOutput return node } -func dropCredentialStatementToJSON(s *ast.DropCredentialStatement) jsonNode { +func updateStatementToJSON(s *ast.UpdateStatement) jsonNode { node := jsonNode{ - "$type": "DropCredentialStatement", + "$type": "UpdateStatement", } - node["IsDatabaseScoped"] = s.IsDatabaseScoped - if s.Name != nil { - node["Name"] = identifierToJSON(s.Name) + if s.UpdateSpecification != nil { + node["UpdateSpecification"] = updateSpecificationToJSON(s.UpdateSpecification) + } + if len(s.OptimizerHints) > 0 { + hints := make([]jsonNode, len(s.OptimizerHints)) + for i, h := range s.OptimizerHints { + hints[i] = optimizerHintToJSON(h) + } + node["OptimizerHints"] = hints } - node["IsIfExists"] = s.IsIfExists return node } -func alterTableDropTableElementStatementToJSON(s *ast.AlterTableDropTableElementStatement) jsonNode { +func updateSpecificationToJSON(spec *ast.UpdateSpecification) jsonNode { node := jsonNode{ - "$type": "AlterTableDropTableElementStatement", + "$type": "UpdateSpecification", } - if len(s.AlterTableDropTableElements) > 0 { - elements := make([]jsonNode, len(s.AlterTableDropTableElements)) - for i, e := range s.AlterTableDropTableElements { - elements[i] = alterTableDropTableElementToJSON(e) + if len(spec.SetClauses) > 0 { + clauses := make([]jsonNode, len(spec.SetClauses)) + for i, c := range spec.SetClauses { + clauses[i] = setClauseToJSON(c) } - node["AlterTableDropTableElements"] = elements + node["SetClauses"] = clauses } - if s.SchemaObjectName != nil { - node["SchemaObjectName"] = schemaObjectNameToJSON(s.SchemaObjectName) + if spec.Target != nil { + node["Target"] = tableReferenceToJSON(spec.Target) + } + if spec.FromClause != nil { + node["FromClause"] = fromClauseToJSON(spec.FromClause) + } + if spec.WhereClause != nil { + node["WhereClause"] = whereClauseToJSON(spec.WhereClause) } return node } -func alterTableDropTableElementToJSON(e *ast.AlterTableDropTableElement) jsonNode { +func setClauseToJSON(sc ast.SetClause) jsonNode { + switch c := sc.(type) { + case *ast.AssignmentSetClause: + node := jsonNode{ + "$type": "AssignmentSetClause", + } + if c.Variable != nil { + node["Variable"] = scalarExpressionToJSON(c.Variable) + } + if c.Column != nil { + node["Column"] = scalarExpressionToJSON(c.Column) + } + if c.NewValue != nil { + node["NewValue"] = scalarExpressionToJSON(c.NewValue) + } + if c.AssignmentKind != "" { + node["AssignmentKind"] = c.AssignmentKind + } + return node + default: + return jsonNode{"$type": "UnknownSetClause"} + } +} + +func deleteStatementToJSON(s *ast.DeleteStatement) jsonNode { node := jsonNode{ - "$type": "AlterTableDropTableElement", + "$type": "DeleteStatement", } - if e.TableElementType != "" { - node["TableElementType"] = e.TableElementType + if s.DeleteSpecification != nil { + node["DeleteSpecification"] = deleteSpecificationToJSON(s.DeleteSpecification) } - if e.Name != nil { - node["Name"] = identifierToJSON(e.Name) + if len(s.OptimizerHints) > 0 { + hints := make([]jsonNode, len(s.OptimizerHints)) + for i, h := range s.OptimizerHints { + hints[i] = optimizerHintToJSON(h) + } + node["OptimizerHints"] = hints } - node["IsIfExists"] = e.IsIfExists return node } -func printStatementToJSON(s *ast.PrintStatement) jsonNode { +func deleteSpecificationToJSON(spec *ast.DeleteSpecification) jsonNode { node := jsonNode{ - "$type": "PrintStatement", + "$type": "DeleteSpecification", } - if s.Expression != nil { - node["Expression"] = scalarExpressionToJSON(s.Expression) + if spec.FromClause != nil { + node["FromClause"] = fromClauseToJSON(spec.FromClause) + } + if spec.WhereClause != nil { + node["WhereClause"] = whereClauseToJSON(spec.WhereClause) + } + if spec.Target != nil { + node["Target"] = tableReferenceToJSON(spec.Target) } return node } -func throwStatementToJSON(s *ast.ThrowStatement) jsonNode { +func whereClauseToJSON(wc *ast.WhereClause) jsonNode { node := jsonNode{ - "$type": "ThrowStatement", - } - if s.ErrorNumber != nil { - node["ErrorNumber"] = scalarExpressionToJSON(s.ErrorNumber) + "$type": "WhereClause", } - if s.Message != nil { - node["Message"] = scalarExpressionToJSON(s.Message) + if wc.Cursor != nil { + node["Cursor"] = cursorIdToJSON(wc.Cursor) } - if s.State != nil { - node["State"] = scalarExpressionToJSON(s.State) + if wc.SearchCondition != nil { + node["SearchCondition"] = booleanExpressionToJSON(wc.SearchCondition) } return node } -func selectStatementToJSON(s *ast.SelectStatement) jsonNode { +func cursorIdToJSON(cid *ast.CursorId) jsonNode { node := jsonNode{ - "$type": "SelectStatement", + "$type": "CursorId", } - if s.QueryExpression != nil { - node["QueryExpression"] = queryExpressionToJSON(s.QueryExpression) + node["IsGlobal"] = cid.IsGlobal + if cid.Name != nil { + node["Name"] = identifierOrValueExpressionToJSON(cid.Name) } - if len(s.OptimizerHints) > 0 { - hints := make([]jsonNode, len(s.OptimizerHints)) - for i, h := range s.OptimizerHints { - hints[i] = optimizerHintToJSON(h) + return node +} + +func declareVariableStatementToJSON(s *ast.DeclareVariableStatement) jsonNode { + node := jsonNode{ + "$type": "DeclareVariableStatement", + } + if len(s.Declarations) > 0 { + decls := make([]jsonNode, len(s.Declarations)) + for i, d := range s.Declarations { + decls[i] = declareVariableElementToJSON(d) } - node["OptimizerHints"] = hints + node["Declarations"] = decls } return node } -func optimizerHintToJSON(h *ast.OptimizerHint) jsonNode { +func declareVariableElementToJSON(elem *ast.DeclareVariableElement) jsonNode { node := jsonNode{ - "$type": "OptimizerHint", + "$type": "DeclareVariableElement", } - if h.HintKind != "" { - node["HintKind"] = h.HintKind + if elem.VariableName != nil { + node["VariableName"] = identifierToJSON(elem.VariableName) + } + if elem.DataType != nil { + node["DataType"] = sqlDataTypeReferenceToJSON(elem.DataType) + } + if elem.Value != nil { + node["Value"] = scalarExpressionToJSON(elem.Value) } return node } -func queryExpressionToJSON(qe ast.QueryExpression) jsonNode { - switch q := qe.(type) { - case *ast.QuerySpecification: - return querySpecificationToJSON(q) - default: - return jsonNode{"$type": "UnknownQueryExpression"} +func sqlDataTypeReferenceToJSON(dt *ast.SqlDataTypeReference) jsonNode { + node := jsonNode{ + "$type": "SqlDataTypeReference", + } + if dt.SqlDataTypeOption != "" { + node["SqlDataTypeOption"] = dt.SqlDataTypeOption + } + if len(dt.Parameters) > 0 { + params := make([]jsonNode, len(dt.Parameters)) + for i, p := range dt.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params } + if dt.Name != nil { + node["Name"] = schemaObjectNameToJSON(dt.Name) + } + return node } -func querySpecificationToJSON(q *ast.QuerySpecification) jsonNode { +func setVariableStatementToJSON(s *ast.SetVariableStatement) jsonNode { node := jsonNode{ - "$type": "QuerySpecification", + "$type": "SetVariableStatement", } - if q.UniqueRowFilter != "" { - node["UniqueRowFilter"] = q.UniqueRowFilter + if s.Variable != nil { + node["Variable"] = scalarExpressionToJSON(s.Variable) } - if len(q.SelectElements) > 0 { - elems := make([]jsonNode, len(q.SelectElements)) - for i, elem := range q.SelectElements { - elems[i] = selectElementToJSON(elem) - } - node["SelectElements"] = elems + if s.Expression != nil { + node["Expression"] = scalarExpressionToJSON(s.Expression) } - if q.FromClause != nil { - node["FromClause"] = fromClauseToJSON(q.FromClause) + if s.CursorDefinition != nil { + node["CursorDefinition"] = cursorDefinitionToJSON(s.CursorDefinition) } - if q.WhereClause != nil { - node["WhereClause"] = whereClauseToJSON(q.WhereClause) + if s.AssignmentKind != "" { + node["AssignmentKind"] = s.AssignmentKind } - if q.GroupByClause != nil { - node["GroupByClause"] = groupByClauseToJSON(q.GroupByClause) + if s.SeparatorType != "" { + node["SeparatorType"] = s.SeparatorType } - if q.HavingClause != nil { - node["HavingClause"] = havingClauseToJSON(q.HavingClause) + return node +} + +func cursorDefinitionToJSON(cd *ast.CursorDefinition) jsonNode { + node := jsonNode{ + "$type": "CursorDefinition", } - if q.OrderByClause != nil { - node["OrderByClause"] = orderByClauseToJSON(q.OrderByClause) + if cd.Select != nil { + node["Select"] = queryExpressionToJSON(cd.Select) } return node } -func selectElementToJSON(elem ast.SelectElement) jsonNode { - switch e := elem.(type) { - case *ast.SelectScalarExpression: - node := jsonNode{ - "$type": "SelectScalarExpression", - } - if e.Expression != nil { - node["Expression"] = scalarExpressionToJSON(e.Expression) - } - if e.ColumnName != nil { - node["ColumnName"] = identifierOrValueExpressionToJSON(e.ColumnName) - } - return node - case *ast.SelectStarExpression: - node := jsonNode{ - "$type": "SelectStarExpression", - } - if e.Qualifier != nil { - node["Qualifier"] = multiPartIdentifierToJSON(e.Qualifier) - } - return node - default: - return jsonNode{"$type": "UnknownSelectElement"} +func ifStatementToJSON(s *ast.IfStatement) jsonNode { + node := jsonNode{ + "$type": "IfStatement", } + if s.Predicate != nil { + node["Predicate"] = booleanExpressionToJSON(s.Predicate) + } + if s.ThenStatement != nil { + node["ThenStatement"] = statementToJSON(s.ThenStatement) + } + if s.ElseStatement != nil { + node["ElseStatement"] = statementToJSON(s.ElseStatement) + } + return node } -func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { - switch e := expr.(type) { - case *ast.ColumnReferenceExpression: - node := jsonNode{ - "$type": "ColumnReferenceExpression", - } - if e.ColumnType != "" { - node["ColumnType"] = e.ColumnType - } - if e.MultiPartIdentifier != nil { - node["MultiPartIdentifier"] = multiPartIdentifierToJSON(e.MultiPartIdentifier) - } - return node - case *ast.IntegerLiteral: - node := jsonNode{ - "$type": "IntegerLiteral", - } - if e.LiteralType != "" { - node["LiteralType"] = e.LiteralType - } - if e.Value != "" { - node["Value"] = e.Value - } - return node - case *ast.StringLiteral: - node := jsonNode{ - "$type": "StringLiteral", - } - if e.LiteralType != "" { - node["LiteralType"] = e.LiteralType - } - // Always include IsNational and IsLargeObject - node["IsNational"] = e.IsNational - node["IsLargeObject"] = e.IsLargeObject - if e.Value != "" { - node["Value"] = e.Value - } - return node - case *ast.FunctionCall: - node := jsonNode{ - "$type": "FunctionCall", - } - if e.FunctionName != nil { - node["FunctionName"] = identifierToJSON(e.FunctionName) - } - if len(e.Parameters) > 0 { - params := make([]jsonNode, len(e.Parameters)) - for i, p := range e.Parameters { - params[i] = scalarExpressionToJSON(p) - } - node["Parameters"] = params - } - if e.UniqueRowFilter != "" { - node["UniqueRowFilter"] = e.UniqueRowFilter - } - if e.WithArrayWrapper { - node["WithArrayWrapper"] = e.WithArrayWrapper - } - return node - case *ast.BinaryExpression: - node := jsonNode{ - "$type": "BinaryExpression", - } - if e.BinaryExpressionType != "" { - node["BinaryExpressionType"] = e.BinaryExpressionType - } - if e.FirstExpression != nil { - node["FirstExpression"] = scalarExpressionToJSON(e.FirstExpression) - } - if e.SecondExpression != nil { - node["SecondExpression"] = scalarExpressionToJSON(e.SecondExpression) - } - return node - case *ast.VariableReference: - node := jsonNode{ - "$type": "VariableReference", - } - if e.Name != "" { - node["Name"] = e.Name - } - return node - default: - return jsonNode{"$type": "UnknownScalarExpression"} +func whileStatementToJSON(s *ast.WhileStatement) jsonNode { + node := jsonNode{ + "$type": "WhileStatement", + } + if s.Predicate != nil { + node["Predicate"] = booleanExpressionToJSON(s.Predicate) + } + if s.Statement != nil { + node["Statement"] = statementToJSON(s.Statement) + } + return node +} + +func beginEndBlockStatementToJSON(s *ast.BeginEndBlockStatement) jsonNode { + node := jsonNode{ + "$type": "BeginEndBlockStatement", + } + if s.StatementList != nil { + node["StatementList"] = statementListToJSON(s.StatementList) } + return node } -func identifierToJSON(id *ast.Identifier) jsonNode { +func statementListToJSON(sl *ast.StatementList) jsonNode { node := jsonNode{ - "$type": "Identifier", - } - if id.Value != "" { - node["Value"] = id.Value + "$type": "StatementList", } - if id.QuoteType != "" { - node["QuoteType"] = id.QuoteType + if len(sl.Statements) > 0 { + stmts := make([]jsonNode, len(sl.Statements)) + for i, s := range sl.Statements { + stmts[i] = statementToJSON(s) + } + node["Statements"] = stmts } return node } -func multiPartIdentifierToJSON(mpi *ast.MultiPartIdentifier) jsonNode { +func createViewStatementToJSON(s *ast.CreateViewStatement) jsonNode { node := jsonNode{ - "$type": "MultiPartIdentifier", + "$type": "CreateViewStatement", } - if mpi.Count > 0 { - node["Count"] = mpi.Count + if s.SchemaObjectName != nil { + node["SchemaObjectName"] = schemaObjectNameToJSON(s.SchemaObjectName) } - if len(mpi.Identifiers) > 0 { - ids := make([]jsonNode, len(mpi.Identifiers)) - for i, id := range mpi.Identifiers { - ids[i] = identifierToJSON(id) + if len(s.Columns) > 0 { + cols := make([]jsonNode, len(s.Columns)) + for i, c := range s.Columns { + cols[i] = identifierToJSON(c) } - node["Identifiers"] = ids + node["Columns"] = cols + } + if s.SelectStatement != nil { + node["SelectStatement"] = selectStatementToJSON(s.SelectStatement) } + node["WithCheckOption"] = s.WithCheckOption + node["IsMaterialized"] = s.IsMaterialized return node } -func identifierOrValueExpressionToJSON(iove *ast.IdentifierOrValueExpression) jsonNode { +func createSchemaStatementToJSON(s *ast.CreateSchemaStatement) jsonNode { node := jsonNode{ - "$type": "IdentifierOrValueExpression", + "$type": "CreateSchemaStatement", } - if iove.Value != "" { - node["Value"] = iove.Value + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) } - if iove.Identifier != nil { - node["Identifier"] = identifierToJSON(iove.Identifier) + if s.Owner != nil { + node["Owner"] = identifierToJSON(s.Owner) + } + if s.StatementList != nil { + node["StatementList"] = statementListToJSON(s.StatementList) } return node } -func fromClauseToJSON(fc *ast.FromClause) jsonNode { +func executeStatementToJSON(s *ast.ExecuteStatement) jsonNode { node := jsonNode{ - "$type": "FromClause", + "$type": "ExecuteStatement", } - if len(fc.TableReferences) > 0 { - refs := make([]jsonNode, len(fc.TableReferences)) - for i, ref := range fc.TableReferences { - refs[i] = tableReferenceToJSON(ref) - } - node["TableReferences"] = refs + if s.ExecuteSpecification != nil { + node["ExecuteSpecification"] = executeSpecificationToJSON(s.ExecuteSpecification) } return node } -func tableReferenceToJSON(ref ast.TableReference) jsonNode { - switch r := ref.(type) { - case *ast.NamedTableReference: - node := jsonNode{ - "$type": "NamedTableReference", - } - if r.SchemaObject != nil { - node["SchemaObject"] = schemaObjectNameToJSON(r.SchemaObject) - } - if r.Alias != nil { - node["Alias"] = identifierToJSON(r.Alias) - } - node["ForPath"] = r.ForPath - return node - case *ast.QualifiedJoin: - node := jsonNode{ - "$type": "QualifiedJoin", - } - if r.SearchCondition != nil { - node["SearchCondition"] = booleanExpressionToJSON(r.SearchCondition) - } - if r.QualifiedJoinType != "" { - node["QualifiedJoinType"] = r.QualifiedJoinType - } - if r.JoinHint != "" { - node["JoinHint"] = r.JoinHint - } - if r.FirstTableReference != nil { - node["FirstTableReference"] = tableReferenceToJSON(r.FirstTableReference) - } - if r.SecondTableReference != nil { - node["SecondTableReference"] = tableReferenceToJSON(r.SecondTableReference) - } - return node - default: - return jsonNode{"$type": "UnknownTableReference"} +func returnStatementToJSON(s *ast.ReturnStatement) jsonNode { + node := jsonNode{ + "$type": "ReturnStatement", + } + if s.Expression != nil { + node["Expression"] = scalarExpressionToJSON(s.Expression) } + return node } -func schemaObjectNameToJSON(son *ast.SchemaObjectName) jsonNode { - node := jsonNode{ - "$type": "SchemaObjectName", +func breakStatementToJSON() jsonNode { + return jsonNode{ + "$type": "BreakStatement", } - if son.BaseIdentifier != nil { - node["BaseIdentifier"] = identifierToJSON(son.BaseIdentifier) +} + +func continueStatementToJSON() jsonNode { + return jsonNode{ + "$type": "ContinueStatement", } - if son.Count > 0 { - node["Count"] = son.Count +} + +func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) { + // Consume TABLE + p.nextToken() + + stmt := &ast.CreateTableStatement{} + + // Parse table name + name, err := p.parseSchemaObjectName() + if err != nil { + return nil, err } - if len(son.Identifiers) > 0 { - // Handle $ref for identifiers that reference the base identifier - ids := make([]any, len(son.Identifiers)) - for i, id := range son.Identifiers { - if son.BaseIdentifier != nil && id == son.BaseIdentifier { - ids[i] = jsonNode{"$ref": "Identifier"} - } else { - ids[i] = identifierToJSON(id) - } + stmt.SchemaObjectName = name + + // Expect ( + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after table name, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt.Definition = &ast.TableDefinition{} + + // Parse column definitions + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + colDef, err := p.parseColumnDefinition() + if err != nil { + return nil, err + } + stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break } - node["Identifiers"] = ids } - return node -} -func whereClauseToJSON(wc *ast.WhereClause) jsonNode { - node := jsonNode{ - "$type": "WhereClause", + // Expect ) + if p.curTok.Type == TokenRParen { + p.nextToken() } - if wc.SearchCondition != nil { - node["SearchCondition"] = booleanExpressionToJSON(wc.SearchCondition) + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() } - return node + + return stmt, nil } -func booleanExpressionToJSON(expr ast.BooleanExpression) jsonNode { - switch e := expr.(type) { - case *ast.BooleanComparisonExpression: - node := jsonNode{ - "$type": "BooleanComparisonExpression", - } - if e.ComparisonType != "" { - node["ComparisonType"] = e.ComparisonType - } - if e.FirstExpression != nil { - node["FirstExpression"] = scalarExpressionToJSON(e.FirstExpression) - } - if e.SecondExpression != nil { - node["SecondExpression"] = scalarExpressionToJSON(e.SecondExpression) - } - return node - case *ast.BooleanBinaryExpression: - node := jsonNode{ - "$type": "BooleanBinaryExpression", - } - if e.BinaryExpressionType != "" { - node["BinaryExpressionType"] = e.BinaryExpressionType +func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { + col := &ast.ColumnDefinition{} + + // Parse column name (parseIdentifier already calls nextToken) + col.ColumnIdentifier = p.parseIdentifier() + + // Parse data type + dataType, err := p.parseDataType() + if err != nil { + return nil, err + } + col.DataType = dataType + + return col, nil +} + +func (p *Parser) parseGrantStatement() (*ast.GrantStatement, error) { + // Consume GRANT + p.nextToken() + + stmt := &ast.GrantStatement{} + + // Parse permission(s) + perm := &ast.Permission{} + for p.curTok.Type != TokenTo && p.curTok.Type != TokenEOF { + if p.curTok.Type == TokenIdent || p.curTok.Type == TokenCreate || + p.curTok.Type == TokenProcedure || p.curTok.Type == TokenView || + p.curTok.Type == TokenSelect || p.curTok.Type == TokenInsert || + p.curTok.Type == TokenUpdate || p.curTok.Type == TokenDelete { + perm.Identifiers = append(perm.Identifiers, &ast.Identifier{ + Value: p.curTok.Literal, + QuoteType: "NotQuoted", + }) + p.nextToken() + } else if p.curTok.Type == TokenComma { + stmt.Permissions = append(stmt.Permissions, perm) + perm = &ast.Permission{} + p.nextToken() + } else { + break } - if e.FirstExpression != nil { - node["FirstExpression"] = booleanExpressionToJSON(e.FirstExpression) + } + if len(perm.Identifiers) > 0 { + stmt.Permissions = append(stmt.Permissions, perm) + } + + // Expect TO + if p.curTok.Type == TokenTo { + p.nextToken() + } + + // Parse principal(s) + for p.curTok.Type != TokenEOF && p.curTok.Type != TokenSemicolon { + principal := &ast.SecurityPrincipal{} + if p.curTok.Type == TokenPublic { + principal.PrincipalType = "Public" + p.nextToken() + } else if p.curTok.Type == TokenIdent || p.curTok.Type == TokenLBracket { + principal.PrincipalType = "Identifier" + // parseIdentifier already calls nextToken() + principal.Identifier = p.parseIdentifier() + } else { + break } - if e.SecondExpression != nil { - node["SecondExpression"] = booleanExpressionToJSON(e.SecondExpression) + stmt.Principals = append(stmt.Principals, principal) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break } - return node - default: - return jsonNode{"$type": "UnknownBooleanExpression"} } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil } -func groupByClauseToJSON(gbc *ast.GroupByClause) jsonNode { +func createTableStatementToJSON(s *ast.CreateTableStatement) jsonNode { node := jsonNode{ - "$type": "GroupByClause", + "$type": "CreateTableStatement", + "SchemaObjectName": schemaObjectNameToJSON(s.SchemaObjectName), + "AsEdge": s.AsEdge, + "AsFileTable": s.AsFileTable, + "AsNode": s.AsNode, + "Definition": tableDefinitionToJSON(s.Definition), } - if gbc.GroupByOption != "" { - node["GroupByOption"] = gbc.GroupByOption - } - if gbc.All { - node["All"] = gbc.All + return node +} + +func tableDefinitionToJSON(t *ast.TableDefinition) jsonNode { + node := jsonNode{ + "$type": "TableDefinition", } - if len(gbc.GroupingSpecifications) > 0 { - specs := make([]jsonNode, len(gbc.GroupingSpecifications)) - for i, spec := range gbc.GroupingSpecifications { - specs[i] = groupingSpecificationToJSON(spec) + if len(t.ColumnDefinitions) > 0 { + cols := make([]jsonNode, len(t.ColumnDefinitions)) + for i, col := range t.ColumnDefinitions { + cols[i] = columnDefinitionToJSON(col) } - node["GroupingSpecifications"] = specs + node["ColumnDefinitions"] = cols } return node } -func groupingSpecificationToJSON(spec ast.GroupingSpecification) jsonNode { - switch s := spec.(type) { - case *ast.ExpressionGroupingSpecification: - node := jsonNode{ - "$type": "ExpressionGroupingSpecification", - } - if s.Expression != nil { - node["Expression"] = scalarExpressionToJSON(s.Expression) - } - if s.DistributedAggregation { - node["DistributedAggregation"] = s.DistributedAggregation - } - return node +func columnDefinitionToJSON(c *ast.ColumnDefinition) jsonNode { + node := jsonNode{ + "$type": "ColumnDefinition", + "IsPersisted": c.IsPersisted, + "IsRowGuidCol": c.IsRowGuidCol, + "IsHidden": c.IsHidden, + "IsMasked": c.IsMasked, + "ColumnIdentifier": identifierToJSON(c.ColumnIdentifier), + } + if c.DataType != nil { + node["DataType"] = dataTypeReferenceToJSON(c.DataType) + } + return node +} + +func dataTypeReferenceToJSON(d ast.DataTypeReference) jsonNode { + switch dt := d.(type) { + case *ast.SqlDataTypeReference: + return sqlDataTypeReferenceToJSON(dt) default: - return jsonNode{"$type": "UnknownGroupingSpecification"} + return jsonNode{"$type": "UnknownDataType"} } } -func havingClauseToJSON(hc *ast.HavingClause) jsonNode { +func grantStatementToJSON(s *ast.GrantStatement) jsonNode { node := jsonNode{ - "$type": "HavingClause", + "$type": "GrantStatement", + "WithGrantOption": s.WithGrantOption, } - if hc.SearchCondition != nil { - node["SearchCondition"] = booleanExpressionToJSON(hc.SearchCondition) + if len(s.Permissions) > 0 { + perms := make([]jsonNode, len(s.Permissions)) + for i, p := range s.Permissions { + perms[i] = permissionToJSON(p) + } + node["Permissions"] = perms + } + if len(s.Principals) > 0 { + principals := make([]jsonNode, len(s.Principals)) + for i, p := range s.Principals { + principals[i] = securityPrincipalToJSON(p) + } + node["Principals"] = principals } return node } -func orderByClauseToJSON(obc *ast.OrderByClause) jsonNode { +func permissionToJSON(p *ast.Permission) jsonNode { node := jsonNode{ - "$type": "OrderByClause", + "$type": "Permission", } - if len(obc.OrderByElements) > 0 { - elems := make([]jsonNode, len(obc.OrderByElements)) - for i, elem := range obc.OrderByElements { - elems[i] = expressionWithSortOrderToJSON(elem) + if len(p.Identifiers) > 0 { + ids := make([]jsonNode, len(p.Identifiers)) + for i, id := range p.Identifiers { + ids[i] = identifierToJSON(id) } - node["OrderByElements"] = elems + node["Identifiers"] = ids } return node } -func expressionWithSortOrderToJSON(ewso *ast.ExpressionWithSortOrder) jsonNode { +func securityPrincipalToJSON(p *ast.SecurityPrincipal) jsonNode { node := jsonNode{ - "$type": "ExpressionWithSortOrder", + "$type": "SecurityPrincipal", + "PrincipalType": p.PrincipalType, } - if ewso.SortOrder != "" { - node["SortOrder"] = ewso.SortOrder - } - if ewso.Expression != nil { - node["Expression"] = scalarExpressionToJSON(ewso.Expression) + if p.Identifier != nil { + node["Identifier"] = identifierToJSON(p.Identifier) } return node } + +func predicateSetStatementToJSON(s *ast.PredicateSetStatement) jsonNode { + return jsonNode{ + "$type": "PredicateSetStatement", + "Options": string(s.Options), + "IsOn": s.IsOn, + } +} diff --git a/parser/testdata/BeginEndBlockStatementTests/metadata.json b/parser/testdata/BeginEndBlockStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/BeginEndBlockStatementTests/metadata.json +++ b/parser/testdata/BeginEndBlockStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/CreateSchemaStatementTests/metadata.json b/parser/testdata/CreateSchemaStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/CreateSchemaStatementTests/metadata.json +++ b/parser/testdata/CreateSchemaStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/DeclareVariableStatementTests/metadata.json b/parser/testdata/DeclareVariableStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/DeclareVariableStatementTests/metadata.json +++ b/parser/testdata/DeclareVariableStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/IfStatementTests/metadata.json b/parser/testdata/IfStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/IfStatementTests/metadata.json +++ b/parser/testdata/IfStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/SelectStatementTests/metadata.json b/parser/testdata/SelectStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/SelectStatementTests/metadata.json +++ b/parser/testdata/SelectStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false}