Skip to content

Commit

Permalink
feat: upcast schemas if needed during set ops
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Feb 20, 2025
1 parent d404520 commit e14cdb1
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t2"
UNION ALL
SELECT
*
FROM (
SELECT
"t0"."id",
"t0"."bigint_col" + 256 AS "i",
"t0"."string_col" AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t1"
) AS "t3"
ORDER BY
"t3"."id" ASC,
"t3"."i" ASC,
"t3"."s" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t2"
UNION
SELECT
*
FROM (
SELECT
"t0"."id",
"t0"."bigint_col" + 256 AS "i",
"t0"."string_col" AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t1"
) AS "t3"
ORDER BY
"t3"."id" ASC,
"t3"."i" ASC,
"t3"."s" ASC
17 changes: 17 additions & 0 deletions ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,20 @@ def test_ctes_in_order():

sql = ibis.to_sql(expr, dialect="duckdb")
assert sql.find('"first" AS (') < sql.find('"second" AS (')


@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
def test_union_unified_schemas(snapshot, functional_alltypes, distinct):
a = functional_alltypes.select(
"id", i="tinyint_col", s=_.string_col.cast("!string")
)
b = functional_alltypes.select(
"id",
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
s=_.string_col.cast("string"),
)
expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")

assert expr.i.type() == b.i.type()
assert expr.s.type() == b.s.type()
snapshot.assert_match(to_sql(expr), "out.sql")
39 changes: 38 additions & 1 deletion ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import ibis
import ibis.expr.types as ir
from ibis import _
from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError
from ibis.backends.tests.errors import (
OracleDatabaseError,
PsycoPg2InternalError,
PyDruidProgrammingError,
)

pd = pytest.importorskip("pandas")

Expand Down Expand Up @@ -49,6 +53,39 @@ def test_union(backend, union_subsets, distinct):
backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
@pytest.mark.notyet(
["oracle"], raises=OracleDatabaseError, reason="does not support NOT NULL types"
)
def test_unified_schemas(backend, con, distinct):
a = con.table("functional_alltypes").select(
"id",
i="tinyint_col",
s=_.string_col.cast("!string"),
)
b = con.table("functional_alltypes").select(
"id",
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
s=_.string_col.cast("string"),
)

expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")
assert expr.i.type() == b.i.type()
assert expr.s.type() == b.s.type()
result = expr.execute()

expected = (
pd.concat([a.execute(), b.execute()], axis=0)
.sort_values(["id", "i", "s"])
.reset_index(drop=True)
)
if distinct:
expected = expected.drop_duplicates(["id", "i", "s"])

backend.assert_frame_equal(result, expected, check_dtype=False)


@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_union_mixed_distinct(backend, union_subsets):
(a, b, c), (da, db, dc) = union_subsets
Expand Down
78 changes: 57 additions & 21 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import ibis.expr.datatypes as dt
from ibis.common.annotations import attribute
from ibis.common.collections import (
ConflictingValuesError,
FrozenDict,
FrozenOrderedDict,
)
Expand Down Expand Up @@ -328,6 +327,61 @@ def schema(self):
return Schema({k: v.dtype for k, v in self.values.items()})


Relations = TypeVar("Relations", bound=tuple[Relation, ...])


def _unify_schemas(relations: Relations) -> Relations:
from ibis.expr.operations.generic import Cast

# TODO: hoist this up into the user facing API so we can see
# all the tables at once and give a better error message
errs = ["Table schemas must be unifiable for set operations."]
first, *rest = relations
all_names = set(first.schema.names)
for relation in rest:
all_names |= set(relation.schema.names)
for relation in relations:
if missing := all_names - set(relation.schema.names):
errs.append(f"Columns missing from {relation}: {missing}")

Check warning on line 345 in ibis/expr/operations/relations.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/operations/relations.py#L345

Added line #L345 was not covered by tests
if len(errs) > 1:
raise RelationError("\n".join(errs))

Check warning on line 347 in ibis/expr/operations/relations.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/operations/relations.py#L347

Added line #L347 was not covered by tests
# Make it so we get consistent column order, using the first relation
names = first.schema.names
assert set(names) == all_names

unified_types: dict[str, dt.DataType] = {}
for name in names:
types = [relation.schema[name] for relation in relations]
try:
unified_types[name] = dt.highest_precedence(types)
except IbisTypeError:
errs.append(f"Unable to find a common dtype for column {name}")
errs.append(f"types: {types}")

