Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: upcast schemas during set ops #10727

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t2"
UNION ALL
SELECT
*
FROM (
SELECT
"t0"."id",
"t0"."bigint_col" + 256 AS "i",
"t0"."string_col" AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t1"
) AS "t3"
ORDER BY
"t3"."id" ASC,
"t3"."i" ASC,
"t3"."s" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t2"
UNION
SELECT
*
FROM (
SELECT
"t0"."id",
"t0"."bigint_col" + 256 AS "i",
"t0"."string_col" AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t1"
) AS "t3"
ORDER BY
"t3"."id" ASC,
"t3"."i" ASC,
"t3"."s" ASC
17 changes: 17 additions & 0 deletions ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,20 @@ def test_ctes_in_order():

sql = ibis.to_sql(expr, dialect="duckdb")
assert sql.find('"first" AS (') < sql.find('"second" AS (')


@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
def test_union_unified_schemas(snapshot, functional_alltypes, distinct):
a = functional_alltypes.select(
"id", i="tinyint_col", s=_.string_col.cast("!string")
)
b = functional_alltypes.select(
"id",
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
s=_.string_col.cast("string"),
)
expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")

assert expr.i.type() == b.i.type()
assert expr.s.type() == b.s.type()
snapshot.assert_match(to_sql(expr), "out.sql")
39 changes: 38 additions & 1 deletion ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import ibis
import ibis.expr.types as ir
from ibis import _
from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError
from ibis.backends.tests.errors import (
OracleDatabaseError,
PsycoPg2InternalError,
PyDruidProgrammingError,
)

pd = pytest.importorskip("pandas")

Expand Down Expand Up @@ -49,6 +53,39 @@ def test_union(backend, union_subsets, distinct):
backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
@pytest.mark.notyet(
["oracle"], raises=OracleDatabaseError, reason="does not support NOT NULL types"
)
def test_unified_schemas(backend, con, distinct):
a = con.table("functional_alltypes").select(
"id",
i="tinyint_col",
s=_.string_col.cast("!string"),
)
b = con.table("functional_alltypes").select(
"id",
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
s=_.string_col.cast("string"),
)

expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")
assert expr.i.type() == b.i.type()
assert expr.s.type() == b.s.type()
result = expr.execute()

expected = (
pd.concat([a.execute(), b.execute()], axis=0)
.sort_values(["id", "i", "s"])
.reset_index(drop=True)
)
if distinct:
expected = expected.drop_duplicates(["id", "i", "s"])

backend.assert_frame_equal(result, expected, check_dtype=False)


@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_union_mixed_distinct(backend, union_subsets):
(a, b, c), (da, db, dc) = union_subsets
Expand Down
78 changes: 57 additions & 21 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import ibis.expr.datatypes as dt
from ibis.common.annotations import attribute
from ibis.common.collections import (
ConflictingValuesError,
FrozenDict,
FrozenOrderedDict,
)
Expand Down Expand Up @@ -328,6 +327,61 @@ def schema(self):
return Schema({k: v.dtype for k, v in self.values.items()})


Relations = TypeVar("Relations", bound=tuple[Relation, ...])


def _unify_schemas(relations: Relations) -> Relations:
from ibis.expr.operations.generic import Cast

# TODO: hoist this up into the user facing API so we can see
# all the tables at once and give a better error message
errs = ["Table schemas must be unifiable for set operations."]
first, *rest = relations
all_names = set(first.schema.names)
for relation in rest:
all_names |= set(relation.schema.names)
for relation in relations:
if missing := all_names - set(relation.schema.names):
errs.append(f"Columns missing from {relation}: {missing}")
if len(errs) > 1:
raise RelationError("\n".join(errs))
# Make it so we get consistent column order, using the first relation
names = first.schema.names
assert set(names) == all_names

unified_types: dict[str, dt.DataType] = {}
for name in names:
types = [relation.schema[name] for relation in relations]
try:
unified_types[name] = dt.highest_precedence(types)
except IbisTypeError:
errs.append(f"Unable to find a common dtype for column {name}")
errs.append(f"types: {types}")
if len(errs) > 1:
raise RelationError("\n".join(errs))

def get_new_relation(relation: Relation):
cols: dict[str, Value] = {}
unchanged = True
if relation.schema.names != names:
# order is different, will need to reorder
unchanged = False
for name in names:
old_type = relation.schema[name]
new_type = unified_types[name]
f = Field(relation, name)
if old_type == new_type:
cols[name] = f
else:
cols[name] = Cast(f, new_type)
unchanged = False
if unchanged:
return relation
return Project(relation, cols)

return tuple(get_new_relation(relation) for relation in relations)


@public
class Set(Relation):
"""Base class for set operations."""
Expand All @@ -337,26 +391,8 @@ class Set(Relation):
distinct: bool = False
values = FrozenOrderedDict()

