Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/bigframes/bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def relational_join(
for l_col, r_col in conditions
),
type=type,
nulls_equal=True, # pandas semantics
propogate_order=propogate_order or self.session._strictly_ordered,
)
return ArrayValue(join_node), (l_mapping, r_mapping)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
result_node = rewrites.pull_up_limits(result_node)
result_node = _replace_unsupported_ops(result_node)
result_node = result_node.bottom_up(rewrites.simplify_join)
# prune before pulling up order to avoid unnnecessary row_number() ops
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
result_node = rewrites.defer_order(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
result_node = rewrite.pull_up_limits(result_node)
result_node = _replace_unsupported_ops(result_node)
result_node = result_node.bottom_up(rewrite.simplify_join)
# prune before pulling up order to avoid unnnecessary row_number() ops
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
result_node = rewrite.defer_order(
Expand Down
3 changes: 3 additions & 0 deletions packages/bigframes/bigframes/core/local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def to_arrow(
else:
return schema, batches

def is_nullable(self, column_id: identifiers.ColumnId) -> bool:
return self.data.column(column_id).null_count > 0

def to_pyarrow_table(
self,
*,
Expand Down
24 changes: 12 additions & 12 deletions packages/bigframes/bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ class InNode(BigFrameNode, AdditiveNode):
right_child: BigFrameNode
left_col: ex.DerefOp
indicator_col: identifiers.ColumnId
# For matching left_col to right_child[0], if true, nulls match nulls, if false, nulls don't match nulls
nulls_equal: bool = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment on what this field means? It's not very obvious to me

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment


def _validate(self):
assert len(self.right_child.fields) == 1
Expand Down Expand Up @@ -271,10 +273,7 @@ def additive_base(self) -> BigFrameNode:

@property
def joins_nulls(self) -> bool:
left_nullable = self.left_child.field_by_id[self.left_col.id].nullable
# assumption: right side has one column
right_nullable = self.right_child.fields[0].nullable
return left_nullable or right_nullable
return self.nulls_equal

@property
def _node_expressions(self):
Expand Down Expand Up @@ -316,6 +315,9 @@ class JoinNode(BigFrameNode):
right_child: BigFrameNode
conditions: typing.Tuple[typing.Tuple[ex.DerefOp, ex.DerefOp], ...]
type: typing.Literal["inner", "outer", "left", "right", "cross"]
# choose to treat nulls as equal or not for purposes of the join
# pandas treats nulls as equal, sql does not
nulls_equal: bool
propogate_order: bool

def _validate(self):
Expand Down Expand Up @@ -355,13 +357,7 @@ def fields(self) -> Sequence[Field]:

@property
def joins_nulls(self) -> bool:
for left_ref, right_ref in self.conditions:
if (
self.left_child.field_by_id[left_ref.id].nullable
and self.right_child.field_by_id[right_ref.id].nullable
):
return True
return False
return self.nulls_equal

@functools.cached_property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -675,7 +671,11 @@ class ReadLocalNode(LeafNode):
@property
def fields(self) -> Sequence[Field]:
fields = tuple(
Field(col_id, self.local_data_source.schema.get_type(source_id))
Field(
col_id,
self.local_data_source.schema.get_type(source_id),
nullable=self.local_data_source.is_nullable(source_id),
)
for col_id, source_id in self.scan_list.items
)

Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
rewrite_range_rolling,
simplify_complex_windows,
)
from bigframes.core.rewrite.nullity import simplify_join

__all__ = [
"as_sql_nodes",
Expand All @@ -55,4 +56,5 @@
"defer_selection",
"simplify_complex_windows",
"lower_udfs",
"simplify_join",
]
42 changes: 42 additions & 0 deletions packages/bigframes/bigframes/core/rewrite/nullity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from bigframes.core import nodes
import dataclasses


def simplify_join(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
"""Simplify a join node by removing nullity checks."""
# if join conditions are provably non-null, we can set nulls_equal=False
if isinstance(node, nodes.JoinNode):
# even better, we can always make nulls_equal false, but wrap the join keys in coalesce
# to handle nulls correctly, this is more granular than the current implementation
for left_ref, right_ref in node.conditions:
if (
node.left_child.field_by_id[left_ref.id].nullable
and node.right_child.field_by_id[right_ref.id].nullable
):
return node
return dataclasses.replace(node, nulls_equal=False)
elif isinstance(node, nodes.InNode):
if (
node.left_child.field_by_id[node.left_col.id].nullable
and node.right_child.fields[0].nullable
):
return node
return dataclasses.replace(node, nulls_equal=False)
else:
return node
106 changes: 53 additions & 53 deletions packages/bigframes/tests/unit/core/compile/sqlglot/tpch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,80 +24,80 @@

TPCH_SCHEMAS = {
"LINEITEM": [
bigquery.SchemaField("L_ORDERKEY", "INTEGER"),
bigquery.SchemaField("L_PARTKEY", "INTEGER"),
bigquery.SchemaField("L_SUPPKEY", "INTEGER"),
bigquery.SchemaField("L_LINENUMBER", "INTEGER"),
bigquery.SchemaField("L_QUANTITY", "FLOAT"),
bigquery.SchemaField("L_EXTENDEDPRICE", "FLOAT"),
bigquery.SchemaField("L_DISCOUNT", "FLOAT"),
bigquery.SchemaField("L_TAX", "FLOAT"),
bigquery.SchemaField("L_RETURNFLAG", "STRING"),
bigquery.SchemaField("L_LINESTATUS", "STRING"),
bigquery.SchemaField("L_SHIPDATE", "DATE"),
bigquery.SchemaField("L_COMMITDATE", "DATE"),
bigquery.SchemaField("L_RECEIPTDATE", "DATE"),
bigquery.SchemaField("L_SHIPINSTRUCT", "STRING"),
bigquery.SchemaField("L_SHIPMODE", "STRING"),
bigquery.SchemaField("L_ORDERKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_PARTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_SUPPKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_LINENUMBER", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_QUANTITY", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_EXTENDEDPRICE", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_DISCOUNT", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_TAX", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_RETURNFLAG", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_LINESTATUS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_SHIPDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("L_COMMITDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("L_RECEIPTDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("L_SHIPINSTRUCT", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_SHIPMODE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_COMMENT", "STRING"),
],
"ORDERS": [
bigquery.SchemaField("O_ORDERKEY", "INTEGER"),
bigquery.SchemaField("O_CUSTKEY", "INTEGER"),
bigquery.SchemaField("O_ORDERSTATUS", "STRING"),
bigquery.SchemaField("O_TOTALPRICE", "FLOAT"),
bigquery.SchemaField("O_ORDERDATE", "DATE"),
bigquery.SchemaField("O_ORDERPRIORITY", "STRING"),
bigquery.SchemaField("O_CLERK", "STRING"),
bigquery.SchemaField("O_SHIPPRIORITY", "INTEGER"),
bigquery.SchemaField("O_ORDERKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("O_CUSTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("O_ORDERSTATUS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("O_TOTALPRICE", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("O_ORDERDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("O_ORDERPRIORITY", "STRING", mode="REQUIRED"),
bigquery.SchemaField("O_CLERK", "STRING", mode="REQUIRED"),
bigquery.SchemaField("O_SHIPPRIORITY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("O_COMMENT", "STRING"),
],
"PART": [
bigquery.SchemaField("P_PARTKEY", "INTEGER"),
bigquery.SchemaField("P_NAME", "STRING"),
bigquery.SchemaField("P_MFGR", "STRING"),
bigquery.SchemaField("P_BRAND", "STRING"),
bigquery.SchemaField("P_TYPE", "STRING"),
bigquery.SchemaField("P_SIZE", "INTEGER"),
bigquery.SchemaField("P_CONTAINER", "STRING"),
bigquery.SchemaField("P_RETAILPRICE", "FLOAT"),
bigquery.SchemaField("P_PARTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("P_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_MFGR", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_BRAND", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_TYPE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_SIZE", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("P_CONTAINER", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_RETAILPRICE", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("P_COMMENT", "STRING"),
],
"SUPPLIER": [
bigquery.SchemaField("S_SUPPKEY", "INTEGER"),
bigquery.SchemaField("S_NAME", "STRING"),
bigquery.SchemaField("S_ADDRESS", "STRING"),
bigquery.SchemaField("S_NATIONKEY", "INTEGER"),
bigquery.SchemaField("S_PHONE", "STRING"),
bigquery.SchemaField("S_ACCTBAL", "FLOAT"),
bigquery.SchemaField("S_SUPPKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("S_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("S_ADDRESS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("S_NATIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("S_PHONE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("S_ACCTBAL", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("S_COMMENT", "STRING"),
],
"PARTSUPP": [
bigquery.SchemaField("PS_PARTKEY", "INTEGER"),
bigquery.SchemaField("PS_SUPPKEY", "INTEGER"),
bigquery.SchemaField("PS_AVAILQTY", "INTEGER"),
bigquery.SchemaField("PS_SUPPLYCOST", "FLOAT"),
bigquery.SchemaField("PS_PARTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("PS_SUPPKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("PS_AVAILQTY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("PS_SUPPLYCOST", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("PS_COMMENT", "STRING"),
],
"CUSTOMER": [
bigquery.SchemaField("C_CUSTKEY", "INTEGER"),
bigquery.SchemaField("C_NAME", "STRING"),
bigquery.SchemaField("C_ADDRESS", "STRING"),
bigquery.SchemaField("C_NATIONKEY", "INTEGER"),
bigquery.SchemaField("C_PHONE", "STRING"),
bigquery.SchemaField("C_ACCTBAL", "FLOAT"),
bigquery.SchemaField("C_MKTSEGMENT", "STRING"),
bigquery.SchemaField("C_CUSTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("C_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_ADDRESS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_NATIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("C_PHONE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_ACCTBAL", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("C_MKTSEGMENT", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_COMMENT", "STRING"),
],
"NATION": [
bigquery.SchemaField("N_NATIONKEY", "INTEGER"),
bigquery.SchemaField("N_NAME", "STRING"),
bigquery.SchemaField("N_REGIONKEY", "INTEGER"),
bigquery.SchemaField("N_NATIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("N_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("N_REGIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("N_COMMENT", "STRING"),
],
"REGION": [
bigquery.SchemaField("R_REGIONKEY", "INTEGER"),
bigquery.SchemaField("R_NAME", "STRING"),
bigquery.SchemaField("R_REGIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("R_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("R_COMMENT", "STRING"),
],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ WITH `bfcte_0` AS (
AVG(`bfcol_43`) AS `bfcol_61`,
COUNT(`bfcol_41`) AS `bfcol_62`
FROM `bfcte_0`
WHERE
NOT `bfcol_44` IS NULL AND NOT `bfcol_45` IS NULL
GROUP BY
`bfcol_44`,
`bfcol_45`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ WITH `bfcte_0` AS (
`bfcol_8` AS `bfcol_24`
FROM `bfcte_3`
INNER JOIN `bfcte_2`
ON COALESCE(`bfcol_9`, 0) = COALESCE(`bfcol_7`, 0)
AND COALESCE(`bfcol_9`, 1) = COALESCE(`bfcol_7`, 1)
ON `bfcol_9` = `bfcol_7`
), `bfcte_5` AS (
SELECT
`bfcol_16` AS `bfcol_25`,
Expand All @@ -56,8 +55,7 @@ WITH `bfcte_0` AS (
`bfcol_5` AS `bfcol_35`
FROM `bfcte_4`
INNER JOIN `bfcte_1`
ON COALESCE(`bfcol_23`, 0) = COALESCE(`bfcol_2`, 0)
AND COALESCE(`bfcol_23`, 1) = COALESCE(`bfcol_2`, 1)
ON `bfcol_23` = `bfcol_2`
), `bfcte_6` AS (
SELECT
`bfcol_25`,
Expand Down Expand Up @@ -107,8 +105,7 @@ WITH `bfcte_0` AS (
), 2) AS `bfcol_83`
FROM `bfcte_5`
INNER JOIN `bfcte_0`
ON COALESCE(`bfcol_28`, 0) = COALESCE(`bfcol_0`, 0)
AND COALESCE(`bfcol_28`, 1) = COALESCE(`bfcol_0`, 1)
ON `bfcol_28` = `bfcol_0`
WHERE
(
(
Expand All @@ -133,13 +130,7 @@ WITH `bfcte_0` AS (
COALESCE(SUM(`bfcol_83`), 0) AS `bfcol_92`
FROM `bfcte_6`
WHERE
NOT `bfcol_76` IS NULL
AND NOT `bfcol_77` IS NULL
AND NOT `bfcol_80` IS NULL
AND NOT `bfcol_79` IS NULL
AND NOT `bfcol_82` IS NULL
AND NOT `bfcol_78` IS NULL
AND NOT `bfcol_81` IS NULL
NOT `bfcol_81` IS NULL
GROUP BY
`bfcol_76`,
`bfcol_77`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ WITH `bfcte_0` AS (
`bfcol_3` AS `bfcol_19`
FROM `bfcte_4`
INNER JOIN `bfcte_3`
ON COALESCE(`bfcol_18`, 0) = COALESCE(`bfcol_4`, 0)
AND COALESCE(`bfcol_18`, 1) = COALESCE(`bfcol_4`, 1)
ON `bfcol_18` = `bfcol_4`
), `bfcte_6` AS (
SELECT
`bfcol_19`,
Expand All @@ -46,8 +45,7 @@ WITH `bfcte_0` AS (
`bfcol_2` * `bfcol_1` AS `bfcol_40`
FROM `bfcte_5`
INNER JOIN `bfcte_1`
ON COALESCE(`bfcol_19`, 0) = COALESCE(`bfcol_0`, 0)
AND COALESCE(`bfcol_19`, 1) = COALESCE(`bfcol_0`, 1)
ON `bfcol_19` = `bfcol_0`
), `bfcte_7` AS (
SELECT
`bfcol_19`,
Expand All @@ -59,8 +57,7 @@ WITH `bfcte_0` AS (
`bfcol_13` * `bfcol_12` AS `bfcol_28`
FROM `bfcte_5`
INNER JOIN `bfcte_2`
ON COALESCE(`bfcol_19`, 0) = COALESCE(`bfcol_11`, 0)
AND COALESCE(`bfcol_19`, 1) = COALESCE(`bfcol_11`, 1)
ON `bfcol_19` = `bfcol_11`
), `bfcte_8` AS (
SELECT
COALESCE(SUM(`bfcol_40`), 0) AS `bfcol_44`
Expand All @@ -70,8 +67,6 @@ WITH `bfcte_0` AS (
`bfcol_27`,
COALESCE(SUM(`bfcol_28`), 0) AS `bfcol_35`
FROM `bfcte_7`
WHERE
NOT `bfcol_27` IS NULL
GROUP BY
`bfcol_27`
), `bfcte_10` AS (
Expand Down Expand Up @@ -101,8 +96,6 @@ WITH `bfcte_0` AS (
`bfcol_8`,
ANY_VALUE(`bfcol_51`) AS `bfcol_55`
FROM `bfcte_12`
WHERE
NOT `bfcol_7` IS NULL AND NOT `bfcol_8` IS NULL
GROUP BY
`bfcol_7`,
`bfcol_8`
Expand Down
Loading
Loading