Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 99d73ef

Browse files
refactor: Define CTE-related SQL nodes for emitter (#2495)
1 parent 5c43efb commit 99d73ef

File tree

26 files changed

+875
-624
lines changed

26 files changed

+875
-624
lines changed

bigframes/core/bigframe_node.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,22 +330,12 @@ def top_down(
330330
"""
331331
Perform a top-down transformation of the BigFrameNode tree.
332332
"""
333-
to_process = [self]
334-
results: Dict[BigFrameNode, BigFrameNode] = {}
335333

336-
while to_process:
337-
item = to_process.pop()
338-
if item not in results.keys():
339-
item_result = transform(item)
340-
results[item] = item_result
341-
to_process.extend(item_result.child_nodes)
334+
@functools.cache
335+
def recursive_transform(node: BigFrameNode) -> BigFrameNode:
336+
return transform(node).transform_children(recursive_transform)
342337

343-
to_process = [self]
344-
# for each processed item, replace its children
345-
for item in reversed(list(results.keys())):
346-
results[item] = results[item].transform_children(lambda x: results[x])
347-
348-
return results[self]
338+
return recursive_transform(self)
349339

350340
def bottom_up(
351341
self: BigFrameNode,

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
6262
if request.sort_rows:
6363
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
6464
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
65+
# TODO: Extract CTEs earlier
66+
result_node = typing.cast(nodes.ResultNode, rewrite.extract_ctes(result_node))
6567
sql = _compile_result_node(result_node)
6668
return configs.CompileResult(
6769
sql,
@@ -74,6 +76,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
7476
result_node = dataclasses.replace(result_node, order_by=None)
7577
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
7678
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
79+
# TODO: Extract CTEs earlier
80+
result_node = typing.cast(nodes.ResultNode, rewrite.extract_ctes(result_node))
7781
sql = _compile_result_node(result_node)
7882
# Return the ordering iff no extra columns are needed to define the row order
7983
if ordering is not None:
@@ -94,6 +98,7 @@ def _remap_variables(
9498
result_node, _ = rewrite.remap_variables(
9599
node, map(identifiers.ColumnId, uid_gen.get_uid_stream("bfcol_"))
96100
)
101+
result_node.validate_tree()
97102
return typing.cast(nodes.ResultNode, result_node)
98103

99104

@@ -102,13 +107,16 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
102107
# of nodes using the same generator.
103108
uid_gen = guid.SequentialUIDGenerator()
104109
root = _remap_variables(root, uid_gen)
110+
# Remap variables creates too mayn new
111+
# root = rewrite.select_pullup(root, prefer_source_names=False)
105112
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))
106113

107114
# Have to bind schema as the final step before compilation.
108115
# Probably, should defer even further
109116
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
110117

111-
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root), uid_gen)
118+
# TODO: Bake all IDs in tree, stop passing uid_gen to emitters
119+
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root, uid_gen), uid_gen)
112120
return sqlglot_ir_obj.sql
113121

114122

@@ -121,7 +129,7 @@ def compile_node(
121129
for current_node in list(node.iter_nodes_topo()):
122130
if current_node.child_nodes == ():
123131
# For leaf node, generates a dumpy child to pass the UID generator.
124-
child_results = tuple([sqlglot_ir.SQLGlotIR(uid_gen=uid_gen)])
132+
child_results = tuple([sqlglot_ir.SQLGlotIR.empty(uid_gen=uid_gen)])
125133
else:
126134
# Child nodes should have been compiled in the reverse topological order.
127135
child_results = tuple(
@@ -256,6 +264,23 @@ def compile_isin_join(
256264
)
257265

258266

267+
@_compile_node.register
268+
def compile_cte_ref_node(node: sql_nodes.SqlCteRefNode, child: sqlglot_ir.SQLGlotIR):
269+
return sqlglot_ir.SQLGlotIR.from_cte_ref(
270+
node.cte_name,
271+
uid_gen=child.uid_gen,
272+
)
273+
274+
275+
@_compile_node.register
276+
def compile_with_ctes_node(
277+
node: sql_nodes.SqlWithCtesNode,
278+
child: sqlglot_ir.SQLGlotIR,
279+
*ctes: sqlglot_ir.SQLGlotIR,
280+
):
281+
return child.with_ctes(tuple(zip(node.cte_names, ctes)))
282+
283+
259284
@_compile_node.register
260285
def compile_concat(
261286
node: nodes.ConcatNode, *children: sqlglot_ir.SQLGlotIR
@@ -271,7 +296,7 @@ def compile_concat(
271296
]
272297

273298
return sqlglot_ir.SQLGlotIR.from_union(
274-
[child._as_select() for child in children],
299+
[child.expr.as_select_all() for child in children],
275300
output_aliases=output_aliases,
276301
uid_gen=uid_gen,
277302
)

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,39 @@
3333
@register_unary_op(ops.IsInOp, pass_op=True)
3434
def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:
3535
values = []
36+
# bools are not comparable to non-bools in SQL, so we need to cast the expression to INT64 if the values contain non-bools.
37+
must_upcast_bools = dtypes.is_numeric(expr.dtype, include_bool=False) or any(
38+
dtypes.is_numeric(dtypes.bigframes_type(type(value)), include_bool=False)
39+
for value in op.values
40+
if not _is_null(value)
41+
)
3642
for value in op.values:
3743
if _is_null(value):
3844
continue
3945
dtype = dtypes.bigframes_type(type(value))
4046
if dtypes.can_compare(expr.dtype, dtype):
47+
if must_upcast_bools and dtype == dtypes.BOOL_DTYPE:
48+
value = int(value)
4149
values.append(sge.convert(value))
4250

51+
sg_lexpr: sge.Expression = expr.expr
52+
if expr.dtype == dtypes.BOOL_DTYPE and must_upcast_bools:
53+
sg_lexpr = sge.cast(expr.expr, "INT64")
54+
4355
if op.match_nulls:
4456
contains_nulls = any(_is_null(value) for value in op.values)
4557
if contains_nulls:
4658
if len(values) == 0:
47-
return sge.Is(this=expr.expr, expression=sge.Null())
48-
return sge.Is(this=expr.expr, expression=sge.Null()) | sge.In(
49-
this=expr.expr, expressions=values
59+
return sge.Is(this=sg_lexpr, expression=sge.Null())
60+
return sge.Is(this=sg_lexpr, expression=sge.Null()) | sge.In(
61+
this=sg_lexpr, expressions=values
5062
)
5163

5264
if len(values) == 0:
5365
return sge.convert(False)
5466

5567
return sge.func(
56-
"COALESCE", sge.In(this=expr.expr, expressions=values), sge.convert(False)
68+
"COALESCE", sge.In(this=sg_lexpr, expressions=values), sge.convert(False)
5769
)
5870

5971

0 commit comments

Comments
 (0)