Check warning on line 359 in ibis/expr/operations/relations.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/operations/relations.py#L357-L359

Added lines #L357 - L359 were not covered by tests
if len(errs) > 1:
raise RelationError("\n".join(errs))

Check warning on line 361 in ibis/expr/operations/relations.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/operations/relations.py#L361

Added line #L361 was not covered by tests

def get_new_relation(relation: Relation):
cols: dict[str, Value] = {}
unchanged = True
if relation.schema.names != names:
# order is different, will need to reorder
unchanged = False
for name in names:
old_type = relation.schema[name]
new_type = unified_types[name]
f = Field(relation, name)
if old_type == new_type:
cols[name] = f
else:
cols[name] = Cast(f, new_type)
unchanged = False
if unchanged:
return relation
return Project(relation, cols)

return tuple(get_new_relation(relation) for relation in relations)


@public
class Set(Relation):
"""Base class for set operations."""
Expand All @@ -337,26 +391,8 @@ class Set(Relation):
distinct: bool = False
values = FrozenOrderedDict()

def __init__(self, left, right, **kwargs):
err_msg = "Table schemas must be equal for set operations."
try:
missing_from_left = right.schema - left.schema
missing_from_right = left.schema - right.schema
except ConflictingValuesError as e:
raise RelationError(err_msg + "\n" + str(e)) from e
if missing_from_left or missing_from_right:
msgs = [err_msg]
if missing_from_left:
msgs.append(f"Columns missing from the left:\n{missing_from_left}.")
if missing_from_right:
msgs.append(f"Columns missing from the right:\n{missing_from_right}.")
raise RelationError("\n".join(msgs))

if left.schema.names != right.schema.names:
# rewrite so that both sides have the columns in the same order making it
# easier for the backends to implement set operations
cols = {name: Field(right, name) for name in left.schema.names}
right = Project(right, cols)
def __init__(self, left: Relation, right: Relation, **kwargs):
left, right = _unify_schemas((left, right))
super().__init__(left=left, right=right, **kwargs)

@attribute
Expand Down
7 changes: 6 additions & 1 deletion ibis/tests/expr/test_set_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ class D:

@pytest.mark.parametrize("method", ["union", "intersect", "difference"])
def test_operation_requires_equal_schemas(method):
with pytest.raises(RelationError, match="`c`: string != float64"):
with pytest.raises(RelationError) as e:

Check warning on line 43 in ibis/tests/expr/test_set_operations.py

View check run for this annotation

Codecov / codecov/patch

ibis/tests/expr/test_set_operations.py#L43

Added line #L43 was not covered by tests
getattr(a, method)(d)
e_str = str(e.value)
assert "Table schemas must be unifiable for set operations" in e_str
assert "Int64(nullable=True)" in e_str
assert "String(nullable=True)" in e_str
assert "Float64(nulladsble=True)" in e_str

Check warning on line 49 in ibis/tests/expr/test_set_operations.py

View check run for this annotation

Codecov / codecov/patch

ibis/tests/expr/test_set_operations.py#L45-L49

Added lines #L45 - L49 were not covered by tests


@pytest.mark.parametrize("method", ["union", "intersect", "difference"])
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

set_ops_schema_top = [("key", "string"), ("value", "double")]
set_ops_schema_bottom = [("key", "string"), ("key2", "string"), ("value", "double")]
setops_relation_error_message = "Table schemas must be equal for set operations"
setops_relation_error_message = "Table schemas must be unifiable for set operations"


@pytest.fixture
Expand Down
3 changes: 1 addition & 2 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,7 @@ def test_null_column():
def test_null_column_union():
s = ibis.table([("a", "string"), ("b", "double")])
t = ibis.table([("a", "string")])
with pytest.raises(ibis.common.exceptions.RelationError):
s.union(t.mutate(b=ibis.null())) # needs a type
assert s.union(t.mutate(b=ibis.null())).schema() == s.schema()

Check warning on line 644 in ibis/tests/expr/test_value_exprs.py

View check run for this annotation

Codecov / codecov/patch

ibis/tests/expr/test_value_exprs.py#L644

Added line #L644 was not covered by tests
assert s.union(t.mutate(b=ibis.null().cast("double"))).schema() == s.schema()


Expand Down

0 comments on commit e14cdb1

Please sign in to comment.