Skip to content

Commit e353025

Browse files
refactor: Make join nullity optimizations more robust
1 parent 2e508b6 commit e353025

File tree

30 files changed

+186
-259
lines changed

30 files changed

+186
-259
lines changed

packages/bigframes/bigframes/core/array_value.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def relational_join(
541541
for l_col, r_col in conditions
542542
),
543543
type=type,
544+
nulls_equal=True, # pandas semantics
544545
propogate_order=propogate_order or self.session._strictly_ordered,
545546
)
546547
return ArrayValue(join_node), (l_mapping, r_mapping)

packages/bigframes/bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
5050
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
5151
result_node = rewrites.pull_up_limits(result_node)
5252
result_node = _replace_unsupported_ops(result_node)
53+
result_node = result_node.bottom_up(rewrites.simplify_join)
5354
# prune before pulling up order to avoid unnnecessary row_number() ops
5455
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
5556
result_node = rewrites.defer_order(

packages/bigframes/bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
5454
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
5555
result_node = rewrite.pull_up_limits(result_node)
5656
result_node = _replace_unsupported_ops(result_node)
57+
result_node = result_node.bottom_up(rewrite.simplify_join)
5758
# prune before pulling up order to avoid unnnecessary row_number() ops
5859
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
5960
result_node = rewrite.defer_order(

packages/bigframes/bigframes/core/local_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def to_arrow(
154154
else:
155155
return schema, batches
156156

157+
def is_nullable(self, column_id: identifiers.ColumnId) -> bool:
158+
return self.data.column(column_id).null_count > 0
159+
157160
def to_pyarrow_table(
158161
self,
159162
*,

packages/bigframes/bigframes/core/nodes.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class InNode(BigFrameNode, AdditiveNode):
204204
right_child: BigFrameNode
205205
left_col: ex.DerefOp
206206
indicator_col: identifiers.ColumnId
207+
nulls_equal: bool = True
207208

208209
def _validate(self):
209210
assert len(self.right_child.fields) == 1
@@ -271,10 +272,7 @@ def additive_base(self) -> BigFrameNode:
271272

272273
@property
273274
def joins_nulls(self) -> bool:
274-
left_nullable = self.left_child.field_by_id[self.left_col.id].nullable
275-
# assumption: right side has one column
276-
right_nullable = self.right_child.fields[0].nullable
277-
return left_nullable or right_nullable
275+
return self.nulls_equal
278276

279277
@property
280278
def _node_expressions(self):
@@ -316,6 +314,8 @@ class JoinNode(BigFrameNode):
316314
right_child: BigFrameNode
317315
conditions: typing.Tuple[typing.Tuple[ex.DerefOp, ex.DerefOp], ...]
318316
type: typing.Literal["inner", "outer", "left", "right", "cross"]
317+
# pandas treats nulls as equal, sql does not
318+
nulls_equal: bool
319319
propogate_order: bool
320320

321321
def _validate(self):
@@ -355,13 +355,7 @@ def fields(self) -> Sequence[Field]:
355355

356356
@property
357357
def joins_nulls(self) -> bool:
358-
for left_ref, right_ref in self.conditions:
359-
if (
360-
self.left_child.field_by_id[left_ref.id].nullable
361-
and self.right_child.field_by_id[right_ref.id].nullable
362-
):
363-
return True
364-
return False
358+
return self.nulls_equal
365359

366360
@functools.cached_property
367361
def variables_introduced(self) -> int:
@@ -675,7 +669,11 @@ class ReadLocalNode(LeafNode):
675669
@property
676670
def fields(self) -> Sequence[Field]:
677671
fields = tuple(
678-
Field(col_id, self.local_data_source.schema.get_type(source_id))
672+
Field(
673+
col_id,
674+
self.local_data_source.schema.get_type(source_id),
675+
nullable=self.local_data_source.is_nullable(source_id),
676+
)
679677
for col_id, source_id in self.scan_list.items
680678
)
681679

packages/bigframes/bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
rewrite_range_rolling,
3434
simplify_complex_windows,
3535
)
36+
from bigframes.core.rewrite.nullity import simplify_join
3637

3738
__all__ = [
3839
"as_sql_nodes",
@@ -55,4 +56,5 @@
5556
"defer_selection",
5657
"simplify_complex_windows",
5758
"lower_udfs",
59+
"simplify_join",
5860
]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from bigframes.core import nodes
18+
import dataclasses
19+
20+
21+
def simplify_join(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
22+
"""Simplify a join node by removing nullity checks."""
23+
# if join conditions are provably non-null, we can set nulls_equal=False
24+
if isinstance(node, nodes.JoinNode):
25+
# even better, we can always make nulls_equal false, but wrap the join keys in coalesce
26+
# to handle nulls correctly, this is more granular than the current implementation
27+
for left_ref, right_ref in node.conditions:
28+
if (
29+
node.left_child.field_by_id[left_ref.id].nullable
30+
and node.right_child.field_by_id[right_ref.id].nullable
31+
):
32+
return node
33+
return dataclasses.replace(node, nulls_equal=False)
34+
elif isinstance(node, nodes.InNode):
35+
if (
36+
node.left_child.field_by_id[node.left_col.id].nullable
37+
and node.right_child.fields[0].nullable
38+
):
39+
return node
40+
return dataclasses.replace(node, nulls_equal=False)
41+
else:
42+
return node

packages/bigframes/tests/unit/core/compile/sqlglot/tpch/conftest.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,80 +24,80 @@
2424

2525
TPCH_SCHEMAS = {
2626
"LINEITEM": [
27-
bigquery.SchemaField("L_ORDERKEY", "INTEGER"),
28-
bigquery.SchemaField("L_PARTKEY", "INTEGER"),
29-
bigquery.SchemaField("L_SUPPKEY", "INTEGER"),
30-
bigquery.SchemaField("L_LINENUMBER", "INTEGER"),
31-
bigquery.SchemaField("L_QUANTITY", "FLOAT"),
32-
bigquery.SchemaField("L_EXTENDEDPRICE", "FLOAT"),
33-
bigquery.SchemaField("L_DISCOUNT", "FLOAT"),
34-
bigquery.SchemaField("L_TAX", "FLOAT"),
35-
bigquery.SchemaField("L_RETURNFLAG", "STRING"),
36-
bigquery.SchemaField("L_LINESTATUS", "STRING"),
37-
bigquery.SchemaField("L_SHIPDATE", "DATE"),
38-
bigquery.SchemaField("L_COMMITDATE", "DATE"),
39-
bigquery.SchemaField("L_RECEIPTDATE", "DATE"),
40-
bigquery.SchemaField("L_SHIPINSTRUCT", "STRING"),
41-
bigquery.SchemaField("L_SHIPMODE", "STRING"),
27+
bigquery.SchemaField("L_ORDERKEY", "INTEGER", mode="REQUIRED"),
28+
bigquery.SchemaField("L_PARTKEY", "INTEGER", mode="REQUIRED"),
29+
bigquery.SchemaField("L_SUPPKEY", "INTEGER", mode="REQUIRED"),
30+
bigquery.SchemaField("L_LINENUMBER", "INTEGER", mode="REQUIRED"),
31+
bigquery.SchemaField("L_QUANTITY", "FLOAT", mode="REQUIRED"),
32+
bigquery.SchemaField("L_EXTENDEDPRICE", "FLOAT", mode="REQUIRED"),
33+
bigquery.SchemaField("L_DISCOUNT", "FLOAT", mode="REQUIRED"),
34+
bigquery.SchemaField("L_TAX", "FLOAT", mode="REQUIRED"),
35+
bigquery.SchemaField("L_RETURNFLAG", "STRING", mode="REQUIRED"),
36+
bigquery.SchemaField("L_LINESTATUS", "STRING", mode="REQUIRED"),
37+
bigquery.SchemaField("L_SHIPDATE", "DATE", mode="REQUIRED"),
38+
bigquery.SchemaField("L_COMMITDATE", "DATE", mode="REQUIRED"),
39+
bigquery.SchemaField("L_RECEIPTDATE", "DATE", mode="REQUIRED"),
40+
bigquery.SchemaField("L_SHIPINSTRUCT", "STRING", mode="REQUIRED"),
41+
bigquery.SchemaField("L_SHIPMODE", "STRING", mode="REQUIRED"),
4242
bigquery.SchemaField("L_COMMENT", "STRING"),
4343
],
4444
"ORDERS": [
45-
bigquery.SchemaField("O_ORDERKEY", "INTEGER"),
46-
bigquery.SchemaField("O_CUSTKEY", "INTEGER"),
47-
bigquery.SchemaField("O_ORDERSTATUS", "STRING"),
48-
bigquery.SchemaField("O_TOTALPRICE", "FLOAT"),
49-
bigquery.SchemaField("O_ORDERDATE", "DATE"),
50-
bigquery.SchemaField("O_ORDERPRIORITY", "STRING"),
51-
bigquery.SchemaField("O_CLERK", "STRING"),
52-
bigquery.SchemaField("O_SHIPPRIORITY", "INTEGER"),
45+
bigquery.SchemaField("O_ORDERKEY", "INTEGER", mode="REQUIRED"),
46+
bigquery.SchemaField("O_CUSTKEY", "INTEGER", mode="REQUIRED"),
47+
bigquery.SchemaField("O_ORDERSTATUS", "STRING", mode="REQUIRED"),
48+
bigquery.SchemaField("O_TOTALPRICE", "FLOAT", mode="REQUIRED"),
49+
bigquery.SchemaField("O_ORDERDATE", "DATE", mode="REQUIRED"),
50+
bigquery.SchemaField("O_ORDERPRIORITY", "STRING", mode="REQUIRED"),
51+
bigquery.SchemaField("O_CLERK", "STRING", mode="REQUIRED"),
52+
bigquery.SchemaField("O_SHIPPRIORITY", "INTEGER", mode="REQUIRED"),
5353
bigquery.SchemaField("O_COMMENT", "STRING"),
5454
],
5555
"PART": [
56-
bigquery.SchemaField("P_PARTKEY", "INTEGER"),
57-
bigquery.SchemaField("P_NAME", "STRING"),
58-
bigquery.SchemaField("P_MFGR", "STRING"),
59-
bigquery.SchemaField("P_BRAND", "STRING"),
60-
bigquery.SchemaField("P_TYPE", "STRING"),
61-
bigquery.SchemaField("P_SIZE", "INTEGER"),
62-
bigquery.SchemaField("P_CONTAINER", "STRING"),
63-
bigquery.SchemaField("P_RETAILPRICE", "FLOAT"),
56+
bigquery.SchemaField("P_PARTKEY", "INTEGER", mode="REQUIRED"),
57+
bigquery.SchemaField("P_NAME", "STRING", mode="REQUIRED"),
58+
bigquery.SchemaField("P_MFGR", "STRING", mode="REQUIRED"),
59+
bigquery.SchemaField("P_BRAND", "STRING", mode="REQUIRED"),
60+
bigquery.SchemaField("P_TYPE", "STRING", mode="REQUIRED"),
61+
bigquery.SchemaField("P_SIZE", "INTEGER", mode="REQUIRED"),
62+
bigquery.SchemaField("P_CONTAINER", "STRING", mode="REQUIRED"),
63+
bigquery.SchemaField("P_RETAILPRICE", "FLOAT", mode="REQUIRED"),
6464
bigquery.SchemaField("P_COMMENT", "STRING"),
6565
],
6666
"SUPPLIER": [
67-
bigquery.SchemaField("S_SUPPKEY", "INTEGER"),
68-
bigquery.SchemaField("S_NAME", "STRING"),
69-
bigquery.SchemaField("S_ADDRESS", "STRING"),
70-
bigquery.SchemaField("S_NATIONKEY", "INTEGER"),
71-
bigquery.SchemaField("S_PHONE", "STRING"),
72-
bigquery.SchemaField("S_ACCTBAL", "FLOAT"),
67+
bigquery.SchemaField("S_SUPPKEY", "INTEGER", mode="REQUIRED"),
68+
bigquery.SchemaField("S_NAME", "STRING", mode="REQUIRED"),
69+
bigquery.SchemaField("S_ADDRESS", "STRING", mode="REQUIRED"),
70+
bigquery.SchemaField("S_NATIONKEY", "INTEGER", mode="REQUIRED"),
71+
bigquery.SchemaField("S_PHONE", "STRING", mode="REQUIRED"),
72+
bigquery.SchemaField("S_ACCTBAL", "FLOAT", mode="REQUIRED"),
7373
bigquery.SchemaField("S_COMMENT", "STRING"),
7474
],
7575
"PARTSUPP": [
76-
bigquery.SchemaField("PS_PARTKEY", "INTEGER"),
77-
bigquery.SchemaField("PS_SUPPKEY", "INTEGER"),
78-
bigquery.SchemaField("PS_AVAILQTY", "INTEGER"),
79-
bigquery.SchemaField("PS_SUPPLYCOST", "FLOAT"),
76+
bigquery.SchemaField("PS_PARTKEY", "INTEGER", mode="REQUIRED"),
77+
bigquery.SchemaField("PS_SUPPKEY", "INTEGER", mode="REQUIRED"),
78+
bigquery.SchemaField("PS_AVAILQTY", "INTEGER", mode="REQUIRED"),
79+
bigquery.SchemaField("PS_SUPPLYCOST", "FLOAT", mode="REQUIRED"),
8080
bigquery.SchemaField("PS_COMMENT", "STRING"),
8181
],
8282
"CUSTOMER": [
83-
bigquery.SchemaField("C_CUSTKEY", "INTEGER"),
84-
bigquery.SchemaField("C_NAME", "STRING"),
85-
bigquery.SchemaField("C_ADDRESS", "STRING"),
86-
bigquery.SchemaField("C_NATIONKEY", "INTEGER"),
87-
bigquery.SchemaField("C_PHONE", "STRING"),
88-
bigquery.SchemaField("C_ACCTBAL", "FLOAT"),
89-
bigquery.SchemaField("C_MKTSEGMENT", "STRING"),
83+
bigquery.SchemaField("C_CUSTKEY", "INTEGER", mode="REQUIRED"),
84+
bigquery.SchemaField("C_NAME", "STRING", mode="REQUIRED"),
85+
bigquery.SchemaField("C_ADDRESS", "STRING", mode="REQUIRED"),
86+
bigquery.SchemaField("C_NATIONKEY", "INTEGER", mode="REQUIRED"),
87+
bigquery.SchemaField("C_PHONE", "STRING", mode="REQUIRED"),
88+
bigquery.SchemaField("C_ACCTBAL", "FLOAT", mode="REQUIRED"),
89+
bigquery.SchemaField("C_MKTSEGMENT", "STRING", mode="REQUIRED"),
9090
bigquery.SchemaField("C_COMMENT", "STRING"),
9191
],
9292
"NATION": [
93-
bigquery.SchemaField("N_NATIONKEY", "INTEGER"),
94-
bigquery.SchemaField("N_NAME", "STRING"),
95-
bigquery.SchemaField("N_REGIONKEY", "INTEGER"),
93+
bigquery.SchemaField("N_NATIONKEY", "INTEGER", mode="REQUIRED"),
94+
bigquery.SchemaField("N_NAME", "STRING", mode="REQUIRED"),
95+
bigquery.SchemaField("N_REGIONKEY", "INTEGER", mode="REQUIRED"),
9696
bigquery.SchemaField("N_COMMENT", "STRING"),
9797
],
9898
"REGION": [
99-
bigquery.SchemaField("R_REGIONKEY", "INTEGER"),
100-
bigquery.SchemaField("R_NAME", "STRING"),
99+
bigquery.SchemaField("R_REGIONKEY", "INTEGER", mode="REQUIRED"),
100+
bigquery.SchemaField("R_NAME", "STRING", mode="REQUIRED"),
101101
bigquery.SchemaField("R_COMMENT", "STRING"),
102102
],
103103
}

packages/bigframes/tests/unit/core/compile/sqlglot/tpch/snapshots/test_tpch/test_tpch_query/1/out.sql

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ WITH `bfcte_0` AS (
5454
AVG(`bfcol_43`) AS `bfcol_61`,
5555
COUNT(`bfcol_41`) AS `bfcol_62`
5656
FROM `bfcte_0`
57-
WHERE
58-
NOT `bfcol_44` IS NULL AND NOT `bfcol_45` IS NULL
5957
GROUP BY
6058
`bfcol_44`,
6159
`bfcol_45`

packages/bigframes/tests/unit/core/compile/sqlglot/tpch/snapshots/test_tpch/test_tpch_query/10/out.sql

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ WITH `bfcte_0` AS (
3939
`bfcol_8` AS `bfcol_24`
4040
FROM `bfcte_3`
4141
INNER JOIN `bfcte_2`
42-
ON COALESCE(`bfcol_9`, 0) = COALESCE(`bfcol_7`, 0)
43-
AND COALESCE(`bfcol_9`, 1) = COALESCE(`bfcol_7`, 1)
42+
ON `bfcol_9` = `bfcol_7`
4443
), `bfcte_5` AS (
4544
SELECT
4645
`bfcol_16` AS `bfcol_25`,
@@ -56,8 +55,7 @@ WITH `bfcte_0` AS (
5655
`bfcol_5` AS `bfcol_35`
5756
FROM `bfcte_4`
5857
INNER JOIN `bfcte_1`
59-
ON COALESCE(`bfcol_23`, 0) = COALESCE(`bfcol_2`, 0)
60-
AND COALESCE(`bfcol_23`, 1) = COALESCE(`bfcol_2`, 1)
58+
ON `bfcol_23` = `bfcol_2`
6159
), `bfcte_6` AS (
6260
SELECT
6361
`bfcol_25`,
@@ -107,8 +105,7 @@ WITH `bfcte_0` AS (
107105
), 2) AS `bfcol_83`
108106
FROM `bfcte_5`
109107
INNER JOIN `bfcte_0`
110-
ON COALESCE(`bfcol_28`, 0) = COALESCE(`bfcol_0`, 0)
111-
AND COALESCE(`bfcol_28`, 1) = COALESCE(`bfcol_0`, 1)
108+
ON `bfcol_28` = `bfcol_0`
112109
WHERE
113110
(
114111
(
@@ -133,13 +130,7 @@ WITH `bfcte_0` AS (
133130
COALESCE(SUM(`bfcol_83`), 0) AS `bfcol_92`
134131
FROM `bfcte_6`
135132
WHERE
136-
NOT `bfcol_76` IS NULL
137-
AND NOT `bfcol_77` IS NULL
138-
AND NOT `bfcol_80` IS NULL
139-
AND NOT `bfcol_79` IS NULL
140-
AND NOT `bfcol_82` IS NULL
141-
AND NOT `bfcol_78` IS NULL
142-
AND NOT `bfcol_81` IS NULL
133+
NOT `bfcol_81` IS NULL
143134
GROUP BY
144135
`bfcol_76`,
145136
`bfcol_77`,

0 commit comments

Comments
 (0)