def __init__(self, left, right, **kwargs):
err_msg = "Table schemas must be equal for set operations."
try:
missing_from_left = right.schema - left.schema
missing_from_right = left.schema - right.schema
except ConflictingValuesError as e:
raise RelationError(err_msg + "\n" + str(e)) from e
if missing_from_left or missing_from_right:
msgs = [err_msg]
if missing_from_left:
msgs.append(f"Columns missing from the left:\n{missing_from_left}.")
if missing_from_right:
msgs.append(f"Columns missing from the right:\n{missing_from_right}.")
raise RelationError("\n".join(msgs))

if left.schema.names != right.schema.names:
# rewrite so that both sides have the columns in the same order making it
# easier for the backends to implement set operations
cols = {name: Field(right, name) for name in left.schema.names}
right = Project(right, cols)
def __init__(self, left: Relation, right: Relation, **kwargs):
left, right = _unify_schemas((left, right))
super().__init__(left=left, right=right, **kwargs)

@attribute
Expand Down
7 changes: 6 additions & 1 deletion ibis/tests/expr/test_set_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ class D:

@pytest.mark.parametrize("method", ["union", "intersect", "difference"])
def test_operation_requires_equal_schemas(method):
with pytest.raises(RelationError, match="`c`: string != float64"):
with pytest.raises(RelationError) as e:
getattr(a, method)(d)
e_str = str(e.value)
assert "Table schemas must be unifiable for set operations" in e_str
assert "Int64(nullable=True)" in e_str
assert "String(nullable=True)" in e_str
assert "Float64(nullable=True)" in e_str


@pytest.mark.parametrize("method", ["union", "intersect", "difference"])
Expand Down
44 changes: 9 additions & 35 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,26 @@
from ibis.expr.types import Column, Table
from ibis.tests.util import assert_equal, assert_pickle_roundtrip


@pytest.fixture
def set_ops_schema_top():
return [("key", "string"), ("value", "double")]


@pytest.fixture
def set_ops_schema_bottom():
return [("key", "string"), ("key2", "string"), ("value", "double")]
set_ops_schema_top = [("key", "string"), ("value", "double")]
set_ops_schema_bottom = [("key", "string"), ("key2", "string"), ("value", "double")]
setops_relation_error_message = "Table schemas must be unifiable for set operations"


@pytest.fixture
def setops_table_foo(set_ops_schema_top):
def setops_table_foo():
return ibis.table(set_ops_schema_top, "foo")


@pytest.fixture
def setops_table_bar(set_ops_schema_top):
def setops_table_bar():
return ibis.table(set_ops_schema_top, "bar")


@pytest.fixture
def setops_table_baz(set_ops_schema_bottom):
def setops_table_baz():
return ibis.table(set_ops_schema_bottom, "baz")


@pytest.fixture
def setops_relation_error_message():
return "Table schemas must be equal for set operations"


def test_empty_schema():
table = api.table([], "foo")
assert not table.schema()
Expand Down Expand Up @@ -1350,12 +1339,7 @@ def test_unravel_compound_equijoin(table):
assert joined.op() == expected


def test_union(
setops_table_foo,
setops_table_bar,
setops_table_baz,
setops_relation_error_message,
):
def test_union(setops_table_foo, setops_table_bar, setops_table_baz: Table):
result = setops_table_foo.union(setops_table_bar)
assert isinstance(result.op(), ops.Union)
assert not result.op().distinct
Expand All @@ -1367,25 +1351,15 @@ def test_union(
setops_table_foo.union(setops_table_baz)


def test_intersection(
setops_table_foo,
setops_table_bar,
setops_table_baz,
setops_relation_error_message,
):
def test_intersection(setops_table_foo, setops_table_bar, setops_table_baz):
result = setops_table_foo.intersect(setops_table_bar)
assert isinstance(result.op(), ops.Intersection)

with pytest.raises(RelationError, match=setops_relation_error_message):
setops_table_foo.intersect(setops_table_baz)


def test_difference(
setops_table_foo,
setops_table_bar,
setops_table_baz,
setops_relation_error_message,
):
def test_difference(setops_table_foo, setops_table_bar, setops_table_baz):
result = setops_table_foo.difference(setops_table_bar)
assert isinstance(result.op(), ops.Difference)

Expand Down
3 changes: 1 addition & 2 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,7 @@ def test_null_column():
def test_null_column_union():
s = ibis.table([("a", "string"), ("b", "double")])
t = ibis.table([("a", "string")])
with pytest.raises(ibis.common.exceptions.RelationError):
s.union(t.mutate(b=ibis.null())) # needs a type
assert s.union(t.mutate(b=ibis.null())).schema() == s.schema()
assert s.union(t.mutate(b=ibis.null().cast("double"))).schema() == s.schema()


Expand Down