Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def to_pyarrow_batches(
self._register_in_memory_tables(expr)

table_expr = expr.as_table()
table_expr = _cast_any_uuids_to_blob(table_expr)
raw_sql = self.compile(table_expr, **kwargs)

frame = self.con.sql(raw_sql)
Expand Down Expand Up @@ -806,3 +807,10 @@ def _pandas(source: pd.DataFrame, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pandas")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_pandas(source, name=tmp_name)


def _cast_any_uuids_to_blob(t: ir.Table) -> ir.Table:
"""When duckdb materializes UUIDs to arrow, by default it returns them as pa.string()s. This pre-converts them to BLOB."""
return t.mutate(
**{col: t[col].cast("binary") for col in t.columns if t[col].type().is_uuid()}
)
8 changes: 8 additions & 0 deletions ibis/backends/datafusion/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,11 @@ def regex_split(s: str, pattern: str) -> list[str]:
)
pattern = patterns[0].as_py()
return pc.split_pattern_regex(s, pattern)


def cast_uuid_to_binary(string_array: dt.string) -> dt.binary:
without_dashes = pc.replace_substring(string_array, "-", "")
with_leading_0x = pc.binary_join_element_wise("0x", without_dashes, "")
# WIP: pa can cast from hex string to ints, but not binary :(
# https://github.com/apache/arrow/commit/012248a7aca9373a39d5ac8ce9e496b8df0f10e6
return pc.cast(with_leading_0x, pa.binary(16))
8 changes: 8 additions & 0 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ def _to_duckdb_relation(
"""
self._run_pre_execute_hooks(expr)
table_expr = expr.as_table()
table_expr = _cast_any_uuids_to_blob(table_expr)
sql = self.compile(table_expr, limit=limit, params=params, **kwargs)
if table_expr.schema().geospatial:
self._load_extensions(["spatial"])
Expand Down Expand Up @@ -1798,3 +1799,10 @@ def _pyarrow_rbr(source, table_name, _conn, **_: Any):
# Ensure the reader isn't marked as started, in case the name is
# being overwritten.
_conn._record_batch_readers_consumed[table_name] = False


def _cast_any_uuids_to_blob(t: ir.Table) -> ir.Table:
"""When duckdb materializes UUIDs to arrow, by default it returns them as pa.string()s. This pre-converts them to BLOB."""
return t.mutate(
**{col: t[col].cast("binary") for col in t.columns if t[col].type().is_uuid()}
)
15 changes: 12 additions & 3 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,18 @@ def _batches(self: Self, *, struct_type: pa.StructType, query: str):
):
cur = cursor.execute(query)
while batch := cur.fetchmany(chunk_size):
yield pa.RecordBatch.from_struct_array(
pa.array(batch, type=struct_type)
)
columns = []
names = []
for i, (name, typ) in enumerate(raw_schema.items()):
col = [row[i] for row in batch]
if typ.is_uuid():
col = [v.bytes if v is not None else None for v in col]
columns.append(pa.array(col, type=typ.to_pyarrow()))
names.append(name)
# pa.array(batch, raw_schema.as_struct().to_pyarrow())
# is not implemented for extension types (eg UUID)
# so we have to create individual arrays for each column.
yield pa.RecordBatch.from_arrays(columns, names=names)

self._run_pre_execute_hooks(expr)

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def visit_Cast(self, op, *, arg, to):
from ibis.formats.pyarrow import PyArrowType

return self.f.arrow_cast(arg, f"{PyArrowType.from_ibis(to)}".capitalize())
if from_.is_uuid() and to.is_binary():
return self.f.cast_uuid_to_binary(arg)
return self.cast(arg, to)

def visit_Arbitrary(self, op, *, arg, where):
Expand Down
16 changes: 12 additions & 4 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,17 +421,25 @@ def visit_TimestampFromYMDHMS(
return self.f[func](*args)

def visit_Cast(self, op, *, arg, to):
dtype = op.arg.dtype
from_ = op.arg.dtype
if to.is_interval():
func = self.f[f"to_{_INTERVAL_SUFFIXES[to.unit.short]}"]
return func(sg.cast(arg, to=self.type_mapper.from_ibis(dt.int32)))
elif to.is_timestamp() and dtype.is_numeric():
elif to.is_timestamp() and from_.is_numeric():
return self.f.to_timestamp(arg)
elif to.is_geospatial():
if dtype.is_binary():
if from_.is_binary():
return self.f.st_geomfromwkb(arg)
elif dtype.is_string():
elif from_.is_string():
return self.f.st_geomfromtext(arg)
elif from_.is_uuid() and to.is_binary():
# In duckdb <=1.3, must do cast(replace(cast(uuid_val AS VARCHAR), '-', '') AS BLOB)
# Once https://github.com/duckdb/duckdb/pull/18027 is released (duckdb 1.4??)
# this can be simplified to `CAST(uuid_val AS BLOB)`
hex_string = self.f.replace(
sg.cast(arg, to=self.type_mapper.from_ibis(dt.string)), "-", ""
)
return self.f.unhex(hex_string)

return self.cast(arg, to)

Expand Down
37 changes: 28 additions & 9 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import functools
import sqlite3
import uuid
from typing import TYPE_CHECKING, Any

import sqlglot as sg
Expand Down Expand Up @@ -37,6 +38,7 @@
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self


@functools.cache
Expand Down Expand Up @@ -153,7 +155,7 @@ def _post_connect(
register_all(self.con)
self.con.execute("PRAGMA case_sensitive_like = ON")

def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> sqlite3.Cursor:
if not isinstance(query, str):
query = query.sql(dialect=self.name)
return self.con.execute(query, **kwargs)
Expand Down Expand Up @@ -345,17 +347,34 @@ def to_pyarrow_batches(
) -> pa.ipc.RecordBatchReader:
import pyarrow as pa

raw_schema = expr.as_table().schema()

def _batches(*, query: str):
with self._safe_raw_sql(query) as cursor:
while batch := cursor.fetchmany(chunk_size):
columns = []
names = []
for i, (name, typ) in enumerate(raw_schema.items()):
col = [row[i] for row in batch]
if typ.is_uuid():
col = [
uuid.UUID(v).bytes if v is not None else None
for v in col
]
columns.append(pa.array(col, type=typ.to_pyarrow()))
names.append(name)
# pa.array(batch, raw_schema.as_struct().to_pyarrow())
# is not implemented for extension types (eg UUID)
# so we have to create individual arrays for each column.
yield pa.RecordBatch.from_arrays(columns, names=names)

self._run_pre_execute_hooks(expr)

schema = expr.as_table().schema()
with self._safe_raw_sql(
self.compile(expr, limit=limit, params=params)
) as cursor:
df = self._fetch_from_cursor(cursor, schema)
table = pa.Table.from_pandas(
df, schema=schema.to_pyarrow(), preserve_index=False
query = self.compile(expr, limit=limit, params=params)
return pa.RecordBatchReader.from_batches(
raw_schema.to_pyarrow(),
_batches(query=query),
)
return table.to_reader(max_chunksize=chunk_size)

def _generate_create_table(self, table: sge.Table, schema: sch.Schema):
target = sge.Schema(
Expand Down
26 changes: 26 additions & 0 deletions ibis/backends/tests/test_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,29 @@ def test_uuid_unique_each_row(con):
con.tables.functional_alltypes.mutate(uuid=ibis.uuid()).limit(2).uuid.nunique()
)
assert expr.execute() == 2


@pytest.mark.notimpl(
["polars"],
raises=NotImplementedError,
)
def test_uuid_scalar_to_pyarrow(con):
expr = ibis.uuid(TEST_UUID)
result = con.to_pyarrow(expr)
assert result.type.extension_name == "arrow.uuid"
result_python = result.as_py()
assert result_python == TEST_UUID


@pytest.mark.notimpl(
["polars"],
raises=NotImplementedError,
)
def test_uuid_column_to_pyarrow(con):
expr = con.tables.functional_alltypes.mutate(uuid=ibis.uuid()).limit(2).uuid
result = con.to_pyarrow(expr)
assert result.type.extension_name == "arrow.uuid"
result_python = result.to_pylist()
assert len(result_python) == 2
assert isinstance(result_python[0], uuid.UUID)
assert isinstance(result_python[1], uuid.UUID)
26 changes: 25 additions & 1 deletion ibis/formats/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@
import pyarrow.dataset as ds


if pa.__version__ >= "18.0.0":
uuid_type = pa.uuid()
else:

class UUIDType(pa.ExtensionType):
def __init__(self):
super().__init__(pa.binary(16), "arrow.uuid")

def __arrow_ext_serialize__(self) -> bytes:
# No parameters are necessary
return b""

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls()

uuid_type = UUIDType()
pa.register_extension_type(uuid_type)

_from_pyarrow_types = {
pa.int8(): dt.Int8,
pa.int16(): dt.Int16,
Expand Down Expand Up @@ -69,7 +88,7 @@
dt.Unknown: pa.string(),
dt.MACADDR: pa.string(),
dt.INET: pa.string(),
dt.UUID: pa.string(),
dt.UUID: uuid_type,
dt.JSON: pa.string(),
}

Expand Down Expand Up @@ -108,6 +127,11 @@ def to_ibis(cls, typ: pa.DataType, nullable=True) -> dt.DataType:
return dt.Map(key_dtype, value_dtype, nullable=nullable)
elif pa.types.is_dictionary(typ):
return cls.to_ibis(typ.value_type)
elif getattr(typ, "extension_name", None) == "arrow.uuid":
return dt.UUID(nullable=nullable)
# TODO: should this be
# elif getattr(typ, "extension_name", "").startswith("geoarrow."):
# to be agnostic to the package that actually implements the extension type?
elif (
isinstance(typ, pa.ExtensionType)
and type(typ).__module__ == "geoarrow.types.type_pyarrow"
Expand Down
Loading