diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000000..7039ea41aa --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,24 @@ +name: Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.25' + + - name: Build + run: go build ./... + + - name: Test + run: go test -v ./... diff --git a/ast/ast.go b/ast/ast.go new file mode 100644 index 0000000000..5d9be1b59b --- /dev/null +++ b/ast/ast.go @@ -0,0 +1,817 @@ +// Package ast defines the abstract syntax tree for ClickHouse SQL. +package ast + +import ( + "github.com/kyleconroy/doubleclick/token" +) + +// Node is the interface implemented by all AST nodes. +type Node interface { + Pos() token.Position + End() token.Position +} + +// Statement is the interface implemented by all statement nodes. +type Statement interface { + Node + statementNode() +} + +// Expression is the interface implemented by all expression nodes. +type Expression interface { + Node + expressionNode() +} + +// ----------------------------------------------------------------------------- +// Statements + +// SelectWithUnionQuery represents a SELECT query possibly with UNION. +type SelectWithUnionQuery struct { + Position token.Position `json:"-"` + Selects []Statement `json:"selects"` + UnionAll bool `json:"union_all,omitempty"` +} + +func (s *SelectWithUnionQuery) Pos() token.Position { return s.Position } +func (s *SelectWithUnionQuery) End() token.Position { return s.Position } +func (s *SelectWithUnionQuery) statementNode() {} + +// SelectQuery represents a SELECT statement. +type SelectQuery struct { + Position token.Position `json:"-"` + With []Expression `json:"with,omitempty"` + Distinct bool `json:"distinct,omitempty"` + Top Expression `json:"top,omitempty"` + Columns []Expression `json:"columns"` + From *TablesInSelectQuery `json:"from,omitempty"` + PreWhere Expression `json:"prewhere,omitempty"` + Where Expression `json:"where,omitempty"` + GroupBy []Expression `json:"group_by,omitempty"` + WithRollup bool `json:"with_rollup,omitempty"` + WithTotals bool `json:"with_totals,omitempty"` + Having Expression `json:"having,omitempty"` + OrderBy []*OrderByElement `json:"order_by,omitempty"` + Limit Expression `json:"limit,omitempty"` + Offset Expression `json:"offset,omitempty"` + Settings []*SettingExpr `json:"settings,omitempty"` + Format *Identifier `json:"format,omitempty"` +} + +func (s *SelectQuery) Pos() token.Position { return s.Position } +func (s *SelectQuery) End() token.Position { return s.Position } +func (s *SelectQuery) statementNode() {} + +// TablesInSelectQuery represents the tables in a SELECT query. +type TablesInSelectQuery struct { + Position token.Position `json:"-"` + Tables []*TablesInSelectQueryElement `json:"tables"` +} + +func (t *TablesInSelectQuery) Pos() token.Position { return t.Position } +func (t *TablesInSelectQuery) End() token.Position { return t.Position } + +// TablesInSelectQueryElement represents a single table element in a SELECT. +type TablesInSelectQueryElement struct { + Position token.Position `json:"-"` + Table *TableExpression `json:"table"` + Join *TableJoin `json:"join,omitempty"` +} + +func (t *TablesInSelectQueryElement) Pos() token.Position { return t.Position } +func (t *TablesInSelectQueryElement) End() token.Position { return t.Position } + +// TableExpression represents a table reference. +type TableExpression struct { + Position token.Position `json:"-"` + Table Expression `json:"table"` // TableIdentifier, Subquery, or Function + Alias string `json:"alias,omitempty"` + Final bool `json:"final,omitempty"` + Sample *SampleClause `json:"sample,omitempty"` +} + +func (t *TableExpression) Pos() token.Position { return t.Position } +func (t *TableExpression) End() token.Position { return t.Position } + +// SampleClause represents a SAMPLE clause. +type SampleClause struct { + Position token.Position `json:"-"` + Ratio Expression `json:"ratio"` + Offset Expression `json:"offset,omitempty"` +} + +func (s *SampleClause) Pos() token.Position { return s.Position } +func (s *SampleClause) End() token.Position { return s.Position } + +// TableJoin represents a JOIN clause. +type TableJoin struct { + Position token.Position `json:"-"` + Type JoinType `json:"type"` + Strictness JoinStrictness `json:"strictness,omitempty"` + Global bool `json:"global,omitempty"` + On Expression `json:"on,omitempty"` + Using []Expression `json:"using,omitempty"` +} + +func (t *TableJoin) Pos() token.Position { return t.Position } +func (t *TableJoin) End() token.Position { return t.Position } + +// JoinType represents the type of join. +type JoinType string + +const ( + JoinInner JoinType = "INNER" + JoinLeft JoinType = "LEFT" + JoinRight JoinType = "RIGHT" + JoinFull JoinType = "FULL" + JoinCross JoinType = "CROSS" +) + +// JoinStrictness represents the join strictness. +type JoinStrictness string + +const ( + JoinStrictAny JoinStrictness = "ANY" + JoinStrictAll JoinStrictness = "ALL" + JoinStrictAsof JoinStrictness = "ASOF" + JoinStrictSemi JoinStrictness = "SEMI" + JoinStrictAnti JoinStrictness = "ANTI" +) + +// OrderByElement represents an ORDER BY element. +type OrderByElement struct { + Position token.Position `json:"-"` + Expression Expression `json:"expression"` + Descending bool `json:"descending,omitempty"` + NullsFirst *bool `json:"nulls_first,omitempty"` + Collate string `json:"collate,omitempty"` + WithFill bool `json:"with_fill,omitempty"` +} + +func (o *OrderByElement) Pos() token.Position { return o.Position } +func (o *OrderByElement) End() token.Position { return o.Position } + +// SettingExpr represents a setting expression. +type SettingExpr struct { + Position token.Position `json:"-"` + Name string `json:"name"` + Value Expression `json:"value"` +} + +func (s *SettingExpr) Pos() token.Position { return s.Position } +func (s *SettingExpr) End() token.Position { return s.Position } + +// InsertQuery represents an INSERT statement. +type InsertQuery struct { + Position token.Position `json:"-"` + Database string `json:"database,omitempty"` + Table string `json:"table"` + Columns []*Identifier `json:"columns,omitempty"` + Select Statement `json:"select,omitempty"` + Format *Identifier `json:"format,omitempty"` +} + +func (i *InsertQuery) Pos() token.Position { return i.Position } +func (i *InsertQuery) End() token.Position { return i.Position } +func (i *InsertQuery) statementNode() {} + +// CreateQuery represents a CREATE statement. +type CreateQuery struct { + Position token.Position `json:"-"` + OrReplace bool `json:"or_replace,omitempty"` + IfNotExists bool `json:"if_not_exists,omitempty"` + Temporary bool `json:"temporary,omitempty"` + Database string `json:"database,omitempty"` + Table string `json:"table,omitempty"` + View string `json:"view,omitempty"` + Materialized bool `json:"materialized,omitempty"` + Columns []*ColumnDeclaration `json:"columns,omitempty"` + Constraints []*Constraint `json:"constraints,omitempty"` + Engine *EngineClause `json:"engine,omitempty"` + OrderBy []Expression `json:"order_by,omitempty"` + PartitionBy Expression `json:"partition_by,omitempty"` + PrimaryKey []Expression `json:"primary_key,omitempty"` + SampleBy Expression `json:"sample_by,omitempty"` + TTL *TTLClause `json:"ttl,omitempty"` + Settings []*SettingExpr `json:"settings,omitempty"` + AsSelect Statement `json:"as_select,omitempty"` + Comment string `json:"comment,omitempty"` + OnCluster string `json:"on_cluster,omitempty"` + CreateDatabase bool `json:"create_database,omitempty"` +} + +func (c *CreateQuery) Pos() token.Position { return c.Position } +func (c *CreateQuery) End() token.Position { return c.Position } +func (c *CreateQuery) statementNode() {} + +// ColumnDeclaration represents a column definition. +type ColumnDeclaration struct { + Position token.Position `json:"-"` + Name string `json:"name"` + Type *DataType `json:"type"` + Nullable *bool `json:"nullable,omitempty"` + Default Expression `json:"default,omitempty"` + DefaultKind string `json:"default_kind,omitempty"` // DEFAULT, MATERIALIZED, ALIAS, EPHEMERAL + Codec *CodecExpr `json:"codec,omitempty"` + TTL Expression `json:"ttl,omitempty"` + Comment string `json:"comment,omitempty"` +} + +func (c *ColumnDeclaration) Pos() token.Position { return c.Position } +func (c *ColumnDeclaration) End() token.Position { return c.Position } + +// DataType represents a data type. +type DataType struct { + Position token.Position `json:"-"` + Name string `json:"name"` + Parameters []Expression `json:"parameters,omitempty"` +} + +func (d *DataType) Pos() token.Position { return d.Position } +func (d *DataType) End() token.Position { return d.Position } +func (d *DataType) expressionNode() {} + +// CodecExpr represents a CODEC expression. +type CodecExpr struct { + Position token.Position `json:"-"` + Codecs []*FunctionCall `json:"codecs"` +} + +func (c *CodecExpr) Pos() token.Position { return c.Position } +func (c *CodecExpr) End() token.Position { return c.Position } + +// Constraint represents a table constraint. +type Constraint struct { + Position token.Position `json:"-"` + Name string `json:"name,omitempty"` + Expression Expression `json:"expression"` +} + +func (c *Constraint) Pos() token.Position { return c.Position } +func (c *Constraint) End() token.Position { return c.Position } + +// EngineClause represents an ENGINE clause. +type EngineClause struct { + Position token.Position `json:"-"` + Name string `json:"name"` + Parameters []Expression `json:"parameters,omitempty"` +} + +func (e *EngineClause) Pos() token.Position { return e.Position } +func (e *EngineClause) End() token.Position { return e.Position } + +// TTLClause represents a TTL clause. +type TTLClause struct { + Position token.Position `json:"-"` + Expression Expression `json:"expression"` +} + +func (t *TTLClause) Pos() token.Position { return t.Position } +func (t *TTLClause) End() token.Position { return t.Position } + +// DropQuery represents a DROP statement. +type DropQuery struct { + Position token.Position `json:"-"` + IfExists bool `json:"if_exists,omitempty"` + Database string `json:"database,omitempty"` + Table string `json:"table,omitempty"` + View string `json:"view,omitempty"` + Temporary bool `json:"temporary,omitempty"` + OnCluster string `json:"on_cluster,omitempty"` + DropDatabase bool `json:"drop_database,omitempty"` +} + +func (d *DropQuery) Pos() token.Position { return d.Position } +func (d *DropQuery) End() token.Position { return d.Position } +func (d *DropQuery) statementNode() {} + +// AlterQuery represents an ALTER statement. +type AlterQuery struct { + Position token.Position `json:"-"` + Database string `json:"database,omitempty"` + Table string `json:"table"` + Commands []*AlterCommand `json:"commands"` + OnCluster string `json:"on_cluster,omitempty"` +} + +func (a *AlterQuery) Pos() token.Position { return a.Position } +func (a *AlterQuery) End() token.Position { return a.Position } +func (a *AlterQuery) statementNode() {} + +// AlterCommand represents an ALTER command. +type AlterCommand struct { + Position token.Position `json:"-"` + Type AlterCommandType `json:"type"` + Column *ColumnDeclaration `json:"column,omitempty"` + ColumnName string `json:"column_name,omitempty"` + AfterColumn string `json:"after_column,omitempty"` + NewName string `json:"new_name,omitempty"` + Index string `json:"index,omitempty"` + Constraint *Constraint `json:"constraint,omitempty"` + Partition Expression `json:"partition,omitempty"` + TTL *TTLClause `json:"ttl,omitempty"` + Settings []*SettingExpr `json:"settings,omitempty"` +} + +func (a *AlterCommand) Pos() token.Position { return a.Position } +func (a *AlterCommand) End() token.Position { return a.Position } + +// AlterCommandType represents the type of ALTER command. +type AlterCommandType string + +const ( + AlterAddColumn AlterCommandType = "ADD_COLUMN" + AlterDropColumn AlterCommandType = "DROP_COLUMN" + AlterModifyColumn AlterCommandType = "MODIFY_COLUMN" + AlterRenameColumn AlterCommandType = "RENAME_COLUMN" + AlterClearColumn AlterCommandType = "CLEAR_COLUMN" + AlterCommentColumn AlterCommandType = "COMMENT_COLUMN" + AlterAddIndex AlterCommandType = "ADD_INDEX" + AlterDropIndex AlterCommandType = "DROP_INDEX" + AlterAddConstraint AlterCommandType = "ADD_CONSTRAINT" + AlterDropConstraint AlterCommandType = "DROP_CONSTRAINT" + AlterModifyTTL AlterCommandType = "MODIFY_TTL" + AlterModifySetting AlterCommandType = "MODIFY_SETTING" + AlterDropPartition AlterCommandType = "DROP_PARTITION" + AlterDetachPartition AlterCommandType = "DETACH_PARTITION" + AlterAttachPartition AlterCommandType = "ATTACH_PARTITION" +) + +// TruncateQuery represents a TRUNCATE statement. +type TruncateQuery struct { + Position token.Position `json:"-"` + IfExists bool `json:"if_exists,omitempty"` + Database string `json:"database,omitempty"` + Table string `json:"table"` + OnCluster string `json:"on_cluster,omitempty"` +} + +func (t *TruncateQuery) Pos() token.Position { return t.Position } +func (t *TruncateQuery) End() token.Position { return t.Position } +func (t *TruncateQuery) statementNode() {} + +// UseQuery represents a USE statement. +type UseQuery struct { + Position token.Position `json:"-"` + Database string `json:"database"` +} + +func (u *UseQuery) Pos() token.Position { return u.Position } +func (u *UseQuery) End() token.Position { return u.Position } +func (u *UseQuery) statementNode() {} + +// DescribeQuery represents a DESCRIBE statement. +type DescribeQuery struct { + Position token.Position `json:"-"` + Database string `json:"database,omitempty"` + Table string `json:"table"` +} + +func (d *DescribeQuery) Pos() token.Position { return d.Position } +func (d *DescribeQuery) End() token.Position { return d.Position } +func (d *DescribeQuery) statementNode() {} + +// ShowQuery represents a SHOW statement. +type ShowQuery struct { + Position token.Position `json:"-"` + ShowType ShowType `json:"show_type"` + Database string `json:"database,omitempty"` + From string `json:"from,omitempty"` + Like string `json:"like,omitempty"` + Where Expression `json:"where,omitempty"` + Limit Expression `json:"limit,omitempty"` +} + +func (s *ShowQuery) Pos() token.Position { return s.Position } +func (s *ShowQuery) End() token.Position { return s.Position } +func (s *ShowQuery) statementNode() {} + +// ShowType represents the type of SHOW statement. +type ShowType string + +const ( + ShowTables ShowType = "TABLES" + ShowDatabases ShowType = "DATABASES" + ShowProcesses ShowType = "PROCESSLIST" + ShowCreate ShowType = "CREATE" +) + +// ExplainQuery represents an EXPLAIN statement. +type ExplainQuery struct { + Position token.Position `json:"-"` + ExplainType ExplainType `json:"explain_type"` + Statement Statement `json:"statement"` +} + +func (e *ExplainQuery) Pos() token.Position { return e.Position } +func (e *ExplainQuery) End() token.Position { return e.Position } +func (e *ExplainQuery) statementNode() {} + +// ExplainType represents the type of EXPLAIN. +type ExplainType string + +const ( + ExplainAST ExplainType = "AST" + ExplainSyntax ExplainType = "SYNTAX" + ExplainPlan ExplainType = "PLAN" + ExplainPipeline ExplainType = "PIPELINE" + ExplainEstimate ExplainType = "ESTIMATE" +) + +// SetQuery represents a SET statement. +type SetQuery struct { + Position token.Position `json:"-"` + Settings []*SettingExpr `json:"settings"` +} + +func (s *SetQuery) Pos() token.Position { return s.Position } +func (s *SetQuery) End() token.Position { return s.Position } +func (s *SetQuery) statementNode() {} + +// OptimizeQuery represents an OPTIMIZE statement. +type OptimizeQuery struct { + Position token.Position `json:"-"` + Database string `json:"database,omitempty"` + Table string `json:"table"` + Partition Expression `json:"partition,omitempty"` + Final bool `json:"final,omitempty"` + Dedupe bool `json:"dedupe,omitempty"` + OnCluster string `json:"on_cluster,omitempty"` +} + +func (o *OptimizeQuery) Pos() token.Position { return o.Position } +func (o *OptimizeQuery) End() token.Position { return o.Position } +func (o *OptimizeQuery) statementNode() {} + +// SystemQuery represents a SYSTEM statement. +type SystemQuery struct { + Position token.Position `json:"-"` + Command string `json:"command"` + Database string `json:"database,omitempty"` + Table string `json:"table,omitempty"` +} + +func (s *SystemQuery) Pos() token.Position { return s.Position } +func (s *SystemQuery) End() token.Position { return s.Position } +func (s *SystemQuery) statementNode() {} + +// ----------------------------------------------------------------------------- +// Expressions + +// Identifier represents an identifier. +type Identifier struct { + Position token.Position `json:"-"` + Parts []string `json:"parts"` // e.g., ["db", "table", "column"] for db.table.column + Alias string `json:"alias,omitempty"` +} + +func (i *Identifier) Pos() token.Position { return i.Position } +func (i *Identifier) End() token.Position { return i.Position } +func (i *Identifier) expressionNode() {} + +// Name returns the full identifier name. +func (i *Identifier) Name() string { + if len(i.Parts) == 0 { + return "" + } + if len(i.Parts) == 1 { + return i.Parts[0] + } + result := i.Parts[0] + for _, p := range i.Parts[1:] { + result += "." + p + } + return result +} + +// TableIdentifier represents a table identifier. +type TableIdentifier struct { + Position token.Position `json:"-"` + Database string `json:"database,omitempty"` + Table string `json:"table"` + Alias string `json:"alias,omitempty"` +} + +func (t *TableIdentifier) Pos() token.Position { return t.Position } +func (t *TableIdentifier) End() token.Position { return t.Position } +func (t *TableIdentifier) expressionNode() {} + +// Literal represents a literal value. +type Literal struct { + Position token.Position `json:"-"` + Type LiteralType `json:"type"` + Value interface{} `json:"value"` +} + +func (l *Literal) Pos() token.Position { return l.Position } +func (l *Literal) End() token.Position { return l.Position } +func (l *Literal) expressionNode() {} + +// LiteralType represents the type of a literal. +type LiteralType string + +const ( + LiteralString LiteralType = "String" + LiteralInteger LiteralType = "Integer" + LiteralFloat LiteralType = "Float" + LiteralBoolean LiteralType = "Boolean" + LiteralNull LiteralType = "Null" + LiteralArray LiteralType = "Array" + LiteralTuple LiteralType = "Tuple" +) + +// Asterisk represents a *. +type Asterisk struct { + Position token.Position `json:"-"` + Table string `json:"table,omitempty"` // for table.* +} + +func (a *Asterisk) Pos() token.Position { return a.Position } +func (a *Asterisk) End() token.Position { return a.Position } +func (a *Asterisk) expressionNode() {} + +// FunctionCall represents a function call. +type FunctionCall struct { + Position token.Position `json:"-"` + Name string `json:"name"` + Arguments []Expression `json:"arguments,omitempty"` + Distinct bool `json:"distinct,omitempty"` + Over *WindowSpec `json:"over,omitempty"` + Alias string `json:"alias,omitempty"` +} + +func (f *FunctionCall) Pos() token.Position { return f.Position } +func (f *FunctionCall) End() token.Position { return f.Position } +func (f *FunctionCall) expressionNode() {} + +// WindowSpec represents a window specification. +type WindowSpec struct { + Position token.Position `json:"-"` + Name string `json:"name,omitempty"` + PartitionBy []Expression `json:"partition_by,omitempty"` + OrderBy []*OrderByElement `json:"order_by,omitempty"` + Frame *WindowFrame `json:"frame,omitempty"` +} + +func (w *WindowSpec) Pos() token.Position { return w.Position } +func (w *WindowSpec) End() token.Position { return w.Position } + +// WindowFrame represents a window frame. +type WindowFrame struct { + Position token.Position `json:"-"` + Type WindowFrameType `json:"type"` + StartBound *FrameBound `json:"start"` + EndBound *FrameBound `json:"end,omitempty"` +} + +func (w *WindowFrame) Pos() token.Position { return w.Position } +func (w *WindowFrame) End() token.Position { return w.Position } + +// WindowFrameType represents the type of window frame. +type WindowFrameType string + +const ( + FrameRows WindowFrameType = "ROWS" + FrameRange WindowFrameType = "RANGE" + FrameGroups WindowFrameType = "GROUPS" +) + +// FrameBound represents a window frame bound. +type FrameBound struct { + Position token.Position `json:"-"` + Type FrameBoundType `json:"type"` + Offset Expression `json:"offset,omitempty"` +} + +func (f *FrameBound) Pos() token.Position { return f.Position } +func (f *FrameBound) End() token.Position { return f.Position } + +// FrameBoundType represents the type of frame bound. +type FrameBoundType string + +const ( + BoundCurrentRow FrameBoundType = "CURRENT_ROW" + BoundUnboundedPre FrameBoundType = "UNBOUNDED_PRECEDING" + BoundUnboundedFol FrameBoundType = "UNBOUNDED_FOLLOWING" + BoundPreceding FrameBoundType = "PRECEDING" + BoundFollowing FrameBoundType = "FOLLOWING" +) + +// BinaryExpr represents a binary expression. +type BinaryExpr struct { + Position token.Position `json:"-"` + Left Expression `json:"left"` + Op string `json:"op"` + Right Expression `json:"right"` +} + +func (b *BinaryExpr) Pos() token.Position { return b.Position } +func (b *BinaryExpr) End() token.Position { return b.Position } +func (b *BinaryExpr) expressionNode() {} + +// UnaryExpr represents a unary expression. +type UnaryExpr struct { + Position token.Position `json:"-"` + Op string `json:"op"` + Operand Expression `json:"operand"` +} + +func (u *UnaryExpr) Pos() token.Position { return u.Position } +func (u *UnaryExpr) End() token.Position { return u.Position } +func (u *UnaryExpr) expressionNode() {} + +// Subquery represents a subquery. +type Subquery struct { + Position token.Position `json:"-"` + Query Statement `json:"query"` + Alias string `json:"alias,omitempty"` +} + +func (s *Subquery) Pos() token.Position { return s.Position } +func (s *Subquery) End() token.Position { return s.Position } +func (s *Subquery) expressionNode() {} + +// WithElement represents a WITH element (CTE). +type WithElement struct { + Position token.Position `json:"-"` + Name string `json:"name"` + Query Expression `json:"query"` // Subquery or Expression +} + +func (w *WithElement) Pos() token.Position { return w.Position } +func (w *WithElement) End() token.Position { return w.Position } +func (w *WithElement) expressionNode() {} + +// CaseExpr represents a CASE expression. +type CaseExpr struct { + Position token.Position `json:"-"` + Operand Expression `json:"operand,omitempty"` // for CASE x WHEN ... + Whens []*WhenClause `json:"whens"` + Else Expression `json:"else,omitempty"` + Alias string `json:"alias,omitempty"` +} + +func (c *CaseExpr) Pos() token.Position { return c.Position } +func (c *CaseExpr) End() token.Position { return c.Position } +func (c *CaseExpr) expressionNode() {} + +// WhenClause represents a WHEN clause in a CASE expression. +type WhenClause struct { + Position token.Position `json:"-"` + Condition Expression `json:"condition"` + Result Expression `json:"result"` +} + +func (w *WhenClause) Pos() token.Position { return w.Position } +func (w *WhenClause) End() token.Position { return w.Position } + +// CastExpr represents a CAST expression. +type CastExpr struct { + Position token.Position `json:"-"` + Expr Expression `json:"expr"` + Type *DataType `json:"type"` + Alias string `json:"alias,omitempty"` +} + +func (c *CastExpr) Pos() token.Position { return c.Position } +func (c *CastExpr) End() token.Position { return c.Position } +func (c *CastExpr) expressionNode() {} + +// ExtractExpr represents an EXTRACT expression. +type ExtractExpr struct { + Position token.Position `json:"-"` + Field string `json:"field"` // YEAR, MONTH, DAY, etc. + From Expression `json:"from"` + Alias string `json:"alias,omitempty"` +} + +func (e *ExtractExpr) Pos() token.Position { return e.Position } +func (e *ExtractExpr) End() token.Position { return e.Position } +func (e *ExtractExpr) expressionNode() {} + +// IntervalExpr represents an INTERVAL expression. +type IntervalExpr struct { + Position token.Position `json:"-"` + Value Expression `json:"value"` + Unit string `json:"unit"` // YEAR, MONTH, DAY, HOUR, MINUTE, SECOND, etc. +} + +func (i *IntervalExpr) Pos() token.Position { return i.Position } +func (i *IntervalExpr) End() token.Position { return i.Position } +func (i *IntervalExpr) expressionNode() {} + +// ArrayAccess represents array element access. +type ArrayAccess struct { + Position token.Position `json:"-"` + Array Expression `json:"array"` + Index Expression `json:"index"` +} + +func (a *ArrayAccess) Pos() token.Position { return a.Position } +func (a *ArrayAccess) End() token.Position { return a.Position } +func (a *ArrayAccess) expressionNode() {} + +// TupleAccess represents tuple element access. +type TupleAccess struct { + Position token.Position `json:"-"` + Tuple Expression `json:"tuple"` + Index Expression `json:"index"` +} + +func (t *TupleAccess) Pos() token.Position { return t.Position } +func (t *TupleAccess) End() token.Position { return t.Position } +func (t *TupleAccess) expressionNode() {} + +// Lambda represents a lambda expression. +type Lambda struct { + Position token.Position `json:"-"` + Parameters []string `json:"parameters"` + Body Expression `json:"body"` +} + +func (l *Lambda) Pos() token.Position { return l.Position } +func (l *Lambda) End() token.Position { return l.Position } +func (l *Lambda) expressionNode() {} + +// Parameter represents a parameter placeholder. +type Parameter struct { + Position token.Position `json:"-"` + Name string `json:"name,omitempty"` + Type *DataType `json:"type,omitempty"` +} + +func (p *Parameter) Pos() token.Position { return p.Position } +func (p *Parameter) End() token.Position { return p.Position } +func (p *Parameter) expressionNode() {} + +// AliasedExpr represents an expression with an alias. +type AliasedExpr struct { + Position token.Position `json:"-"` + Expr Expression `json:"expr"` + Alias string `json:"alias"` +} + +func (a *AliasedExpr) Pos() token.Position { return a.Position } +func (a *AliasedExpr) End() token.Position { return a.Position } +func (a *AliasedExpr) expressionNode() {} + +// BetweenExpr represents a BETWEEN expression. +type BetweenExpr struct { + Position token.Position `json:"-"` + Expr Expression `json:"expr"` + Not bool `json:"not,omitempty"` + Low Expression `json:"low"` + High Expression `json:"high"` +} + +func (b *BetweenExpr) Pos() token.Position { return b.Position } +func (b *BetweenExpr) End() token.Position { return b.Position } +func (b *BetweenExpr) expressionNode() {} + +// InExpr represents an IN expression. +type InExpr struct { + Position token.Position `json:"-"` + Expr Expression `json:"expr"` + Not bool `json:"not,omitempty"` + Global bool `json:"global,omitempty"` + List []Expression `json:"list,omitempty"` + Query Statement `json:"query,omitempty"` +} + +func (i *InExpr) Pos() token.Position { return i.Position } +func (i *InExpr) End() token.Position { return i.Position } +func (i *InExpr) expressionNode() {} + +// IsNullExpr represents an IS NULL or IS NOT NULL expression. +type IsNullExpr struct { + Position token.Position `json:"-"` + Expr Expression `json:"expr"` + Not bool `json:"not,omitempty"` +} + +func (i *IsNullExpr) Pos() token.Position { return i.Position } +func (i *IsNullExpr) End() token.Position { return i.Position } +func (i *IsNullExpr) expressionNode() {} + +// LikeExpr represents a LIKE or ILIKE expression. +type LikeExpr struct { + Position token.Position `json:"-"` + Expr Expression `json:"expr"` + Not bool `json:"not,omitempty"` + CaseInsensitive bool `json:"case_insensitive,omitempty"` // true for ILIKE + Pattern Expression `json:"pattern"` +} + +func (l *LikeExpr) Pos() token.Position { return l.Position } +func (l *LikeExpr) End() token.Position { return l.Position } +func (l *LikeExpr) expressionNode() {} + +// ExistsExpr represents an EXISTS expression. +type ExistsExpr struct { + Position token.Position `json:"-"` + Query Statement `json:"query"` +} + +func (e *ExistsExpr) Pos() token.Position { return e.Position } +func (e *ExistsExpr) End() token.Position { return e.Position } +func (e *ExistsExpr) expressionNode() {} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000000..209ec7d64c --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/kyleconroy/doubleclick + +go 1.24.7 diff --git a/lexer/lexer.go b/lexer/lexer.go new file mode 100644 index 0000000000..cad15e57ff --- /dev/null +++ b/lexer/lexer.go @@ -0,0 +1,420 @@ +// Package lexer implements a lexer for ClickHouse SQL. +package lexer + +import ( + "bufio" + "io" + "strings" + "unicode" + "unicode/utf8" + + "github.com/kyleconroy/doubleclick/token" +) + +// Lexer tokenizes ClickHouse SQL input. +type Lexer struct { + reader *bufio.Reader + ch rune // current character + pos token.Position + eof bool +} + +// Item represents a lexical token with its value and position. +type Item struct { + Token token.Token + Value string + Pos token.Position +} + +// New creates a new Lexer from an io.Reader. +func New(r io.Reader) *Lexer { + l := &Lexer{ + reader: bufio.NewReader(r), + pos: token.Position{Offset: 0, Line: 1, Column: 0}, + } + l.readChar() + return l +} + +func (l *Lexer) readChar() { + if l.eof { + l.ch = 0 + return + } + + r, size, err := l.reader.ReadRune() + if err != nil { + l.ch = 0 + l.eof = true + return + } + + if l.ch == '\n' { + l.pos.Line++ + l.pos.Column = 1 + } else { + l.pos.Column++ + } + l.pos.Offset += size + l.ch = r +} + +func (l *Lexer) peekChar() rune { + if l.eof { + return 0 + } + bytes, err := l.reader.Peek(1) + if err != nil || len(bytes) == 0 { + return 0 + } + r, _ := utf8.DecodeRune(bytes) + return r +} + +func (l *Lexer) skipWhitespace() { + for unicode.IsSpace(l.ch) { + l.readChar() + } +} + +// NextToken returns the next token from the input. +func (l *Lexer) NextToken() Item { + l.skipWhitespace() + + pos := l.pos + + if l.eof || l.ch == 0 { + return Item{Token: token.EOF, Value: "", Pos: pos} + } + + // Handle comments + if l.ch == '-' && l.peekChar() == '-' { + return l.readLineComment() + } + if l.ch == '/' && l.peekChar() == '*' { + return l.readBlockComment() + } + + switch l.ch { + case '+': + l.readChar() + return Item{Token: token.PLUS, Value: "+", Pos: pos} + case '-': + if l.peekChar() == '>' { + l.readChar() + l.readChar() + return Item{Token: token.ARROW, Value: "->", Pos: pos} + } + l.readChar() + return Item{Token: token.MINUS, Value: "-", Pos: pos} + case '*': + l.readChar() + return Item{Token: token.ASTERISK, Value: "*", Pos: pos} + case '/': + l.readChar() + return Item{Token: token.SLASH, Value: "/", Pos: pos} + case '%': + l.readChar() + return Item{Token: token.PERCENT, Value: "%", Pos: pos} + case '=': + l.readChar() + return Item{Token: token.EQ, Value: "=", Pos: pos} + case '!': + if l.peekChar() == '=' { + l.readChar() + l.readChar() + return Item{Token: token.NEQ, Value: "!=", Pos: pos} + } + l.readChar() + return Item{Token: token.ILLEGAL, Value: "!", Pos: pos} + case '<': + if l.peekChar() == '=' { + l.readChar() + l.readChar() + return Item{Token: token.LTE, Value: "<=", Pos: pos} + } + if l.peekChar() == '>' { + l.readChar() + l.readChar() + return Item{Token: token.NEQ, Value: "<>", Pos: pos} + } + l.readChar() + return Item{Token: token.LT, Value: "<", Pos: pos} + case '>': + if l.peekChar() == '=' { + l.readChar() + l.readChar() + return Item{Token: token.GTE, Value: ">=", Pos: pos} + } + l.readChar() + return Item{Token: token.GT, Value: ">", Pos: pos} + case '|': + if l.peekChar() == '|' { + l.readChar() + l.readChar() + return Item{Token: token.CONCAT, Value: "||", Pos: pos} + } + l.readChar() + return Item{Token: token.ILLEGAL, Value: "|", Pos: pos} + case ':': + if l.peekChar() == ':' { + l.readChar() + l.readChar() + return Item{Token: token.COLONCOLON, Value: "::", Pos: pos} + } + l.readChar() + return Item{Token: token.COLON, Value: ":", Pos: pos} + case '(': + l.readChar() + return Item{Token: token.LPAREN, Value: "(", Pos: pos} + case ')': + l.readChar() + return Item{Token: token.RPAREN, Value: ")", Pos: pos} + case '[': + l.readChar() + return Item{Token: token.LBRACKET, Value: "[", Pos: pos} + case ']': + l.readChar() + return Item{Token: token.RBRACKET, Value: "]", Pos: pos} + case '{': + return l.readParameter() + case '}': + l.readChar() + return Item{Token: token.RBRACE, Value: "}", Pos: pos} + case ',': + l.readChar() + return Item{Token: token.COMMA, Value: ",", Pos: pos} + case '.': + if unicode.IsDigit(l.peekChar()) { + return l.readNumber() + } + l.readChar() + return Item{Token: token.DOT, Value: ".", Pos: pos} + case ';': + l.readChar() + return Item{Token: token.SEMICOLON, Value: ";", Pos: pos} + case '?': + l.readChar() + return Item{Token: token.QUESTION, Value: "?", Pos: pos} + case '\'': + return l.readString('\'') + case '"': + return l.readQuotedIdentifier() + case '`': + return l.readBacktickIdentifier() + default: + if unicode.IsDigit(l.ch) { + return l.readNumber() + } + if isIdentStart(l.ch) { + return l.readIdentifier() + } + ch := l.ch + l.readChar() + return Item{Token: token.ILLEGAL, Value: string(ch), Pos: pos} + } +} + +func (l *Lexer) readLineComment() Item { + pos := l.pos + var sb strings.Builder + // Skip -- + sb.WriteRune(l.ch) + l.readChar() + sb.WriteRune(l.ch) + l.readChar() + + for l.ch != '\n' && l.ch != 0 && !l.eof { + sb.WriteRune(l.ch) + l.readChar() + } + return Item{Token: token.COMMENT, Value: sb.String(), Pos: pos} +} + +func (l *Lexer) readBlockComment() Item { + pos := l.pos + var sb strings.Builder + // Skip /* + sb.WriteRune(l.ch) + l.readChar() + sb.WriteRune(l.ch) + l.readChar() + + for !l.eof { + if l.ch == '*' && l.peekChar() == '/' { + sb.WriteRune(l.ch) + l.readChar() + sb.WriteRune(l.ch) + l.readChar() + break + } + sb.WriteRune(l.ch) + l.readChar() + } + return Item{Token: token.COMMENT, Value: sb.String(), Pos: pos} +} + +func (l *Lexer) readString(quote rune) Item { + pos := l.pos + var sb strings.Builder + l.readChar() // skip opening quote + + for !l.eof { + if l.ch == quote { + // Check for escaped quote + if l.peekChar() == quote { + sb.WriteRune(l.ch) + l.readChar() + sb.WriteRune(l.ch) + l.readChar() + continue + } + l.readChar() // skip closing quote + break + } + if l.ch == '\\' { + sb.WriteRune(l.ch) + l.readChar() + if !l.eof { + sb.WriteRune(l.ch) + l.readChar() + } + continue + } + sb.WriteRune(l.ch) + l.readChar() + } + return Item{Token: token.STRING, Value: sb.String(), Pos: pos} +} + +func (l *Lexer) readQuotedIdentifier() Item { + pos := l.pos + var sb strings.Builder + l.readChar() // skip opening quote + + for !l.eof && l.ch != '"' { + if l.ch == '\\' { + l.readChar() + if !l.eof { + sb.WriteRune(l.ch) + l.readChar() + } + continue + } + sb.WriteRune(l.ch) + l.readChar() + } + if l.ch == '"' { + l.readChar() // skip closing quote + } + return Item{Token: token.IDENT, Value: sb.String(), Pos: pos} +} + +func (l *Lexer) readBacktickIdentifier() Item { + pos := l.pos + var sb strings.Builder + l.readChar() // skip opening backtick + + for !l.eof && l.ch != '`' { + sb.WriteRune(l.ch) + l.readChar() + } + if l.ch == '`' { + l.readChar() // skip closing backtick + } + return Item{Token: token.IDENT, Value: sb.String(), Pos: pos} +} + +func (l *Lexer) readNumber() Item { + pos := l.pos + var sb strings.Builder + + // Handle leading dot for decimals like .5 + if l.ch == '.' { + sb.WriteRune(l.ch) + l.readChar() + } + + // Read integer part + for unicode.IsDigit(l.ch) { + sb.WriteRune(l.ch) + l.readChar() + } + + // Check for decimal point + if l.ch == '.' && unicode.IsDigit(l.peekChar()) { + sb.WriteRune(l.ch) + l.readChar() + for unicode.IsDigit(l.ch) { + sb.WriteRune(l.ch) + l.readChar() + } + } + + // Check for exponent + if l.ch == 'e' || l.ch == 'E' { + sb.WriteRune(l.ch) + l.readChar() + if l.ch == '+' || l.ch == '-' { + sb.WriteRune(l.ch) + l.readChar() + } + for unicode.IsDigit(l.ch) { + sb.WriteRune(l.ch) + l.readChar() + } + } + + return Item{Token: token.NUMBER, Value: sb.String(), Pos: pos} +} + +func (l *Lexer) readIdentifier() Item { + pos := l.pos + var sb strings.Builder + + for isIdentChar(l.ch) { + sb.WriteRune(l.ch) + l.readChar() + } + + ident := sb.String() + tok := token.Lookup(strings.ToUpper(ident)) + return Item{Token: tok, Value: ident, Pos: pos} +} + +func (l *Lexer) readParameter() Item { + pos := l.pos + var sb strings.Builder + l.readChar() // skip opening brace + + for !l.eof && l.ch != '}' { + sb.WriteRune(l.ch) + l.readChar() + } + if l.ch == '}' { + l.readChar() // skip closing brace + } + return Item{Token: token.PARAM, Value: sb.String(), Pos: pos} +} + +func isIdentStart(ch rune) bool { + return ch == '_' || unicode.IsLetter(ch) +} + +func isIdentChar(ch rune) bool { + return ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) +} + +// Tokenize returns all tokens from the reader. +func Tokenize(r io.Reader) []Item { + l := New(r) + var items []Item + for { + item := l.NextToken() + items = append(items, item) + if item.Token == token.EOF { + break + } + } + return items +} diff --git a/parser/expression.go b/parser/expression.go new file mode 100644 index 0000000000..f7df361680 --- /dev/null +++ b/parser/expression.go @@ -0,0 +1,1021 @@ +package parser + +import ( + "strconv" + "strings" + + "github.com/kyleconroy/doubleclick/ast" + "github.com/kyleconroy/doubleclick/token" +) + +// Operator precedence levels +const ( + LOWEST = iota + ALIAS_PREC // AS + OR_PREC // OR + AND_PREC // AND + NOT_PREC // NOT + COMPARE // =, !=, <, >, <=, >=, LIKE, IN, BETWEEN, IS + CONCAT_PREC // || + ADD_PREC // +, - + MUL_PREC // *, /, % + UNARY // -x, NOT x + CALL // function(), array[] + HIGHEST +) + +func (p *Parser) precedence(tok token.Token) int { + switch tok { + case token.AS: + return ALIAS_PREC + case token.OR: + return OR_PREC + case token.AND: + return AND_PREC + case token.NOT: + return NOT_PREC + case token.EQ, token.NEQ, token.LT, token.GT, token.LTE, token.GTE, + token.LIKE, token.ILIKE, token.IN, token.BETWEEN, token.IS: + return COMPARE + case token.CONCAT: + return CONCAT_PREC + case token.PLUS, token.MINUS: + return ADD_PREC + case token.ASTERISK, token.SLASH, token.PERCENT: + return MUL_PREC + case token.LPAREN, token.LBRACKET: + return CALL + default: + return LOWEST + } +} + +func (p *Parser) parseExpressionList() []ast.Expression { + var exprs []ast.Expression + + if p.currentIs(token.RPAREN) || p.currentIs(token.EOF) { + return exprs + } + + exprs = append(exprs, p.parseExpression(LOWEST)) + + for p.currentIs(token.COMMA) { + p.nextToken() + exprs = append(exprs, p.parseExpression(LOWEST)) + } + + return exprs +} + +func (p *Parser) parseExpression(precedence int) ast.Expression { + left := p.parsePrefixExpression() + if left == nil { + return nil + } + + for !p.currentIs(token.EOF) && precedence < p.precedence(p.current.Token) { + left = p.parseInfixExpression(left) + if left == nil { + return nil + } + } + + return left +} + +func (p *Parser) parsePrefixExpression() ast.Expression { + switch p.current.Token { + case token.IDENT: + return p.parseIdentifierOrFunction() + case token.NUMBER: + return p.parseNumber() + case token.STRING: + return p.parseString() + case token.TRUE, token.FALSE: + return p.parseBoolean() + case token.NULL: + return p.parseNull() + case token.MINUS: + return p.parseUnaryMinus() + case token.NOT: + return p.parseNot() + case token.LPAREN: + return p.parseGroupedOrTuple() + case token.LBRACKET: + return p.parseArrayLiteral() + case token.ASTERISK: + return p.parseAsterisk() + case token.CASE: + return p.parseCase() + case token.CAST: + return p.parseCast() + case token.EXTRACT: + return p.parseExtract() + case token.INTERVAL: + return p.parseInterval() + case token.EXISTS: + return p.parseExists() + case token.PARAM: + return p.parseParameter() + case token.QUESTION: + return p.parsePositionalParameter() + case token.SUBSTRING: + return p.parseSubstring() + case token.TRIM: + return p.parseTrim() + default: + return nil + } +} + +func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression { + switch p.current.Token { + case token.PLUS, token.MINUS, token.ASTERISK, token.SLASH, token.PERCENT, + token.EQ, token.NEQ, token.LT, token.GT, token.LTE, token.GTE, + token.AND, token.OR, token.CONCAT: + return p.parseBinaryExpression(left) + case token.LIKE, token.ILIKE: + return p.parseLikeExpression(left, false) + case token.NOT: + // NOT IN, NOT LIKE, NOT BETWEEN, IS NOT + p.nextToken() + switch p.current.Token { + case token.IN: + return p.parseInExpression(left, true) + case token.LIKE: + return p.parseLikeExpression(left, true) + case token.ILIKE: + return p.parseLikeExpression(left, true) + case token.BETWEEN: + return p.parseBetweenExpression(left, true) + default: + // Put back NOT and treat as binary + return left + } + case token.IN: + return p.parseInExpression(left, false) + case token.BETWEEN: + return p.parseBetweenExpression(left, false) + case token.IS: + return p.parseIsExpression(left) + case token.LPAREN: + // Function call on identifier + if ident, ok := left.(*ast.Identifier); ok { + return p.parseFunctionCall(ident.Name(), ident.Position) + } + return left + case token.LBRACKET: + return p.parseArrayAccess(left) + case token.DOT: + return p.parseDotAccess(left) + case token.AS: + return p.parseAlias(left) + case token.COLONCOLON: + return p.parseCastOperator(left) + case token.ARROW: + return p.parseLambda(left) + default: + return left + } +} + +func (p *Parser) parseIdentifierOrFunction() ast.Expression { + pos := p.current.Pos + name := p.current.Value + p.nextToken() + + // Check for function call + if p.currentIs(token.LPAREN) { + return p.parseFunctionCall(name, pos) + } + + // Check for qualified identifier (a.b.c) + parts := []string{name} + for p.currentIs(token.DOT) { + p.nextToken() + if p.currentIs(token.IDENT) { + parts = append(parts, p.current.Value) + p.nextToken() + } else if p.currentIs(token.ASTERISK) { + // table.* + p.nextToken() + return &ast.Asterisk{ + Position: pos, + Table: strings.Join(parts, "."), + } + } else { + break + } + } + + // Check for function call after qualified name + if p.currentIs(token.LPAREN) { + return p.parseFunctionCall(strings.Join(parts, "."), pos) + } + + return &ast.Identifier{ + Position: pos, + Parts: parts, + } +} + +func (p *Parser) parseFunctionCall(name string, pos token.Position) *ast.FunctionCall { + fn := &ast.FunctionCall{ + Position: pos, + Name: name, + } + + p.nextToken() // skip ( + + // Handle DISTINCT + if p.currentIs(token.DISTINCT) { + fn.Distinct = true + p.nextToken() + } + + // Parse arguments + if !p.currentIs(token.RPAREN) { + fn.Arguments = p.parseExpressionList() + } + + p.expect(token.RPAREN) + + // Handle OVER clause for window functions + if p.currentIs(token.OVER) { + p.nextToken() + fn.Over = p.parseWindowSpec() + } + + // Handle alias + if p.currentIs(token.AS) { + p.nextToken() + if p.currentIs(token.IDENT) { + fn.Alias = p.current.Value + p.nextToken() + } + } + + return fn +} + +func (p *Parser) parseWindowSpec() *ast.WindowSpec { + spec := &ast.WindowSpec{ + Position: p.current.Pos, + } + + if p.currentIs(token.IDENT) { + // Window name reference + spec.Name = p.current.Value + p.nextToken() + return spec + } + + if !p.expect(token.LPAREN) { + return spec + } + + // Parse PARTITION BY + if p.currentIs(token.PARTITION) { + p.nextToken() + if p.expect(token.BY) { + spec.PartitionBy = p.parseExpressionList() + } + } + + // Parse ORDER BY + if p.currentIs(token.ORDER) { + p.nextToken() + if p.expect(token.BY) { + spec.OrderBy = p.parseOrderByList() + } + } + + // Parse frame specification + if p.currentIs(token.IDENT) { + frameType := strings.ToUpper(p.current.Value) + if frameType == "ROWS" || frameType == "RANGE" || frameType == "GROUPS" { + spec.Frame = p.parseWindowFrame() + } + } + + p.expect(token.RPAREN) + return spec +} + +func (p *Parser) parseWindowFrame() *ast.WindowFrame { + frame := &ast.WindowFrame{ + Position: p.current.Pos, + } + + switch strings.ToUpper(p.current.Value) { + case "ROWS": + frame.Type = ast.FrameRows + case "RANGE": + frame.Type = ast.FrameRange + case "GROUPS": + frame.Type = ast.FrameGroups + } + p.nextToken() + + if p.currentIs(token.BETWEEN) { + p.nextToken() + frame.StartBound = p.parseFrameBound() + if p.currentIs(token.AND) { + p.nextToken() + frame.EndBound = p.parseFrameBound() + } + } else { + frame.StartBound = p.parseFrameBound() + } + + return frame +} + +func (p *Parser) parseFrameBound() *ast.FrameBound { + bound := &ast.FrameBound{ + Position: p.current.Pos, + } + + if p.currentIs(token.IDENT) && strings.ToUpper(p.current.Value) == "CURRENT" { + p.nextToken() + if p.currentIs(token.IDENT) && strings.ToUpper(p.current.Value) == "ROW" { + p.nextToken() + } + bound.Type = ast.BoundCurrentRow + return bound + } + + if p.currentIs(token.IDENT) && strings.ToUpper(p.current.Value) == "UNBOUNDED" { + p.nextToken() + if p.currentIs(token.IDENT) { + switch strings.ToUpper(p.current.Value) { + case "PRECEDING": + bound.Type = ast.BoundUnboundedPre + case "FOLLOWING": + bound.Type = ast.BoundUnboundedFol + } + p.nextToken() + } + return bound + } + + // n PRECEDING or n FOLLOWING + bound.Offset = p.parseExpression(LOWEST) + if p.currentIs(token.IDENT) { + switch strings.ToUpper(p.current.Value) { + case "PRECEDING": + bound.Type = ast.BoundPreceding + case "FOLLOWING": + bound.Type = ast.BoundFollowing + } + p.nextToken() + } + + return bound +} + +func (p *Parser) parseNumber() ast.Expression { + lit := &ast.Literal{ + Position: p.current.Pos, + } + + value := p.current.Value + p.nextToken() + + // Check if it's a float + if strings.Contains(value, ".") || strings.ContainsAny(value, "eE") { + f, err := strconv.ParseFloat(value, 64) + if err != nil { + lit.Type = ast.LiteralString + lit.Value = value + } else { + lit.Type = ast.LiteralFloat + lit.Value = f + } + } else { + i, err := strconv.ParseInt(value, 10, 64) + if err != nil { + lit.Type = ast.LiteralString + lit.Value = value + } else { + lit.Type = ast.LiteralInteger + lit.Value = i + } + } + + return lit +} + +func (p *Parser) parseString() ast.Expression { + lit := &ast.Literal{ + Position: p.current.Pos, + Type: ast.LiteralString, + Value: p.current.Value, + } + p.nextToken() + return lit +} + +func (p *Parser) parseBoolean() ast.Expression { + lit := &ast.Literal{ + Position: p.current.Pos, + Type: ast.LiteralBoolean, + Value: p.current.Token == token.TRUE, + } + p.nextToken() + return lit +} + +func (p *Parser) parseNull() ast.Expression { + lit := &ast.Literal{ + Position: p.current.Pos, + Type: ast.LiteralNull, + Value: nil, + } + p.nextToken() + return lit +} + +func (p *Parser) parseUnaryMinus() ast.Expression { + expr := &ast.UnaryExpr{ + Position: p.current.Pos, + Op: "-", + } + p.nextToken() + expr.Operand = p.parseExpression(UNARY) + return expr +} + +func (p *Parser) parseNot() ast.Expression { + expr := &ast.UnaryExpr{ + Position: p.current.Pos, + Op: "NOT", + } + p.nextToken() + expr.Operand = p.parseExpression(NOT_PREC) + return expr +} + +func (p *Parser) parseGroupedOrTuple() ast.Expression { + pos := p.current.Pos + p.nextToken() // skip ( + + // Check for subquery + if p.currentIs(token.SELECT) || p.currentIs(token.WITH) { + subquery := p.parseSelectWithUnion() + p.expect(token.RPAREN) + return &ast.Subquery{ + Position: pos, + Query: subquery, + } + } + + // Parse first expression + first := p.parseExpression(LOWEST) + + // Check if it's a tuple + if p.currentIs(token.COMMA) { + elements := []ast.Expression{first} + for p.currentIs(token.COMMA) { + p.nextToken() + elements = append(elements, p.parseExpression(LOWEST)) + } + p.expect(token.RPAREN) + return &ast.Literal{ + Position: pos, + Type: ast.LiteralTuple, + Value: elements, + } + } + + p.expect(token.RPAREN) + return first +} + +func (p *Parser) parseArrayLiteral() ast.Expression { + lit := &ast.Literal{ + Position: p.current.Pos, + Type: ast.LiteralArray, + } + p.nextToken() // skip [ + + var elements []ast.Expression + if !p.currentIs(token.RBRACKET) { + elements = p.parseExpressionList() + } + lit.Value = elements + + p.expect(token.RBRACKET) + return lit +} + +func (p *Parser) parseAsterisk() ast.Expression { + asterisk := &ast.Asterisk{ + Position: p.current.Pos, + } + p.nextToken() + return asterisk +} + +func (p *Parser) parseCase() ast.Expression { + expr := &ast.CaseExpr{ + Position: p.current.Pos, + } + p.nextToken() // skip CASE + + // Check for CASE operand (simple CASE) + if !p.currentIs(token.WHEN) { + expr.Operand = p.parseExpression(LOWEST) + } + + // Parse WHEN clauses + for p.currentIs(token.WHEN) { + when := &ast.WhenClause{ + Position: p.current.Pos, + } + p.nextToken() // skip WHEN + + when.Condition = p.parseExpression(LOWEST) + + if !p.expect(token.THEN) { + break + } + + when.Result = p.parseExpression(LOWEST) + expr.Whens = append(expr.Whens, when) + } + + // Parse ELSE clause + if p.currentIs(token.ELSE) { + p.nextToken() + expr.Else = p.parseExpression(LOWEST) + } + + p.expect(token.END) + + // Handle alias + if p.currentIs(token.AS) { + p.nextToken() + if p.currentIs(token.IDENT) { + expr.Alias = p.current.Value + p.nextToken() + } + } + + return expr +} + +func (p *Parser) parseCast() ast.Expression { + expr := &ast.CastExpr{ + Position: p.current.Pos, + } + p.nextToken() // skip CAST + + if !p.expect(token.LPAREN) { + return nil + } + + expr.Expr = p.parseExpression(LOWEST) + + if !p.expect(token.AS) { + return nil + } + + expr.Type = p.parseDataType() + + p.expect(token.RPAREN) + + return expr +} + +func (p *Parser) parseExtract() ast.Expression { + expr := &ast.ExtractExpr{ + Position: p.current.Pos, + } + p.nextToken() // skip EXTRACT + + if !p.expect(token.LPAREN) { + return nil + } + + // Parse field (YEAR, MONTH, etc.) + if p.currentIs(token.IDENT) { + expr.Field = strings.ToUpper(p.current.Value) + p.nextToken() + } + + if !p.expect(token.FROM) { + return nil + } + + expr.From = p.parseExpression(LOWEST) + + p.expect(token.RPAREN) + + return expr +} + +func (p *Parser) parseInterval() ast.Expression { + expr := &ast.IntervalExpr{ + Position: p.current.Pos, + } + p.nextToken() // skip INTERVAL + + expr.Value = p.parseExpression(LOWEST) + + // Parse unit + if p.currentIs(token.IDENT) { + expr.Unit = strings.ToUpper(p.current.Value) + p.nextToken() + } + + return expr +} + +func (p *Parser) parseExists() ast.Expression { + expr := &ast.ExistsExpr{ + Position: p.current.Pos, + } + p.nextToken() // skip EXISTS + + if !p.expect(token.LPAREN) { + return nil + } + + expr.Query = p.parseSelectWithUnion() + + p.expect(token.RPAREN) + + return expr +} + +func (p *Parser) parseParameter() ast.Expression { + param := &ast.Parameter{ + Position: p.current.Pos, + } + + value := p.current.Value + p.nextToken() + + // Parse {name:Type} format + parts := strings.SplitN(value, ":", 2) + param.Name = parts[0] + if len(parts) > 1 { + param.Type = &ast.DataType{Name: parts[1]} + } + + return param +} + +func (p *Parser) parsePositionalParameter() ast.Expression { + param := &ast.Parameter{ + Position: p.current.Pos, + } + p.nextToken() + return param +} + +func (p *Parser) parseSubstring() ast.Expression { + pos := p.current.Pos + p.nextToken() // skip SUBSTRING + + if !p.expect(token.LPAREN) { + return nil + } + + args := []ast.Expression{p.parseExpression(LOWEST)} + + // Handle FROM + if p.currentIs(token.FROM) { + p.nextToken() + args = append(args, p.parseExpression(LOWEST)) + } else if p.currentIs(token.COMMA) { + p.nextToken() + args = append(args, p.parseExpression(LOWEST)) + } + + // Handle FOR + if p.currentIs(token.FOR) { + p.nextToken() + args = append(args, p.parseExpression(LOWEST)) + } else if p.currentIs(token.COMMA) { + p.nextToken() + args = append(args, p.parseExpression(LOWEST)) + } + + p.expect(token.RPAREN) + + return &ast.FunctionCall{ + Position: pos, + Name: "substring", + Arguments: args, + } +} + +func (p *Parser) parseTrim() ast.Expression { + pos := p.current.Pos + p.nextToken() // skip TRIM + + if !p.expect(token.LPAREN) { + return nil + } + + var trimType string + var trimChars ast.Expression + + // Check for LEADING, TRAILING, BOTH + if p.currentIs(token.LEADING) { + trimType = "LEADING" + p.nextToken() + } else if p.currentIs(token.TRAILING) { + trimType = "TRAILING" + p.nextToken() + } else if p.currentIs(token.BOTH) { + trimType = "BOTH" + p.nextToken() + } + + // Parse characters to trim (if specified) + if !p.currentIs(token.FROM) && !p.currentIs(token.RPAREN) { + trimChars = p.parseExpression(LOWEST) + } + + // FROM clause + var expr ast.Expression + if p.currentIs(token.FROM) { + p.nextToken() + expr = p.parseExpression(LOWEST) + } else { + expr = trimChars + trimChars = nil + } + + p.expect(token.RPAREN) + + // Build appropriate function call + fnName := "trim" + switch trimType { + case "LEADING": + fnName = "trimLeft" + case "TRAILING": + fnName = "trimRight" + } + + args := []ast.Expression{expr} + if trimChars != nil { + args = append(args, trimChars) + } + + return &ast.FunctionCall{ + Position: pos, + Name: fnName, + Arguments: args, + } +} + +func (p *Parser) parseBinaryExpression(left ast.Expression) ast.Expression { + expr := &ast.BinaryExpr{ + Position: p.current.Pos, + Left: left, + Op: p.current.Value, + } + + if p.current.Token.IsKeyword() { + expr.Op = strings.ToUpper(p.current.Value) + } + + prec := p.precedence(p.current.Token) + p.nextToken() + + expr.Right = p.parseExpression(prec) + return expr +} + +func (p *Parser) parseLikeExpression(left ast.Expression, not bool) ast.Expression { + expr := &ast.LikeExpr{ + Position: p.current.Pos, + Expr: left, + Not: not, + } + + if p.currentIs(token.ILIKE) { + expr.CaseInsensitive = true + } + + p.nextToken() // skip LIKE/ILIKE + + expr.Pattern = p.parseExpression(COMPARE) + return expr +} + +func (p *Parser) parseInExpression(left ast.Expression, not bool) ast.Expression { + expr := &ast.InExpr{ + Position: p.current.Pos, + Expr: left, + Not: not, + } + + // Handle GLOBAL IN + if p.currentIs(token.GLOBAL) { + expr.Global = true + p.nextToken() + } + + p.nextToken() // skip IN + + if !p.expect(token.LPAREN) { + return nil + } + + // Check for subquery + if p.currentIs(token.SELECT) || p.currentIs(token.WITH) { + expr.Query = p.parseSelectWithUnion() + } else { + expr.List = p.parseExpressionList() + } + + p.expect(token.RPAREN) + return expr +} + +func (p *Parser) parseBetweenExpression(left ast.Expression, not bool) ast.Expression { + expr := &ast.BetweenExpr{ + Position: p.current.Pos, + Expr: left, + Not: not, + } + + p.nextToken() // skip BETWEEN + + expr.Low = p.parseExpression(COMPARE) + + if !p.expect(token.AND) { + return nil + } + + expr.High = p.parseExpression(COMPARE) + return expr +} + +func (p *Parser) parseIsExpression(left ast.Expression) ast.Expression { + pos := p.current.Pos + p.nextToken() // skip IS + + not := false + if p.currentIs(token.NOT) { + not = true + p.nextToken() + } + + if p.currentIs(token.NULL) { + p.nextToken() + return &ast.IsNullExpr{ + Position: pos, + Expr: left, + Not: not, + } + } + + // IS TRUE, IS FALSE + if p.currentIs(token.TRUE) || p.currentIs(token.FALSE) { + value := p.currentIs(token.TRUE) + if not { + value = !value + } + p.nextToken() + return &ast.BinaryExpr{ + Position: pos, + Left: left, + Op: "=", + Right: &ast.Literal{ + Position: pos, + Type: ast.LiteralBoolean, + Value: value, + }, + } + } + + return left +} + +func (p *Parser) parseArrayAccess(left ast.Expression) ast.Expression { + expr := &ast.ArrayAccess{ + Position: p.current.Pos, + Array: left, + } + + p.nextToken() // skip [ + expr.Index = p.parseExpression(LOWEST) + p.expect(token.RBRACKET) + + return expr +} + +func (p *Parser) parseDotAccess(left ast.Expression) ast.Expression { + p.nextToken() // skip . + + // Check for tuple access with number + if p.currentIs(token.NUMBER) { + expr := &ast.TupleAccess{ + Position: p.current.Pos, + Tuple: left, + Index: p.parseNumber(), + } + return expr + } + + // Regular identifier access + if p.currentIs(token.IDENT) { + if ident, ok := left.(*ast.Identifier); ok { + ident.Parts = append(ident.Parts, p.current.Value) + p.nextToken() + + // Check for function call + if p.currentIs(token.LPAREN) { + return p.parseFunctionCall(ident.Name(), ident.Position) + } + + // Check for table.* + if p.currentIs(token.ASTERISK) { + tableName := ident.Name() + p.nextToken() + return &ast.Asterisk{ + Position: ident.Position, + Table: tableName, + } + } + + return ident + } + } + + return left +} + +func (p *Parser) parseAlias(left ast.Expression) ast.Expression { + p.nextToken() // skip AS + + alias := "" + if p.currentIs(token.IDENT) { + alias = p.current.Value + p.nextToken() + } + + // Set alias on the expression if it supports it + switch e := left.(type) { + case *ast.Identifier: + e.Alias = alias + return e + case *ast.FunctionCall: + e.Alias = alias + return e + case *ast.Subquery: + e.Alias = alias + return e + default: + return &ast.AliasedExpr{ + Position: left.Pos(), + Expr: left, + Alias: alias, + } + } +} + +func (p *Parser) parseCastOperator(left ast.Expression) ast.Expression { + expr := &ast.CastExpr{ + Position: p.current.Pos, + Expr: left, + } + + p.nextToken() // skip :: + + expr.Type = p.parseDataType() + return expr +} + +func (p *Parser) parseLambda(left ast.Expression) ast.Expression { + lambda := &ast.Lambda{ + Position: p.current.Pos, + } + + // Extract parameter names from left expression + switch e := left.(type) { + case *ast.Identifier: + lambda.Parameters = e.Parts + case *ast.Literal: + if e.Type == ast.LiteralTuple { + if exprs, ok := e.Value.([]ast.Expression); ok { + for _, expr := range exprs { + if ident, ok := expr.(*ast.Identifier); ok { + lambda.Parameters = append(lambda.Parameters, ident.Name()) + } + } + } + } + } + + p.nextToken() // skip -> + + lambda.Body = p.parseExpression(LOWEST) + return lambda +} diff --git a/parser/parser.go b/parser/parser.go new file mode 100644 index 0000000000..59f1422de1 --- /dev/null +++ b/parser/parser.go @@ -0,0 +1,1707 @@ +// Package parser implements a parser for ClickHouse SQL. +package parser + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/kyleconroy/doubleclick/ast" + "github.com/kyleconroy/doubleclick/lexer" + "github.com/kyleconroy/doubleclick/token" +) + +// Parser parses ClickHouse SQL statements. +type Parser struct { + lexer *lexer.Lexer + current lexer.Item + peek lexer.Item + errors []error +} + +// New creates a new Parser from an io.Reader. +func New(r io.Reader) *Parser { + p := &Parser{ + lexer: lexer.New(r), + } + // Read two tokens to initialize current and peek + p.nextToken() + p.nextToken() + return p +} + +func (p *Parser) nextToken() { + p.current = p.peek + for { + p.peek = p.lexer.NextToken() + // Skip comments and whitespace + if p.peek.Token != token.COMMENT && p.peek.Token != token.WHITESPACE { + break + } + } +} + +func (p *Parser) currentIs(t token.Token) bool { + return p.current.Token == t +} + +func (p *Parser) peekIs(t token.Token) bool { + return p.peek.Token == t +} + +func (p *Parser) expect(t token.Token) bool { + if p.currentIs(t) { + p.nextToken() + return true + } + p.errors = append(p.errors, fmt.Errorf("expected %s, got %s at line %d, column %d", + t, p.current.Token, p.current.Pos.Line, p.current.Pos.Column)) + return false +} + +func (p *Parser) expectPeek(t token.Token) bool { + if p.peekIs(t) { + p.nextToken() + return true + } + p.errors = append(p.errors, fmt.Errorf("expected %s, got %s at line %d, column %d", + t, p.peek.Token, p.peek.Pos.Line, p.peek.Pos.Column)) + return false +} + +// Parse parses SQL statements from the input. +func Parse(ctx context.Context, r io.Reader) ([]ast.Statement, error) { + p := New(r) + return p.ParseStatements(ctx) +} + +// ParseStatements parses multiple SQL statements. +func (p *Parser) ParseStatements(ctx context.Context) ([]ast.Statement, error) { + var statements []ast.Statement + + for !p.currentIs(token.EOF) { + select { + case <-ctx.Done(): + return statements, ctx.Err() + default: + } + + stmt := p.parseStatement() + if stmt != nil { + statements = append(statements, stmt) + } + + // Skip semicolons between statements + for p.currentIs(token.SEMICOLON) { + p.nextToken() + } + } + + if len(p.errors) > 0 { + return statements, fmt.Errorf("parse errors: %v", p.errors) + } + return statements, nil +} + +func (p *Parser) parseStatement() ast.Statement { + switch p.current.Token { + case token.SELECT: + return p.parseSelectWithUnion() + case token.WITH: + return p.parseSelectWithUnion() + case token.INSERT: + return p.parseInsert() + case token.CREATE: + return p.parseCreate() + case token.DROP: + return p.parseDrop() + case token.ALTER: + return p.parseAlter() + case token.TRUNCATE: + return p.parseTruncate() + case token.USE: + return p.parseUse() + case token.DESCRIBE: + return p.parseDescribe() + case token.SHOW: + return p.parseShow() + case token.EXPLAIN: + return p.parseExplain() + case token.SET: + return p.parseSet() + case token.OPTIMIZE: + return p.parseOptimize() + case token.SYSTEM: + return p.parseSystem() + default: + p.errors = append(p.errors, fmt.Errorf("unexpected token %s at line %d, column %d", + p.current.Token, p.current.Pos.Line, p.current.Pos.Column)) + p.nextToken() + return nil + } +} + +// parseSelectWithUnion parses SELECT ... UNION ... queries +func (p *Parser) parseSelectWithUnion() *ast.SelectWithUnionQuery { + query := &ast.SelectWithUnionQuery{ + Position: p.current.Pos, + } + + // Parse first SELECT + sel := p.parseSelect() + if sel == nil { + return nil + } + query.Selects = append(query.Selects, sel) + + // Parse UNION clauses + for p.currentIs(token.UNION) { + p.nextToken() // skip UNION + if p.currentIs(token.ALL) { + query.UnionAll = true + p.nextToken() + } + sel := p.parseSelect() + if sel == nil { + break + } + query.Selects = append(query.Selects, sel) + } + + return query +} + +func (p *Parser) parseSelect() *ast.SelectQuery { + sel := &ast.SelectQuery{ + Position: p.current.Pos, + } + + // Handle WITH clause + if p.currentIs(token.WITH) { + p.nextToken() + sel.With = p.parseWithClause() + } + + if !p.expect(token.SELECT) { + return nil + } + + // Handle DISTINCT + if p.currentIs(token.DISTINCT) { + sel.Distinct = true + p.nextToken() + } + + // Handle TOP + if p.currentIs(token.TOP) { + p.nextToken() + sel.Top = p.parseExpression(LOWEST) + } + + // Parse column list + sel.Columns = p.parseExpressionList() + + // Parse FROM clause + if p.currentIs(token.FROM) { + p.nextToken() + sel.From = p.parseTablesInSelect() + } + + // Parse PREWHERE clause + if p.currentIs(token.PREWHERE) { + p.nextToken() + sel.PreWhere = p.parseExpression(LOWEST) + } + + // Parse WHERE clause + if p.currentIs(token.WHERE) { + p.nextToken() + sel.Where = p.parseExpression(LOWEST) + } + + // Parse GROUP BY clause + if p.currentIs(token.GROUP) { + p.nextToken() + if !p.expect(token.BY) { + return nil + } + sel.GroupBy = p.parseExpressionList() + + // WITH ROLLUP + if p.currentIs(token.WITH) && p.peekIs(token.ROLLUP) { + p.nextToken() + p.nextToken() + sel.WithRollup = true + } + + // WITH TOTALS + if p.currentIs(token.WITH) && p.peekIs(token.TOTALS) { + p.nextToken() + p.nextToken() + sel.WithTotals = true + } + } + + // Parse HAVING clause + if p.currentIs(token.HAVING) { + p.nextToken() + sel.Having = p.parseExpression(LOWEST) + } + + // Parse ORDER BY clause + if p.currentIs(token.ORDER) { + p.nextToken() + if !p.expect(token.BY) { + return nil + } + sel.OrderBy = p.parseOrderByList() + } + + // Parse LIMIT clause + if p.currentIs(token.LIMIT) { + p.nextToken() + sel.Limit = p.parseExpression(LOWEST) + + // LIMIT n, m syntax (offset, limit) + if p.currentIs(token.COMMA) { + p.nextToken() + sel.Offset = sel.Limit + sel.Limit = p.parseExpression(LOWEST) + } + } + + // Parse OFFSET clause + if p.currentIs(token.OFFSET) { + p.nextToken() + sel.Offset = p.parseExpression(LOWEST) + } + + // Parse SETTINGS clause + if p.currentIs(token.SETTINGS) { + p.nextToken() + sel.Settings = p.parseSettingsList() + } + + // Parse FORMAT clause + if p.currentIs(token.FORMAT) { + p.nextToken() + if p.currentIs(token.IDENT) { + sel.Format = &ast.Identifier{ + Position: p.current.Pos, + Parts: []string{p.current.Value}, + } + p.nextToken() + } + } + + return sel +} + +func (p *Parser) parseWithClause() []ast.Expression { + var elements []ast.Expression + + for { + elem := &ast.WithElement{ + Position: p.current.Pos, + } + + // Check if it's a subquery or expression + if p.currentIs(token.LPAREN) { + // Subquery + p.nextToken() + subquery := p.parseSelectWithUnion() + if !p.expect(token.RPAREN) { + return nil + } + elem.Query = &ast.Subquery{Query: subquery} + } else { + // Expression + elem.Query = p.parseExpression(LOWEST) + } + + if !p.expect(token.AS) { + return nil + } + + if p.currentIs(token.IDENT) { + elem.Name = p.current.Value + p.nextToken() + } + + elements = append(elements, elem) + + if !p.currentIs(token.COMMA) { + break + } + p.nextToken() + } + + return elements +} + +func (p *Parser) parseTablesInSelect() *ast.TablesInSelectQuery { + tables := &ast.TablesInSelectQuery{ + Position: p.current.Pos, + } + + // Parse first table + elem := p.parseTableElement() + if elem == nil { + return nil + } + tables.Tables = append(tables.Tables, elem) + + // Parse JOINs + for p.isJoinKeyword() { + elem := p.parseTableElementWithJoin() + if elem == nil { + break + } + tables.Tables = append(tables.Tables, elem) + } + + return tables +} + +func (p *Parser) isJoinKeyword() bool { + switch p.current.Token { + case token.JOIN, token.INNER, token.LEFT, token.RIGHT, token.FULL, token.CROSS, + token.GLOBAL, token.ANY, token.ALL, token.ASOF, token.SEMI, token.ANTI: + return true + case token.COMMA: + return true + } + return false +} + +func (p *Parser) parseTableElement() *ast.TablesInSelectQueryElement { + elem := &ast.TablesInSelectQueryElement{ + Position: p.current.Pos, + } + + elem.Table = p.parseTableExpression() + return elem +} + +func (p *Parser) parseTableElementWithJoin() *ast.TablesInSelectQueryElement { + elem := &ast.TablesInSelectQueryElement{ + Position: p.current.Pos, + } + + // Handle comma join (implicit cross join) + if p.currentIs(token.COMMA) { + p.nextToken() + elem.Table = p.parseTableExpression() + return elem + } + + // Parse JOIN + join := &ast.TableJoin{ + Position: p.current.Pos, + } + + // Parse join modifiers + if p.currentIs(token.GLOBAL) { + join.Global = true + p.nextToken() + } + + // Parse strictness + switch p.current.Token { + case token.ANY: + join.Strictness = ast.JoinStrictAny + p.nextToken() + case token.ALL: + join.Strictness = ast.JoinStrictAll + p.nextToken() + case token.ASOF: + join.Strictness = ast.JoinStrictAsof + p.nextToken() + case token.SEMI: + join.Strictness = ast.JoinStrictSemi + p.nextToken() + case token.ANTI: + join.Strictness = ast.JoinStrictAnti + p.nextToken() + } + + // Parse join type + switch p.current.Token { + case token.INNER: + join.Type = ast.JoinInner + p.nextToken() + case token.LEFT: + join.Type = ast.JoinLeft + p.nextToken() + if p.currentIs(token.OUTER) { + p.nextToken() + } + case token.RIGHT: + join.Type = ast.JoinRight + p.nextToken() + if p.currentIs(token.OUTER) { + p.nextToken() + } + case token.FULL: + join.Type = ast.JoinFull + p.nextToken() + if p.currentIs(token.OUTER) { + p.nextToken() + } + case token.CROSS: + join.Type = ast.JoinCross + p.nextToken() + default: + join.Type = ast.JoinInner + } + + if !p.expect(token.JOIN) { + return nil + } + + elem.Table = p.parseTableExpression() + + // Parse ON or USING clause + if p.currentIs(token.ON) { + p.nextToken() + join.On = p.parseExpression(LOWEST) + } else if p.currentIs(token.USING) { + p.nextToken() + if p.currentIs(token.LPAREN) { + p.nextToken() + join.Using = p.parseExpressionList() + p.expect(token.RPAREN) + } else { + join.Using = p.parseExpressionList() + } + } + + elem.Join = join + return elem +} + +func (p *Parser) parseTableExpression() *ast.TableExpression { + expr := &ast.TableExpression{ + Position: p.current.Pos, + } + + // Handle subquery + if p.currentIs(token.LPAREN) { + p.nextToken() + if p.currentIs(token.SELECT) || p.currentIs(token.WITH) { + subquery := p.parseSelectWithUnion() + expr.Table = &ast.Subquery{Query: subquery} + } else { + // Table function or expression + expr.Table = p.parseExpression(LOWEST) + } + p.expect(token.RPAREN) + } else if p.currentIs(token.IDENT) { + // Table identifier or function + ident := p.current.Value + pos := p.current.Pos + p.nextToken() + + if p.currentIs(token.LPAREN) { + // Table function + expr.Table = p.parseFunctionCall(ident, pos) + } else if p.currentIs(token.DOT) { + // database.table + p.nextToken() + tableName := "" + if p.currentIs(token.IDENT) { + tableName = p.current.Value + p.nextToken() + } + expr.Table = &ast.TableIdentifier{ + Position: pos, + Database: ident, + Table: tableName, + } + } else { + expr.Table = &ast.TableIdentifier{ + Position: pos, + Table: ident, + } + } + } + + // Handle FINAL + if p.currentIs(token.FINAL) { + expr.Final = true + p.nextToken() + } + + // Handle SAMPLE + if p.currentIs(token.SAMPLE) { + p.nextToken() + expr.Sample = &ast.SampleClause{ + Position: p.current.Pos, + Ratio: p.parseExpression(LOWEST), + } + if p.currentIs(token.OFFSET) { + p.nextToken() + expr.Sample.Offset = p.parseExpression(LOWEST) + } + } + + // Handle alias + if p.currentIs(token.AS) { + p.nextToken() + if p.currentIs(token.IDENT) { + expr.Alias = p.current.Value + p.nextToken() + } + } else if p.currentIs(token.IDENT) && !p.isKeywordForClause() { + expr.Alias = p.current.Value + p.nextToken() + } + + return expr +} + +func (p *Parser) isKeywordForClause() bool { + switch p.current.Token { + case token.WHERE, token.GROUP, token.HAVING, token.ORDER, token.LIMIT, + token.OFFSET, token.UNION, token.EXCEPT, token.SETTINGS, token.FORMAT, + token.PREWHERE, token.JOIN, token.LEFT, token.RIGHT, token.INNER, + token.FULL, token.CROSS, token.ON, token.USING, token.GLOBAL, + token.ANY, token.ALL, token.SEMI, token.ANTI, token.ASOF: + return true + } + return false +} + +func (p *Parser) parseOrderByList() []*ast.OrderByElement { + var elements []*ast.OrderByElement + + for { + elem := &ast.OrderByElement{ + Position: p.current.Pos, + Expression: p.parseExpression(LOWEST), + } + + // Handle ASC/DESC + if p.currentIs(token.ASC) { + p.nextToken() + } else if p.currentIs(token.DESC) { + elem.Descending = true + p.nextToken() + } + + // Handle NULLS FIRST/LAST + if p.currentIs(token.NULLS) { + p.nextToken() + if p.currentIs(token.FIRST) { + t := true + elem.NullsFirst = &t + p.nextToken() + } else { + // NULLS LAST + f := false + elem.NullsFirst = &f + p.nextToken() + } + } + + // Handle COLLATE + if p.currentIs(token.COLLATE) { + p.nextToken() + if p.currentIs(token.STRING) || p.currentIs(token.IDENT) { + elem.Collate = p.current.Value + p.nextToken() + } + } + + elements = append(elements, elem) + + if !p.currentIs(token.COMMA) { + break + } + p.nextToken() + } + + return elements +} + +func (p *Parser) parseSettingsList() []*ast.SettingExpr { + var settings []*ast.SettingExpr + + for { + if !p.currentIs(token.IDENT) { + break + } + + setting := &ast.SettingExpr{ + Position: p.current.Pos, + Name: p.current.Value, + } + p.nextToken() + + if !p.expect(token.EQ) { + break + } + + setting.Value = p.parseExpression(LOWEST) + settings = append(settings, setting) + + if !p.currentIs(token.COMMA) { + break + } + p.nextToken() + } + + return settings +} + +func (p *Parser) parseInsert() *ast.InsertQuery { + ins := &ast.InsertQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip INSERT + + if !p.expect(token.INTO) { + return nil + } + + // Skip optional TABLE keyword + if p.currentIs(token.TABLE) { + p.nextToken() + } + + // Parse table name + if p.currentIs(token.IDENT) { + tableName := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + ins.Database = tableName + if p.currentIs(token.IDENT) { + ins.Table = p.current.Value + p.nextToken() + } + } else { + ins.Table = tableName + } + } + + // Parse column list + if p.currentIs(token.LPAREN) { + p.nextToken() + for !p.currentIs(token.RPAREN) && !p.currentIs(token.EOF) { + if p.currentIs(token.IDENT) { + ins.Columns = append(ins.Columns, &ast.Identifier{ + Position: p.current.Pos, + Parts: []string{p.current.Value}, + }) + p.nextToken() + } + if p.currentIs(token.COMMA) { + p.nextToken() + } else { + break + } + } + p.expect(token.RPAREN) + } + + // Parse VALUES or SELECT + if p.currentIs(token.VALUES) { + p.nextToken() + // VALUES are typically provided externally, skip for now + } else if p.currentIs(token.SELECT) || p.currentIs(token.WITH) { + ins.Select = p.parseSelectWithUnion() + } + + // Parse FORMAT + if p.currentIs(token.FORMAT) { + p.nextToken() + if p.currentIs(token.IDENT) { + ins.Format = &ast.Identifier{ + Position: p.current.Pos, + Parts: []string{p.current.Value}, + } + p.nextToken() + } + } + + return ins +} + +func (p *Parser) parseCreate() *ast.CreateQuery { + create := &ast.CreateQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip CREATE + + // Handle OR REPLACE + if p.currentIs(token.OR) { + p.nextToken() + if p.currentIs(token.REPLACE) { + create.OrReplace = true + p.nextToken() + } + } + + // Handle TEMPORARY + if p.currentIs(token.TEMPORARY) { + create.Temporary = true + p.nextToken() + } + + // Handle MATERIALIZED + if p.currentIs(token.MATERIALIZED) { + create.Materialized = true + p.nextToken() + } + + // What are we creating? + switch p.current.Token { + case token.TABLE: + p.nextToken() + p.parseCreateTable(create) + case token.DATABASE: + create.CreateDatabase = true + p.nextToken() + p.parseCreateDatabase(create) + case token.VIEW: + p.nextToken() + p.parseCreateView(create) + default: + p.errors = append(p.errors, fmt.Errorf("expected TABLE, DATABASE, or VIEW after CREATE")) + return nil + } + + return create +} + +func (p *Parser) parseCreateTable(create *ast.CreateQuery) { + // Handle IF NOT EXISTS + if p.currentIs(token.IF) { + p.nextToken() + if p.currentIs(token.NOT) { + p.nextToken() + if p.currentIs(token.EXISTS) { + create.IfNotExists = true + p.nextToken() + } + } + } + + // Parse table name + if p.currentIs(token.IDENT) { + tableName := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + create.Database = tableName + if p.currentIs(token.IDENT) { + create.Table = p.current.Value + p.nextToken() + } + } else { + create.Table = tableName + } + } + + // Handle ON CLUSTER + if p.currentIs(token.ON) { + p.nextToken() + if p.currentIs(token.CLUSTER) { + p.nextToken() + if p.currentIs(token.IDENT) || p.currentIs(token.STRING) { + create.OnCluster = p.current.Value + p.nextToken() + } + } + } + + // Parse column definitions + if p.currentIs(token.LPAREN) { + p.nextToken() + for !p.currentIs(token.RPAREN) && !p.currentIs(token.EOF) { + col := p.parseColumnDeclaration() + if col != nil { + create.Columns = append(create.Columns, col) + } + if p.currentIs(token.COMMA) { + p.nextToken() + } else { + break + } + } + p.expect(token.RPAREN) + } + + // Parse ENGINE + if p.currentIs(token.ENGINE) { + p.nextToken() + if p.currentIs(token.EQ) { + p.nextToken() + } + create.Engine = p.parseEngineClause() + } + + // Parse ORDER BY + if p.currentIs(token.ORDER) { + p.nextToken() + if p.expect(token.BY) { + if p.currentIs(token.LPAREN) { + p.nextToken() + create.OrderBy = p.parseExpressionList() + p.expect(token.RPAREN) + } else { + create.OrderBy = []ast.Expression{p.parseExpression(LOWEST)} + } + } + } + + // Parse PARTITION BY + if p.currentIs(token.PARTITION) { + p.nextToken() + if p.expect(token.BY) { + create.PartitionBy = p.parseExpression(LOWEST) + } + } + + // Parse PRIMARY KEY + if p.currentIs(token.PRIMARY) { + p.nextToken() + if p.expect(token.KEY) { + if p.currentIs(token.LPAREN) { + p.nextToken() + create.PrimaryKey = p.parseExpressionList() + p.expect(token.RPAREN) + } else { + create.PrimaryKey = []ast.Expression{p.parseExpression(LOWEST)} + } + } + } + + // Parse SAMPLE BY + if p.currentIs(token.SAMPLE) { + p.nextToken() + if p.expect(token.BY) { + create.SampleBy = p.parseExpression(LOWEST) + } + } + + // Parse TTL + if p.currentIs(token.TTL) { + p.nextToken() + create.TTL = &ast.TTLClause{ + Position: p.current.Pos, + Expression: p.parseExpression(LOWEST), + } + } + + // Parse SETTINGS + if p.currentIs(token.SETTINGS) { + p.nextToken() + create.Settings = p.parseSettingsList() + } + + // Parse AS SELECT + if p.currentIs(token.AS) { + p.nextToken() + if p.currentIs(token.SELECT) || p.currentIs(token.WITH) { + create.AsSelect = p.parseSelectWithUnion() + } + } +} + +func (p *Parser) parseCreateDatabase(create *ast.CreateQuery) { + // Handle IF NOT EXISTS + if p.currentIs(token.IF) { + p.nextToken() + if p.currentIs(token.NOT) { + p.nextToken() + if p.currentIs(token.EXISTS) { + create.IfNotExists = true + p.nextToken() + } + } + } + + // Parse database name + if p.currentIs(token.IDENT) { + create.Database = p.current.Value + p.nextToken() + } + + // Handle ON CLUSTER + if p.currentIs(token.ON) { + p.nextToken() + if p.currentIs(token.CLUSTER) { + p.nextToken() + if p.currentIs(token.IDENT) || p.currentIs(token.STRING) { + create.OnCluster = p.current.Value + p.nextToken() + } + } + } + + // Parse ENGINE + if p.currentIs(token.ENGINE) { + p.nextToken() + if p.currentIs(token.EQ) { + p.nextToken() + } + create.Engine = p.parseEngineClause() + } +} + +func (p *Parser) parseCreateView(create *ast.CreateQuery) { + // Handle IF NOT EXISTS + if p.currentIs(token.IF) { + p.nextToken() + if p.currentIs(token.NOT) { + p.nextToken() + if p.currentIs(token.EXISTS) { + create.IfNotExists = true + p.nextToken() + } + } + } + + // Parse view name + if p.currentIs(token.IDENT) { + viewName := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + create.Database = viewName + if p.currentIs(token.IDENT) { + create.View = p.current.Value + p.nextToken() + } + } else { + create.View = viewName + } + } + + // Handle ON CLUSTER + if p.currentIs(token.ON) { + p.nextToken() + if p.currentIs(token.CLUSTER) { + p.nextToken() + if p.currentIs(token.IDENT) || p.currentIs(token.STRING) { + create.OnCluster = p.current.Value + p.nextToken() + } + } + } + + // Parse AS SELECT + if p.currentIs(token.AS) { + p.nextToken() + if p.currentIs(token.SELECT) || p.currentIs(token.WITH) { + create.AsSelect = p.parseSelectWithUnion() + } + } +} + +func (p *Parser) parseColumnDeclaration() *ast.ColumnDeclaration { + col := &ast.ColumnDeclaration{ + Position: p.current.Pos, + } + + // Parse column name + if p.currentIs(token.IDENT) { + col.Name = p.current.Value + p.nextToken() + } else { + return nil + } + + // Parse data type + col.Type = p.parseDataType() + + // Parse DEFAULT/MATERIALIZED/ALIAS + switch p.current.Token { + case token.DEFAULT: + col.DefaultKind = "DEFAULT" + p.nextToken() + col.Default = p.parseExpression(LOWEST) + case token.MATERIALIZED: + col.DefaultKind = "MATERIALIZED" + p.nextToken() + col.Default = p.parseExpression(LOWEST) + case token.ALIAS: + col.DefaultKind = "ALIAS" + p.nextToken() + col.Default = p.parseExpression(LOWEST) + } + + // Parse CODEC + if p.currentIs(token.IDENT) && strings.ToUpper(p.current.Value) == "CODEC" { + p.nextToken() + col.Codec = p.parseCodecExpr() + } + + // Parse TTL + if p.currentIs(token.TTL) { + p.nextToken() + col.TTL = p.parseExpression(LOWEST) + } + + return col +} + +func (p *Parser) parseDataType() *ast.DataType { + if !p.currentIs(token.IDENT) { + return nil + } + + dt := &ast.DataType{ + Position: p.current.Pos, + Name: p.current.Value, + } + p.nextToken() + + // Parse type parameters + if p.currentIs(token.LPAREN) { + p.nextToken() + for !p.currentIs(token.RPAREN) && !p.currentIs(token.EOF) { + // Could be another data type or an expression + if p.currentIs(token.IDENT) && p.isDataTypeName(p.current.Value) { + dt.Parameters = append(dt.Parameters, p.parseDataType()) + } else { + dt.Parameters = append(dt.Parameters, p.parseExpression(LOWEST)) + } + if p.currentIs(token.COMMA) { + p.nextToken() + } else { + break + } + } + p.expect(token.RPAREN) + } + + return dt +} + +func (p *Parser) isDataTypeName(name string) bool { + upper := strings.ToUpper(name) + types := []string{ + "INT8", "INT16", "INT32", "INT64", "INT128", "INT256", + "UINT8", "UINT16", "UINT32", "UINT64", "UINT128", "UINT256", + "FLOAT32", "FLOAT64", + "DECIMAL", "DECIMAL32", "DECIMAL64", "DECIMAL128", "DECIMAL256", + "STRING", "FIXEDSTRING", + "UUID", "DATE", "DATE32", "DATETIME", "DATETIME64", + "ENUM", "ENUM8", "ENUM16", + "ARRAY", "TUPLE", "MAP", "NESTED", + "NULLABLE", "LOWCARDINALITY", + "BOOL", "BOOLEAN", + "IPV4", "IPV6", + "NOTHING", "INTERVAL", + } + for _, t := range types { + if upper == t { + return true + } + } + return false +} + +func (p *Parser) parseCodecExpr() *ast.CodecExpr { + codec := &ast.CodecExpr{ + Position: p.current.Pos, + } + + if !p.expect(token.LPAREN) { + return nil + } + + for !p.currentIs(token.RPAREN) && !p.currentIs(token.EOF) { + if p.currentIs(token.IDENT) { + name := p.current.Value + pos := p.current.Pos + p.nextToken() + + fn := &ast.FunctionCall{ + Position: pos, + Name: name, + } + + if p.currentIs(token.LPAREN) { + p.nextToken() + if !p.currentIs(token.RPAREN) { + fn.Arguments = p.parseExpressionList() + } + p.expect(token.RPAREN) + } + + codec.Codecs = append(codec.Codecs, fn) + } + + if p.currentIs(token.COMMA) { + p.nextToken() + } else { + break + } + } + + p.expect(token.RPAREN) + return codec +} + +func (p *Parser) parseEngineClause() *ast.EngineClause { + engine := &ast.EngineClause{ + Position: p.current.Pos, + } + + if p.currentIs(token.IDENT) { + engine.Name = p.current.Value + p.nextToken() + } + + if p.currentIs(token.LPAREN) { + p.nextToken() + if !p.currentIs(token.RPAREN) { + engine.Parameters = p.parseExpressionList() + } + p.expect(token.RPAREN) + } + + return engine +} + +func (p *Parser) parseDrop() *ast.DropQuery { + drop := &ast.DropQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip DROP + + // Handle TEMPORARY + if p.currentIs(token.TEMPORARY) { + drop.Temporary = true + p.nextToken() + } + + // What are we dropping? + switch p.current.Token { + case token.TABLE: + p.nextToken() + case token.DATABASE: + drop.DropDatabase = true + p.nextToken() + case token.VIEW: + p.nextToken() + default: + p.nextToken() // skip unknown token + } + + // Handle IF EXISTS + if p.currentIs(token.IF) { + p.nextToken() + if p.currentIs(token.EXISTS) { + drop.IfExists = true + p.nextToken() + } + } + + // Parse name + if p.currentIs(token.IDENT) { + name := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + drop.Database = name + if p.currentIs(token.IDENT) { + if drop.DropDatabase { + drop.Database = p.current.Value + } else { + drop.Table = p.current.Value + } + p.nextToken() + } + } else { + if drop.DropDatabase { + drop.Database = name + } else { + drop.Table = name + } + } + } + + // Handle ON CLUSTER + if p.currentIs(token.ON) { + p.nextToken() + if p.currentIs(token.CLUSTER) { + p.nextToken() + if p.currentIs(token.IDENT) || p.currentIs(token.STRING) { + drop.OnCluster = p.current.Value + p.nextToken() + } + } + } + + return drop +} + +func (p *Parser) parseAlter() *ast.AlterQuery { + alter := &ast.AlterQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip ALTER + + if !p.expect(token.TABLE) { + return nil + } + + // Parse table name + if p.currentIs(token.IDENT) { + tableName := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + alter.Database = tableName + if p.currentIs(token.IDENT) { + alter.Table = p.current.Value + p.nextToken() + } + } else { + alter.Table = tableName + } + } + + // Handle ON CLUSTER + if p.currentIs(token.ON) { + p.nextToken() + if p.currentIs(token.CLUSTER) { + p.nextToken() + if p.currentIs(token.IDENT) || p.currentIs(token.STRING) { + alter.OnCluster = p.current.Value + p.nextToken() + } + } + } + + // Parse commands + for { + cmd := p.parseAlterCommand() + if cmd == nil { + break + } + alter.Commands = append(alter.Commands, cmd) + + if !p.currentIs(token.COMMA) { + break + } + p.nextToken() + } + + return alter +} + +func (p *Parser) parseAlterCommand() *ast.AlterCommand { + cmd := &ast.AlterCommand{ + Position: p.current.Pos, + } + + switch p.current.Token { + case token.ADD: + p.nextToken() + if p.currentIs(token.COLUMN) { + cmd.Type = ast.AlterAddColumn + p.nextToken() + cmd.Column = p.parseColumnDeclaration() + if p.currentIs(token.IDENT) && strings.ToUpper(p.current.Value) == "AFTER" { + p.nextToken() + if p.currentIs(token.IDENT) { + cmd.AfterColumn = p.current.Value + p.nextToken() + } + } + } else if p.currentIs(token.INDEX) { + cmd.Type = ast.AlterAddIndex + p.nextToken() + // Parse index definition + } else if p.currentIs(token.CONSTRAINT) { + cmd.Type = ast.AlterAddConstraint + p.nextToken() + // Parse constraint + } + case token.DROP: + p.nextToken() + if p.currentIs(token.COLUMN) { + cmd.Type = ast.AlterDropColumn + p.nextToken() + if p.currentIs(token.IF) { + p.nextToken() + p.expect(token.EXISTS) + } + if p.currentIs(token.IDENT) { + cmd.ColumnName = p.current.Value + p.nextToken() + } + } else if p.currentIs(token.INDEX) { + cmd.Type = ast.AlterDropIndex + p.nextToken() + if p.currentIs(token.IDENT) { + cmd.Index = p.current.Value + p.nextToken() + } + } else if p.currentIs(token.CONSTRAINT) { + cmd.Type = ast.AlterDropConstraint + p.nextToken() + } else if p.currentIs(token.PARTITION) { + cmd.Type = ast.AlterDropPartition + p.nextToken() + cmd.Partition = p.parseExpression(LOWEST) + } + case token.MODIFY: + p.nextToken() + if p.currentIs(token.COLUMN) { + cmd.Type = ast.AlterModifyColumn + p.nextToken() + cmd.Column = p.parseColumnDeclaration() + } else if p.currentIs(token.TTL) { + cmd.Type = ast.AlterModifyTTL + p.nextToken() + cmd.TTL = &ast.TTLClause{ + Position: p.current.Pos, + Expression: p.parseExpression(LOWEST), + } + } else if p.currentIs(token.SETTINGS) { + cmd.Type = ast.AlterModifySetting + p.nextToken() + cmd.Settings = p.parseSettingsList() + } + case token.RENAME: + p.nextToken() + if p.currentIs(token.COLUMN) { + cmd.Type = ast.AlterRenameColumn + p.nextToken() + if p.currentIs(token.IDENT) { + cmd.ColumnName = p.current.Value + p.nextToken() + } + if p.currentIs(token.TO) { + p.nextToken() + if p.currentIs(token.IDENT) { + cmd.NewName = p.current.Value + p.nextToken() + } + } + } + case token.DETACH: + p.nextToken() + if p.currentIs(token.PARTITION) { + cmd.Type = ast.AlterDetachPartition + p.nextToken() + cmd.Partition = p.parseExpression(LOWEST) + } + case token.ATTACH: + p.nextToken() + if p.currentIs(token.PARTITION) { + cmd.Type = ast.AlterAttachPartition + p.nextToken() + cmd.Partition = p.parseExpression(LOWEST) + } + default: + return nil + } + + return cmd +} + +func (p *Parser) parseTruncate() *ast.TruncateQuery { + trunc := &ast.TruncateQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip TRUNCATE + + if p.currentIs(token.TABLE) { + p.nextToken() + } + + // Handle IF EXISTS + if p.currentIs(token.IF) { + p.nextToken() + if p.currentIs(token.EXISTS) { + trunc.IfExists = true + p.nextToken() + } + } + + // Parse table name + if p.currentIs(token.IDENT) { + tableName := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + trunc.Database = tableName + if p.currentIs(token.IDENT) { + trunc.Table = p.current.Value + p.nextToken() + } + } else { + trunc.Table = tableName + } + } + + // Handle ON CLUSTER + if p.currentIs(token.ON) { + p.nextToken() + if p.currentIs(token.CLUSTER) { + p.nextToken() + if p.currentIs(token.IDENT) || p.currentIs(token.STRING) { + trunc.OnCluster = p.current.Value + p.nextToken() + } + } + } + + return trunc +} + +func (p *Parser) parseUse() *ast.UseQuery { + use := &ast.UseQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip USE + + if p.currentIs(token.IDENT) { + use.Database = p.current.Value + p.nextToken() + } + + return use +} + +func (p *Parser) parseDescribe() *ast.DescribeQuery { + desc := &ast.DescribeQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip DESCRIBE + + if p.currentIs(token.TABLE) { + p.nextToken() + } + + // Parse table name + if p.currentIs(token.IDENT) { + tableName := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + desc.Database = tableName + if p.currentIs(token.IDENT) { + desc.Table = p.current.Value + p.nextToken() + } + } else { + desc.Table = tableName + } + } + + return desc +} + +func (p *Parser) parseShow() *ast.ShowQuery { + show := &ast.ShowQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip SHOW + + switch p.current.Token { + case token.TABLES: + show.ShowType = ast.ShowTables + p.nextToken() + case token.DATABASES: + show.ShowType = ast.ShowDatabases + p.nextToken() + case token.CREATE: + show.ShowType = ast.ShowCreate + p.nextToken() + if p.currentIs(token.TABLE) { + p.nextToken() + } + default: + // Handle SHOW PROCESSLIST etc. + if p.currentIs(token.IDENT) { + if strings.ToUpper(p.current.Value) == "PROCESSLIST" { + show.ShowType = ast.ShowProcesses + } + p.nextToken() + } + } + + // Parse FROM clause + if p.currentIs(token.FROM) { + p.nextToken() + if p.currentIs(token.IDENT) { + show.From = p.current.Value + p.nextToken() + } + } + + // Parse LIKE clause + if p.currentIs(token.LIKE) { + p.nextToken() + if p.currentIs(token.STRING) { + show.Like = p.current.Value + p.nextToken() + } + } + + // Parse WHERE clause + if p.currentIs(token.WHERE) { + p.nextToken() + show.Where = p.parseExpression(LOWEST) + } + + // Parse LIMIT clause + if p.currentIs(token.LIMIT) { + p.nextToken() + show.Limit = p.parseExpression(LOWEST) + } + + return show +} + +func (p *Parser) parseExplain() *ast.ExplainQuery { + explain := &ast.ExplainQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip EXPLAIN + + // Parse explain type + if p.currentIs(token.IDENT) { + switch strings.ToUpper(p.current.Value) { + case "AST": + explain.ExplainType = ast.ExplainAST + p.nextToken() + case "SYNTAX": + explain.ExplainType = ast.ExplainSyntax + p.nextToken() + case "PLAN": + explain.ExplainType = ast.ExplainPlan + p.nextToken() + case "PIPELINE": + explain.ExplainType = ast.ExplainPipeline + p.nextToken() + case "ESTIMATE": + explain.ExplainType = ast.ExplainEstimate + p.nextToken() + default: + explain.ExplainType = ast.ExplainPlan + } + } + + // Parse the statement being explained + explain.Statement = p.parseStatement() + + return explain +} + +func (p *Parser) parseSet() *ast.SetQuery { + set := &ast.SetQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip SET + + set.Settings = p.parseSettingsList() + + return set +} + +func (p *Parser) parseOptimize() *ast.OptimizeQuery { + opt := &ast.OptimizeQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip OPTIMIZE + + if !p.expect(token.TABLE) { + return nil + } + + // Parse table name + if p.currentIs(token.IDENT) { + tableName := p.current.Value + p.nextToken() + + if p.currentIs(token.DOT) { + p.nextToken() + opt.Database = tableName + if p.currentIs(token.IDENT) { + opt.Table = p.current.Value + p.nextToken() + } + } else { + opt.Table = tableName + } + } + + // Handle ON CLUSTER + if p.currentIs(token.ON) { + p.nextToken() + if p.currentIs(token.CLUSTER) { + p.nextToken() + if p.currentIs(token.IDENT) || p.currentIs(token.STRING) { + opt.OnCluster = p.current.Value + p.nextToken() + } + } + } + + // Handle PARTITION + if p.currentIs(token.PARTITION) { + p.nextToken() + opt.Partition = p.parseExpression(LOWEST) + } + + // Handle FINAL + if p.currentIs(token.FINAL) { + opt.Final = true + p.nextToken() + } + + // Handle DEDUPLICATE + if p.currentIs(token.IDENT) && strings.ToUpper(p.current.Value) == "DEDUPLICATE" { + opt.Dedupe = true + p.nextToken() + } + + return opt +} + +func (p *Parser) parseSystem() *ast.SystemQuery { + sys := &ast.SystemQuery{ + Position: p.current.Pos, + } + + p.nextToken() // skip SYSTEM + + // Read the command + var parts []string + for p.currentIs(token.IDENT) { + parts = append(parts, p.current.Value) + p.nextToken() + } + sys.Command = strings.Join(parts, " ") + + return sys +} diff --git a/parser/parser_test.go b/parser/parser_test.go new file mode 100644 index 0000000000..2749dc58fe --- /dev/null +++ b/parser/parser_test.go @@ -0,0 +1,558 @@ +package parser_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os/exec" + "strings" + "testing" + "time" + + "github.com/kyleconroy/doubleclick/ast" + "github.com/kyleconroy/doubleclick/parser" +) + +// clickhouseAvailable checks if ClickHouse server is running +func clickhouseAvailable() bool { + resp, err := http.Get("http://127.0.0.1:8123/ping") + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == 200 +} + +// getClickHouseAST runs EXPLAIN AST on ClickHouse and returns the output +func getClickHouseAST(query string) (string, error) { + explainQuery := fmt.Sprintf("EXPLAIN AST %s", query) + resp, err := http.Get("http://127.0.0.1:8123/?query=" + url.QueryEscape(explainQuery)) + if err != nil { + return "", err + } + defer resp.Body.Close() + + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + return buf.String(), nil +} + +// TestParserBasicSelect tests basic SELECT parsing +func TestParserBasicSelect(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"simple select", "SELECT 1"}, + {"select columns", "SELECT id, name FROM users"}, + {"select with where", "SELECT * FROM users WHERE id = 1"}, + {"select with alias", "SELECT id AS user_id FROM users"}, + {"select distinct", "SELECT DISTINCT name FROM users"}, + {"select with limit", "SELECT * FROM users LIMIT 10"}, + {"select with offset", "SELECT * FROM users LIMIT 10 OFFSET 5"}, + {"select with order", "SELECT * FROM users ORDER BY name ASC"}, + {"select with order desc", "SELECT * FROM users ORDER BY id DESC"}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + if _, ok := stmts[0].(*ast.SelectWithUnionQuery); !ok { + t.Fatalf("Expected SelectWithUnionQuery, got %T", stmts[0]) + } + }) + } +} + +// TestParserComplexSelect tests complex SELECT parsing +func TestParserComplexSelect(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"group by", "SELECT count(*) FROM users GROUP BY status"}, + {"group by having", "SELECT count(*) FROM users GROUP BY status HAVING count(*) > 1"}, + {"multiple tables", "SELECT * FROM users, orders"}, + {"inner join", "SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id"}, + {"left join", "SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id"}, + {"subquery in where", "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"}, + {"subquery in from", "SELECT * FROM (SELECT id FROM users) AS t"}, + {"union all", "SELECT 1 UNION ALL SELECT 2"}, + {"case expression", "SELECT CASE WHEN id > 1 THEN 'big' ELSE 'small' END FROM users"}, + {"between", "SELECT * FROM users WHERE id BETWEEN 1 AND 10"}, + {"like", "SELECT * FROM users WHERE name LIKE '%test%'"}, + {"is null", "SELECT * FROM users WHERE name IS NULL"}, + {"is not null", "SELECT * FROM users WHERE name IS NOT NULL"}, + {"in list", "SELECT * FROM users WHERE id IN (1, 2, 3)"}, + {"not in", "SELECT * FROM users WHERE id NOT IN (1, 2, 3)"}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + }) + } +} + +// TestParserFunctions tests function parsing +func TestParserFunctions(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"count", "SELECT count(*) FROM users"}, + {"sum", "SELECT sum(amount) FROM orders"}, + {"avg", "SELECT avg(price) FROM products"}, + {"min max", "SELECT min(id), max(id) FROM users"}, + {"nested functions", "SELECT toDate(now()) FROM users"}, + {"function with multiple args", "SELECT substring(name, 1, 5) FROM users"}, + {"distinct in function", "SELECT count(DISTINCT id) FROM users"}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + }) + } +} + +// TestParserExpressions tests expression parsing +func TestParserExpressions(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"arithmetic", "SELECT 1 + 2 * 3"}, + {"comparison", "SELECT 1 < 2"}, + {"logical and", "SELECT 1 AND 2"}, + {"logical or", "SELECT 1 OR 2"}, + {"logical not", "SELECT NOT 1"}, + {"unary minus", "SELECT -5"}, + {"parentheses", "SELECT (1 + 2) * 3"}, + {"string literal", "SELECT 'hello'"}, + {"integer literal", "SELECT 42"}, + {"float literal", "SELECT 3.14"}, + {"null literal", "SELECT NULL"}, + {"boolean true", "SELECT true"}, + {"boolean false", "SELECT false"}, + {"array literal", "SELECT [1, 2, 3]"}, + {"tuple literal", "SELECT (1, 'a')"}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + }) + } +} + +// TestParserDDL tests DDL statement parsing +func TestParserDDL(t *testing.T) { + tests := []struct { + name string + query string + stmtType interface{} + }{ + {"create table", "CREATE TABLE test (id UInt64, name String) ENGINE = MergeTree() ORDER BY id", &ast.CreateQuery{}}, + {"create table if not exists", "CREATE TABLE IF NOT EXISTS test (id UInt64) ENGINE = MergeTree() ORDER BY id", &ast.CreateQuery{}}, + {"drop table", "DROP TABLE test", &ast.DropQuery{}}, + {"drop table if exists", "DROP TABLE IF EXISTS test", &ast.DropQuery{}}, + {"truncate table", "TRUNCATE TABLE test", &ast.TruncateQuery{}}, + {"alter add column", "ALTER TABLE test ADD COLUMN age UInt32", &ast.AlterQuery{}}, + {"alter drop column", "ALTER TABLE test DROP COLUMN age", &ast.AlterQuery{}}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + }) + } +} + +// TestParserOtherStatements tests other statement types +func TestParserOtherStatements(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"use database", "USE mydb"}, + {"describe table", "DESCRIBE TABLE users"}, + {"show tables", "SHOW TABLES"}, + {"show databases", "SHOW DATABASES"}, + {"insert into", "INSERT INTO users (id, name) VALUES"}, + {"insert select", "INSERT INTO users SELECT * FROM old_users"}, + {"set setting", "SET max_threads = 4"}, + {"explain", "EXPLAIN SELECT 1"}, + {"explain ast", "EXPLAIN AST SELECT 1"}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if len(stmts) != 1 { + t.Fatalf("Expected 1 statement, got %d", len(stmts)) + } + }) + } +} + +// TestParserWithClickHouse compares parsing with ClickHouse's EXPLAIN AST +func TestParserWithClickHouse(t *testing.T) { + if !clickhouseAvailable() { + t.Skip("ClickHouse not available") + } + + tests := []struct { + name string + query string + }{ + {"simple select", "SELECT 1"}, + {"select from table", "SELECT id, name FROM users"}, + {"select with where", "SELECT * FROM users WHERE id = 1"}, + {"select with and", "SELECT * FROM users WHERE id = 1 AND status = 'active'"}, + {"select with order limit", "SELECT * FROM users ORDER BY name LIMIT 10"}, + {"select with join", "SELECT a.id FROM users a JOIN orders b ON a.id = b.user_id"}, + {"select with group by", "SELECT count(*) FROM orders GROUP BY user_id"}, + {"select with having", "SELECT count(*) FROM orders GROUP BY user_id HAVING count(*) > 1"}, + {"select with subquery", "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"}, + {"select with case", "SELECT CASE WHEN id > 1 THEN 'big' ELSE 'small' END FROM users"}, + {"select with functions", "SELECT toDate(now()), count(*) FROM users"}, + {"select with between", "SELECT * FROM users WHERE id BETWEEN 1 AND 10"}, + {"select with like", "SELECT * FROM users WHERE name LIKE '%test%'"}, + {"union all", "SELECT 1 UNION ALL SELECT 2"}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse with our parser + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Our parser error: %v", err) + } + if len(stmts) == 0 { + t.Fatal("Our parser returned no statements") + } + + // Get ClickHouse's AST + chAST, err := getClickHouseAST(tt.query) + if err != nil { + t.Fatalf("ClickHouse error: %v", err) + } + + // Verify ClickHouse accepted the query (no error in response) + if strings.Contains(chAST, "Code:") || strings.Contains(chAST, "Exception:") { + t.Fatalf("ClickHouse rejected query: %s", chAST) + } + + // Log both ASTs for comparison + t.Logf("Query: %s", tt.query) + t.Logf("ClickHouse AST:\n%s", chAST) + + // Verify our AST can be serialized to JSON + jsonBytes, err := json.MarshalIndent(stmts[0], "", " ") + if err != nil { + t.Fatalf("JSON marshal error: %v", err) + } + t.Logf("Our AST (JSON):\n%s", string(jsonBytes)) + }) + } +} + +// TestParserJSONSerialization tests that all AST nodes can be serialized to JSON +func TestParserJSONSerialization(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"select", "SELECT id, name AS n FROM users WHERE id > 1 ORDER BY name LIMIT 10"}, + {"create table", "CREATE TABLE test (id UInt64, name String) ENGINE = MergeTree() ORDER BY id"}, + {"insert", "INSERT INTO users (id, name) SELECT id, name FROM old_users"}, + {"alter", "ALTER TABLE users ADD COLUMN age UInt32"}, + {"complex select", ` + SELECT + u.id, + u.name, + count(*) AS order_count, + sum(o.amount) AS total + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + WHERE u.status = 'active' + GROUP BY u.id, u.name + HAVING count(*) > 0 + ORDER BY total DESC + LIMIT 100 + `}, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmts, err := parser.Parse(ctx, strings.NewReader(tt.query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + + for i, stmt := range stmts { + jsonBytes, err := json.MarshalIndent(stmt, "", " ") + if err != nil { + t.Fatalf("JSON marshal error for statement %d: %v", i, err) + } + + // Verify it's valid JSON by unmarshaling + var m map[string]interface{} + if err := json.Unmarshal(jsonBytes, &m); err != nil { + t.Fatalf("JSON unmarshal error for statement %d: %v", i, err) + } + + t.Logf("Statement %d JSON:\n%s", i, string(jsonBytes)) + } + }) + } +} + +// TestParserMultipleStatements tests parsing multiple statements +func TestParserMultipleStatements(t *testing.T) { + query := ` + SELECT 1; + SELECT 2; + SELECT 3 + ` + + ctx := context.Background() + stmts, err := parser.Parse(ctx, strings.NewReader(query)) + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if len(stmts) != 3 { + t.Fatalf("Expected 3 statements, got %d", len(stmts)) + } +} + +// TestParserContextCancellation tests that parsing respects context cancellation +func TestParserContextCancellation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + // Give some time for cancellation + time.Sleep(5 * time.Millisecond) + + _, err := parser.Parse(ctx, strings.NewReader("SELECT 1")) + if err == nil { + // Context might not have been checked yet for simple queries + // This is acceptable behavior + t.Log("Context cancellation not triggered for simple query (acceptable)") + } +} + +// TestClickHouseASTComparison runs a detailed comparison with ClickHouse AST +func TestClickHouseASTComparison(t *testing.T) { + if !clickhouseAvailable() { + t.Skip("ClickHouse not available") + } + + // Test queries that exercise different AST node types + queries := []string{ + // Basic SELECT + "SELECT 1", + "SELECT id FROM users", + "SELECT id, name FROM users", + + // Expressions + "SELECT 1 + 2", + "SELECT 1 + 2 * 3", + "SELECT (1 + 2) * 3", + "SELECT -5", + "SELECT NOT true", + + // Literals + "SELECT 'hello'", + "SELECT 3.14", + "SELECT NULL", + "SELECT [1, 2, 3]", + + // Functions + "SELECT count(*)", + "SELECT sum(amount) FROM orders", + "SELECT toDate('2023-01-01')", + + // WHERE clause + "SELECT * FROM users WHERE id = 1", + "SELECT * FROM users WHERE id > 1 AND status = 'active'", + "SELECT * FROM users WHERE id IN (1, 2, 3)", + "SELECT * FROM users WHERE name LIKE '%test%'", + "SELECT * FROM users WHERE id BETWEEN 1 AND 10", + "SELECT * FROM users WHERE name IS NULL", + + // JOINs + "SELECT * FROM users u JOIN orders o ON u.id = o.user_id", + "SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id", + + // GROUP BY / ORDER BY + "SELECT count(*) FROM users GROUP BY status", + "SELECT * FROM users ORDER BY id", + "SELECT * FROM users ORDER BY id DESC", + "SELECT * FROM users ORDER BY id LIMIT 10", + + // Subqueries + "SELECT * FROM (SELECT 1) AS t", + "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)", + + // UNION + "SELECT 1 UNION ALL SELECT 2", + + // CASE + "SELECT CASE WHEN id > 1 THEN 'big' ELSE 'small' END FROM users", + + // CREATE TABLE + "CREATE TABLE test (id UInt64, name String) ENGINE = MergeTree() ORDER BY id", + + // DROP + "DROP TABLE IF EXISTS test", + + // INSERT + "INSERT INTO users (id, name) VALUES", + + // ALTER + "ALTER TABLE users ADD COLUMN age UInt32", + + // USE + "USE mydb", + + // TRUNCATE + "TRUNCATE TABLE users", + + // DESCRIBE + "DESCRIBE TABLE users", + + // SHOW + "SHOW TABLES", + } + + ctx := context.Background() + passed := 0 + failed := 0 + + for _, query := range queries { + // Parse with our parser + stmts, err := parser.Parse(ctx, strings.NewReader(query)) + if err != nil { + t.Logf("FAIL [parse error]: %s\n Error: %v", query, err) + failed++ + continue + } + + if len(stmts) == 0 { + t.Logf("FAIL [no statements]: %s", query) + failed++ + continue + } + + // Get ClickHouse's AST + chAST, err := getClickHouseAST(query) + if err != nil { + t.Logf("SKIP [clickhouse error]: %s\n Error: %v", query, err) + continue + } + + // Check if ClickHouse accepted the query + if strings.Contains(chAST, "Code:") || strings.Contains(chAST, "Exception:") { + t.Logf("SKIP [clickhouse rejected]: %s\n Response: %s", query, strings.TrimSpace(chAST)) + continue + } + + // Verify we can serialize to JSON + _, jsonErr := json.Marshal(stmts[0]) + if jsonErr != nil { + t.Logf("FAIL [json error]: %s\n Error: %v", query, jsonErr) + failed++ + continue + } + + t.Logf("PASS: %s", query) + passed++ + } + + t.Logf("\nSummary: %d passed, %d failed", passed, failed) + + if failed > 0 { + t.Errorf("%d queries failed to parse", failed) + } +} + +// BenchmarkParser benchmarks the parser performance +func BenchmarkParser(b *testing.B) { + query := ` + SELECT + u.id, + u.name, + count(*) AS order_count, + sum(o.amount) AS total + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + WHERE u.status = 'active' AND o.created_at > '2023-01-01' + GROUP BY u.id, u.name + HAVING count(*) > 0 + ORDER BY total DESC + LIMIT 100 + ` + + ctx := context.Background() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := parser.Parse(ctx, strings.NewReader(query)) + if err != nil { + b.Fatal(err) + } + } +} + +// Helper to run clickhouse client command +func runClickHouseClient(query string) (string, error) { + cmd := exec.Command("./clickhouse", "client", "--query", query) + out, err := cmd.CombinedOutput() + return string(out), err +} diff --git a/token/token.go b/token/token.go new file mode 100644 index 0000000000..ae61d35746 --- /dev/null +++ b/token/token.go @@ -0,0 +1,389 @@ +// Package token defines constants representing the lexical tokens of ClickHouse SQL. +package token + +// Token represents a lexical token. +type Token int + +const ( + // Special tokens + ILLEGAL Token = iota + EOF + WHITESPACE + COMMENT + + // Literals + IDENT // identifiers + NUMBER // integer or float literals + STRING // string literals + PARAM // parameter placeholders like {name:Type} + + // Operators + PLUS // + + MINUS // - + ASTERISK // * + SLASH // / + PERCENT // % + EQ // = + NEQ // != or <> + LT // < + GT // > + LTE // <= + GTE // >= + CONCAT // || + ARROW // -> + COLONCOLON // :: + + // Delimiters + LPAREN // ( + RPAREN // ) + LBRACKET // [ + RBRACKET // ] + LBRACE // { + RBRACE // } + COMMA // , + DOT // . + SEMICOLON // ; + COLON // : + QUESTION // ? + + // Keywords + keyword_beg + ADD + ALIAS + ALL + ALTER + AND + ANTI + ANY + ARRAY + AS + ASC + ASOF + ATTACH + BETWEEN + BOTH + BY + CASE + CAST + CHECK + CLUSTER + COLLATE + COLUMN + CONSTRAINT + CREATE + CROSS + DATABASE + DATABASES + DEFAULT + DELETE + DESC + DESCRIBE + DETACH + DISTINCT + DISTRIBUTED + DROP + ELSE + END + ENGINE + EXCEPT + EXISTS + EXPLAIN + EXTRACT + FALSE + FETCH + FINAL + FIRST + FOR + FORMAT + FROM + FULL + FUNCTION + GLOBAL + GRANT + GROUP + HAVING + IF + ILIKE + IN + INDEX + INF + INNER + INSERT + INTERVAL + INTO + IS + JOIN + KEY + KILL + LEADING + LEFT + LIKE + LIMIT + LIVE + LOCAL + MATERIALIZED + MODIFY + NAN + NATURAL + NOT + NULL + NULLS + OFFSET + ON + OPTIMIZE + OR + ORDER + OUTER + OUTFILE + OVER + PARTITION + POPULATE + PREWHERE + PRIMARY + RENAME + REPLACE + REVOKE + RIGHT + ROLLUP + SAMPLE + SELECT + SEMI + SET + SETTINGS + SHOW + SUBSTRING + SYSTEM + TABLE + TABLES + TEMPORARY + THEN + TIES + TO + TOP + TOTALS + TRAILING + TRIM + TRUE + TRUNCATE + TTL + UNION + UPDATE + USE + USING + VALUES + VIEW + WATCH + WHEN + WHERE + WINDOW + WITH + keyword_end +) + +var tokens = [...]string{ + ILLEGAL: "ILLEGAL", + EOF: "EOF", + WHITESPACE: "WHITESPACE", + COMMENT: "COMMENT", + + IDENT: "IDENT", + NUMBER: "NUMBER", + STRING: "STRING", + PARAM: "PARAM", + + PLUS: "+", + MINUS: "-", + ASTERISK: "*", + SLASH: "/", + PERCENT: "%", + EQ: "=", + NEQ: "!=", + LT: "<", + GT: ">", + LTE: "<=", + GTE: ">=", + CONCAT: "||", + ARROW: "->", + COLONCOLON: "::", + + LPAREN: "(", + RPAREN: ")", + LBRACKET: "[", + RBRACKET: "]", + LBRACE: "{", + RBRACE: "}", + COMMA: ",", + DOT: ".", + SEMICOLON: ";", + COLON: ":", + QUESTION: "?", + + ADD: "ADD", + ALIAS: "ALIAS", + ALL: "ALL", + ALTER: "ALTER", + AND: "AND", + ANTI: "ANTI", + ANY: "ANY", + ARRAY: "ARRAY", + AS: "AS", + ASC: "ASC", + ASOF: "ASOF", + ATTACH: "ATTACH", + BETWEEN: "BETWEEN", + BOTH: "BOTH", + BY: "BY", + CASE: "CASE", + CAST: "CAST", + CHECK: "CHECK", + CLUSTER: "CLUSTER", + COLLATE: "COLLATE", + COLUMN: "COLUMN", + CONSTRAINT: "CONSTRAINT", + CREATE: "CREATE", + CROSS: "CROSS", + DATABASE: "DATABASE", + DATABASES: "DATABASES", + DEFAULT: "DEFAULT", + DELETE: "DELETE", + DESC: "DESC", + DESCRIBE: "DESCRIBE", + DETACH: "DETACH", + DISTINCT: "DISTINCT", + DISTRIBUTED: "DISTRIBUTED", + DROP: "DROP", + ELSE: "ELSE", + END: "END", + ENGINE: "ENGINE", + EXCEPT: "EXCEPT", + EXISTS: "EXISTS", + EXPLAIN: "EXPLAIN", + EXTRACT: "EXTRACT", + FALSE: "FALSE", + FETCH: "FETCH", + FINAL: "FINAL", + FIRST: "FIRST", + FOR: "FOR", + FORMAT: "FORMAT", + FROM: "FROM", + FULL: "FULL", + FUNCTION: "FUNCTION", + GLOBAL: "GLOBAL", + GRANT: "GRANT", + GROUP: "GROUP", + HAVING: "HAVING", + IF: "IF", + ILIKE: "ILIKE", + IN: "IN", + INDEX: "INDEX", + INF: "INF", + INNER: "INNER", + INSERT: "INSERT", + INTERVAL: "INTERVAL", + INTO: "INTO", + IS: "IS", + JOIN: "JOIN", + KEY: "KEY", + KILL: "KILL", + LEADING: "LEADING", + LEFT: "LEFT", + LIKE: "LIKE", + LIMIT: "LIMIT", + LIVE: "LIVE", + LOCAL: "LOCAL", + MATERIALIZED: "MATERIALIZED", + MODIFY: "MODIFY", + NAN: "NAN", + NATURAL: "NATURAL", + NOT: "NOT", + NULL: "NULL", + NULLS: "NULLS", + OFFSET: "OFFSET", + ON: "ON", + OPTIMIZE: "OPTIMIZE", + OR: "OR", + ORDER: "ORDER", + OUTER: "OUTER", + OUTFILE: "OUTFILE", + OVER: "OVER", + PARTITION: "PARTITION", + POPULATE: "POPULATE", + PREWHERE: "PREWHERE", + PRIMARY: "PRIMARY", + RENAME: "RENAME", + REPLACE: "REPLACE", + REVOKE: "REVOKE", + RIGHT: "RIGHT", + ROLLUP: "ROLLUP", + SAMPLE: "SAMPLE", + SELECT: "SELECT", + SEMI: "SEMI", + SET: "SET", + SETTINGS: "SETTINGS", + SHOW: "SHOW", + SUBSTRING: "SUBSTRING", + SYSTEM: "SYSTEM", + TABLE: "TABLE", + TABLES: "TABLES", + TEMPORARY: "TEMPORARY", + THEN: "THEN", + TIES: "TIES", + TO: "TO", + TOP: "TOP", + TOTALS: "TOTALS", + TRAILING: "TRAILING", + TRIM: "TRIM", + TRUE: "TRUE", + TRUNCATE: "TRUNCATE", + TTL: "TTL", + UNION: "UNION", + UPDATE: "UPDATE", + USE: "USE", + USING: "USING", + VALUES: "VALUES", + VIEW: "VIEW", + WATCH: "WATCH", + WHEN: "WHEN", + WHERE: "WHERE", + WINDOW: "WINDOW", + WITH: "WITH", +} + +func (tok Token) String() string { + if tok >= 0 && int(tok) < len(tokens) { + return tokens[tok] + } + return "" +} + +// Keywords maps keyword strings to their token types. +var Keywords map[string]Token + +func init() { + Keywords = make(map[string]Token) + for i := keyword_beg + 1; i < keyword_end; i++ { + Keywords[tokens[i]] = i + } +} + +// Lookup returns the token type for an identifier string. +// If the string is a keyword, it returns the keyword token. +// Otherwise, it returns IDENT. +func Lookup(ident string) Token { + if tok, ok := Keywords[ident]; ok { + return tok + } + return IDENT +} + +// IsKeyword returns true if the token is a keyword. +func (tok Token) IsKeyword() bool { + return tok > keyword_beg && tok < keyword_end +} + +// Position represents a source position. +type Position struct { + Offset int // byte offset + Line int // line number (1-based) + Column int // column number (1-based) +}