diff --git a/ast/alter_table_drop_table_element_statement.go b/ast/alter_table_drop_table_element_statement.go new file mode 100644 index 00000000..fe1a388f --- /dev/null +++ b/ast/alter_table_drop_table_element_statement.go @@ -0,0 +1,19 @@ +package ast + +// AlterTableDropTableElementStatement represents an ALTER TABLE ... DROP statement. +type AlterTableDropTableElementStatement struct { + SchemaObjectName *SchemaObjectName + AlterTableDropTableElements []*AlterTableDropTableElement +} + +func (*AlterTableDropTableElementStatement) node() {} +func (*AlterTableDropTableElementStatement) statement() {} + +// AlterTableDropTableElement represents an element being dropped from a table. +type AlterTableDropTableElement struct { + TableElementType string + Name *Identifier + IsIfExists bool +} + +func (*AlterTableDropTableElement) node() {} diff --git a/ast/binary_expression.go b/ast/binary_expression.go new file mode 100644 index 00000000..90f87a8d --- /dev/null +++ b/ast/binary_expression.go @@ -0,0 +1,11 @@ +package ast + +// BinaryExpression represents a binary scalar expression (Add, Subtract, etc.). +type BinaryExpression struct { + BinaryExpressionType string `json:"BinaryExpressionType,omitempty"` + FirstExpression ScalarExpression `json:"FirstExpression,omitempty"` + SecondExpression ScalarExpression `json:"SecondExpression,omitempty"` +} + +func (*BinaryExpression) node() {} +func (*BinaryExpression) scalarExpression() {} diff --git a/ast/drop_credential_statement.go b/ast/drop_credential_statement.go new file mode 100644 index 00000000..f1c09387 --- /dev/null +++ b/ast/drop_credential_statement.go @@ -0,0 +1,11 @@ +package ast + +// DropCredentialStatement represents a DROP CREDENTIAL statement. +type DropCredentialStatement struct { + IsDatabaseScoped bool + Name *Identifier + IsIfExists bool +} + +func (*DropCredentialStatement) node() {} +func (*DropCredentialStatement) statement() {} diff --git a/ast/print_statement.go b/ast/print_statement.go new file mode 100644 index 00000000..89195b2e --- /dev/null +++ b/ast/print_statement.go @@ -0,0 +1,9 @@ +package ast + +// PrintStatement represents a PRINT statement. +type PrintStatement struct { + Expression ScalarExpression +} + +func (*PrintStatement) node() {} +func (*PrintStatement) statement() {} diff --git a/ast/revert_statement.go b/ast/revert_statement.go new file mode 100644 index 00000000..b5deb5d5 --- /dev/null +++ b/ast/revert_statement.go @@ -0,0 +1,9 @@ +package ast + +// RevertStatement represents a REVERT statement. +type RevertStatement struct { + Cookie ScalarExpression +} + +func (*RevertStatement) node() {} +func (*RevertStatement) statement() {} diff --git a/ast/throw_statement.go b/ast/throw_statement.go new file mode 100644 index 00000000..0d97d52a --- /dev/null +++ b/ast/throw_statement.go @@ -0,0 +1,11 @@ +package ast + +// ThrowStatement represents a THROW statement. +type ThrowStatement struct { + ErrorNumber ScalarExpression + Message ScalarExpression + State ScalarExpression +} + +func (*ThrowStatement) node() {} +func (*ThrowStatement) statement() {} diff --git a/ast/variable_reference.go b/ast/variable_reference.go new file mode 100644 index 00000000..974656c3 --- /dev/null +++ b/ast/variable_reference.go @@ -0,0 +1,9 @@ +package ast + +// VariableReference represents a reference to a variable (e.g., @var). +type VariableReference struct { + Name string `json:"Name,omitempty"` +} + +func (*VariableReference) node() {} +func (*VariableReference) scalarExpression() {} diff --git a/parser/lexer.go b/parser/lexer.go index 3dd8ac7a..94e8270c 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -38,6 +38,18 @@ const ( TokenOption TokenAll TokenDistinct + TokenPrint + TokenThrow + TokenAlter + TokenTable + TokenDrop + TokenIndex + TokenRevert + TokenWith + TokenCookie + TokenDatabase + TokenScoped + TokenCredential ) // Token represents a lexical token. @@ -281,6 +293,18 @@ var keywords = map[string]TokenType{ "OPTION": TokenOption, "ALL": TokenAll, "DISTINCT": TokenDistinct, + "PRINT": TokenPrint, + "THROW": TokenThrow, + "ALTER": TokenAlter, + "TABLE": TokenTable, + "DROP": TokenDrop, + "INDEX": TokenIndex, + "REVERT": TokenRevert, + "WITH": TokenWith, + "COOKIE": TokenCookie, + "DATABASE": TokenDatabase, + "SCOPED": TokenScoped, + "CREDENTIAL": TokenCredential, } func lookupKeyword(ident string) TokenType { diff --git a/parser/parser.go b/parser/parser.go index 6c3a5179..1bed31cc 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -88,6 +88,16 @@ func (p *Parser) parseStatement() (ast.Statement, error) { switch p.curTok.Type { case TokenSelect: return p.parseSelectStatement() + case TokenPrint: + return p.parsePrintStatement() + case TokenThrow: + return p.parseThrowStatement() + case TokenAlter: + return p.parseAlterStatement() + case TokenRevert: + return p.parseRevertStatement() + case TokenDrop: + return p.parseDropStatement() case TokenSemicolon: p.nextToken() return nil, nil @@ -96,6 +106,241 @@ func (p *Parser) parseStatement() (ast.Statement, error) { } } +func (p *Parser) parseRevertStatement() (*ast.RevertStatement, error) { + // Consume REVERT + p.nextToken() + + stmt := &ast.RevertStatement{} + + // Check for WITH COOKIE = expression + if p.curTok.Type == TokenWith { + p.nextToken() // consume WITH + + if p.curTok.Type != TokenCookie { + return nil, fmt.Errorf("expected COOKIE after WITH, got %s", p.curTok.Literal) + } + p.nextToken() // consume COOKIE + + if p.curTok.Type != TokenEquals { + return nil, fmt.Errorf("expected = after COOKIE, got %s", p.curTok.Literal) + } + p.nextToken() // consume = + + cookie, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.Cookie = cookie + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseDropStatement() (ast.Statement, error) { + // Consume DROP + p.nextToken() + + // Check what type of DROP statement this is + if p.curTok.Type == TokenDatabase { + return p.parseDropDatabaseScopedStatement() + } + + return nil, fmt.Errorf("unexpected token after DROP: %s", p.curTok.Literal) +} + +func (p *Parser) parseDropDatabaseScopedStatement() (ast.Statement, error) { + // Consume DATABASE + p.nextToken() + + if p.curTok.Type != TokenScoped { + return nil, fmt.Errorf("expected SCOPED after DATABASE, got %s", p.curTok.Literal) + } + p.nextToken() // consume SCOPED + + if p.curTok.Type == TokenCredential { + return p.parseDropCredentialStatement(true) + } + + return nil, fmt.Errorf("unexpected token after SCOPED: %s", p.curTok.Literal) +} + +func (p *Parser) parseDropCredentialStatement(isDatabaseScoped bool) (*ast.DropCredentialStatement, error) { + // Consume CREDENTIAL + p.nextToken() + + stmt := &ast.DropCredentialStatement{ + IsDatabaseScoped: isDatabaseScoped, + IsIfExists: false, + } + + // Parse credential name + if p.curTok.Type != TokenIdent { + return nil, fmt.Errorf("expected identifier, got %s", p.curTok.Literal) + } + + stmt.Name = &ast.Identifier{ + Value: p.curTok.Literal, + QuoteType: "NotQuoted", + } + p.nextToken() + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseAlterStatement() (ast.Statement, error) { + // Consume ALTER + p.nextToken() + + // Check what type of ALTER statement this is + if p.curTok.Type == TokenTable { + return p.parseAlterTableStatement() + } + + return nil, fmt.Errorf("unexpected token after ALTER: %s", p.curTok.Literal) +} + +func (p *Parser) parseAlterTableStatement() (ast.Statement, error) { + // Consume TABLE + p.nextToken() + + // Parse table name + tableName, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + + // Check what kind of ALTER TABLE statement this is + if p.curTok.Type == TokenDrop { + return p.parseAlterTableDropStatement(tableName) + } + + return nil, fmt.Errorf("unexpected token in ALTER TABLE: %s", p.curTok.Literal) +} + +func (p *Parser) parseAlterTableDropStatement(tableName *ast.SchemaObjectName) (*ast.AlterTableDropTableElementStatement, error) { + // Consume DROP + p.nextToken() + + stmt := &ast.AlterTableDropTableElementStatement{ + SchemaObjectName: tableName, + } + + // Parse the element type and name + var elementType string + switch p.curTok.Type { + case TokenIndex: + elementType = "Index" + p.nextToken() + default: + return nil, fmt.Errorf("unexpected token after DROP: %s", p.curTok.Literal) + } + + // Parse the element name + if p.curTok.Type != TokenIdent { + return nil, fmt.Errorf("expected identifier after %s, got %s", elementType, p.curTok.Literal) + } + + element := &ast.AlterTableDropTableElement{ + TableElementType: elementType, + Name: &ast.Identifier{ + Value: p.curTok.Literal, + QuoteType: "NotQuoted", + }, + IsIfExists: false, + } + p.nextToken() + + stmt.AlterTableDropTableElements = append(stmt.AlterTableDropTableElements, element) + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parsePrintStatement() (*ast.PrintStatement, error) { + // Consume PRINT + p.nextToken() + + // Parse expression + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return &ast.PrintStatement{Expression: expr}, nil +} + +func (p *Parser) parseThrowStatement() (*ast.ThrowStatement, error) { + // Consume THROW + p.nextToken() + + stmt := &ast.ThrowStatement{} + + // THROW can be used without arguments (re-throw) + if p.curTok.Type == TokenSemicolon || p.curTok.Type == TokenEOF || + p.curTok.Type == TokenSelect || p.curTok.Type == TokenPrint || p.curTok.Type == TokenThrow { + return stmt, nil + } + + // Parse error number + errNum, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.ErrorNumber = errNum + + // Expect comma + if p.curTok.Type != TokenComma { + return nil, fmt.Errorf("expected comma after error number, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse message + msg, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.Message = msg + + // Expect comma + if p.curTok.Type != TokenComma { + return nil, fmt.Errorf("expected comma after message, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse state + state, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.State = state + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) { stmt := &ast.SelectStatement{} @@ -202,16 +447,83 @@ func (p *Parser) parseSelectElement() (ast.SelectElement, error) { } func (p *Parser) parseScalarExpression() (ast.ScalarExpression, error) { - // For now, only handle column references and identifiers - if p.curTok.Type == TokenIdent { - return p.parseColumnReference() + return p.parseAdditiveExpression() +} + +func (p *Parser) parseAdditiveExpression() (ast.ScalarExpression, error) { + left, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + + for p.curTok.Type == TokenPlus || p.curTok.Type == TokenMinus { + var opType string + if p.curTok.Type == TokenPlus { + opType = "Add" + } else { + opType = "Subtract" + } + p.nextToken() + + right, err := p.parsePrimaryExpression() + if err != nil { + return nil, err + } + + left = &ast.BinaryExpression{ + BinaryExpressionType: opType, + FirstExpression: left, + SecondExpression: right, + } } - if p.curTok.Type == TokenNumber { + + return left, nil +} + +func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) { + switch p.curTok.Type { + case TokenIdent: + // Check if it's a variable reference (starts with @) + if strings.HasPrefix(p.curTok.Literal, "@") { + name := p.curTok.Literal + p.nextToken() + return &ast.VariableReference{Name: name}, nil + } + return p.parseColumnReference() + case TokenNumber: val := p.curTok.Literal p.nextToken() return &ast.IntegerLiteral{LiteralType: "Integer", Value: val}, nil + case TokenString: + return p.parseStringLiteral() + default: + return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal) } - return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal) +} + +func (p *Parser) parseStringLiteral() (*ast.StringLiteral, error) { + raw := p.curTok.Literal + p.nextToken() + + // Remove surrounding quotes and handle escaped quotes + if len(raw) >= 2 && raw[0] == '\'' && raw[len(raw)-1] == '\'' { + inner := raw[1 : len(raw)-1] + // Replace escaped quotes + value := strings.ReplaceAll(inner, "''", "'") + return &ast.StringLiteral{ + LiteralType: "String", + IsNational: false, + IsLargeObject: false, + Value: value, + }, nil + } + + return &ast.StringLiteral{ + LiteralType: "String", + IsNational: false, + IsLargeObject: false, + Value: raw, + }, nil } func (p *Parser) parseColumnReference() (*ast.ColumnReferenceExpression, error) { @@ -431,11 +743,100 @@ func statementToJSON(stmt ast.Statement) jsonNode { switch s := stmt.(type) { case *ast.SelectStatement: return selectStatementToJSON(s) + case *ast.PrintStatement: + return printStatementToJSON(s) + case *ast.ThrowStatement: + return throwStatementToJSON(s) + case *ast.AlterTableDropTableElementStatement: + return alterTableDropTableElementStatementToJSON(s) + case *ast.RevertStatement: + return revertStatementToJSON(s) + case *ast.DropCredentialStatement: + return dropCredentialStatementToJSON(s) default: return jsonNode{"$type": "UnknownStatement"} } } +func revertStatementToJSON(s *ast.RevertStatement) jsonNode { + node := jsonNode{ + "$type": "RevertStatement", + } + if s.Cookie != nil { + node["Cookie"] = scalarExpressionToJSON(s.Cookie) + } + return node +} + +func dropCredentialStatementToJSON(s *ast.DropCredentialStatement) jsonNode { + node := jsonNode{ + "$type": "DropCredentialStatement", + } + node["IsDatabaseScoped"] = s.IsDatabaseScoped + if s.Name != nil { + node["Name"] = identifierToJSON(s.Name) + } + node["IsIfExists"] = s.IsIfExists + return node +} + +func alterTableDropTableElementStatementToJSON(s *ast.AlterTableDropTableElementStatement) jsonNode { + node := jsonNode{ + "$type": "AlterTableDropTableElementStatement", + } + if len(s.AlterTableDropTableElements) > 0 { + elements := make([]jsonNode, len(s.AlterTableDropTableElements)) + for i, e := range s.AlterTableDropTableElements { + elements[i] = alterTableDropTableElementToJSON(e) + } + node["AlterTableDropTableElements"] = elements + } + if s.SchemaObjectName != nil { + node["SchemaObjectName"] = schemaObjectNameToJSON(s.SchemaObjectName) + } + return node +} + +func alterTableDropTableElementToJSON(e *ast.AlterTableDropTableElement) jsonNode { + node := jsonNode{ + "$type": "AlterTableDropTableElement", + } + if e.TableElementType != "" { + node["TableElementType"] = e.TableElementType + } + if e.Name != nil { + node["Name"] = identifierToJSON(e.Name) + } + node["IsIfExists"] = e.IsIfExists + return node +} + +func printStatementToJSON(s *ast.PrintStatement) jsonNode { + node := jsonNode{ + "$type": "PrintStatement", + } + if s.Expression != nil { + node["Expression"] = scalarExpressionToJSON(s.Expression) + } + return node +} + +func throwStatementToJSON(s *ast.ThrowStatement) jsonNode { + node := jsonNode{ + "$type": "ThrowStatement", + } + if s.ErrorNumber != nil { + node["ErrorNumber"] = scalarExpressionToJSON(s.ErrorNumber) + } + if s.Message != nil { + node["Message"] = scalarExpressionToJSON(s.Message) + } + if s.State != nil { + node["State"] = scalarExpressionToJSON(s.State) + } + return node +} + func selectStatementToJSON(s *ast.SelectStatement) jsonNode { node := jsonNode{ "$type": "SelectStatement", @@ -561,12 +962,9 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { if e.LiteralType != "" { node["LiteralType"] = e.LiteralType } - if e.IsNational { - node["IsNational"] = e.IsNational - } - if e.IsLargeObject { - node["IsLargeObject"] = e.IsLargeObject - } + // Always include IsNational and IsLargeObject + node["IsNational"] = e.IsNational + node["IsLargeObject"] = e.IsLargeObject if e.Value != "" { node["Value"] = e.Value } @@ -592,6 +990,28 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode { node["WithArrayWrapper"] = e.WithArrayWrapper } return node + case *ast.BinaryExpression: + node := jsonNode{ + "$type": "BinaryExpression", + } + if e.BinaryExpressionType != "" { + node["BinaryExpressionType"] = e.BinaryExpressionType + } + if e.FirstExpression != nil { + node["FirstExpression"] = scalarExpressionToJSON(e.FirstExpression) + } + if e.SecondExpression != nil { + node["SecondExpression"] = scalarExpressionToJSON(e.SecondExpression) + } + return node + case *ast.VariableReference: + node := jsonNode{ + "$type": "VariableReference", + } + if e.Name != "" { + node["Name"] = e.Name + } + return node default: return jsonNode{"$type": "UnknownScalarExpression"} } diff --git a/parser/testdata/AlterTableDropTableElementStatementTests130/metadata.json b/parser/testdata/AlterTableDropTableElementStatementTests130/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/AlterTableDropTableElementStatementTests130/metadata.json +++ b/parser/testdata/AlterTableDropTableElementStatementTests130/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/Baselines110_ThrowStatementTests/metadata.json b/parser/testdata/Baselines110_ThrowStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/Baselines110_ThrowStatementTests/metadata.json +++ b/parser/testdata/Baselines110_ThrowStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/Baselines130_AlterTableDropTableElementStatementTests130/metadata.json b/parser/testdata/Baselines130_AlterTableDropTableElementStatementTests130/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/Baselines130_AlterTableDropTableElementStatementTests130/metadata.json +++ b/parser/testdata/Baselines130_AlterTableDropTableElementStatementTests130/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/Baselines130_DropDatabaseScopedCredentialStatementTests130/metadata.json b/parser/testdata/Baselines130_DropDatabaseScopedCredentialStatementTests130/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/Baselines130_DropDatabaseScopedCredentialStatementTests130/metadata.json +++ b/parser/testdata/Baselines130_DropDatabaseScopedCredentialStatementTests130/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/Baselines90_RevertStatementTests/metadata.json b/parser/testdata/Baselines90_RevertStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/Baselines90_RevertStatementTests/metadata.json +++ b/parser/testdata/Baselines90_RevertStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/BaselinesCommon_PrintStatementTests/metadata.json b/parser/testdata/BaselinesCommon_PrintStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/BaselinesCommon_PrintStatementTests/metadata.json +++ b/parser/testdata/BaselinesCommon_PrintStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/DropDatabaseScopedCredentialStatementTests130/metadata.json b/parser/testdata/DropDatabaseScopedCredentialStatementTests130/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/DropDatabaseScopedCredentialStatementTests130/metadata.json +++ b/parser/testdata/DropDatabaseScopedCredentialStatementTests130/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/PrintStatementTests/metadata.json b/parser/testdata/PrintStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/PrintStatementTests/metadata.json +++ b/parser/testdata/PrintStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/RevertStatementTests/metadata.json b/parser/testdata/RevertStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/RevertStatementTests/metadata.json +++ b/parser/testdata/RevertStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false} diff --git a/parser/testdata/ThrowStatementTests/metadata.json b/parser/testdata/ThrowStatementTests/metadata.json index 49e9182b..e27d63a6 100644 --- a/parser/testdata/ThrowStatementTests/metadata.json +++ b/parser/testdata/ThrowStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} +{"skip": false}