Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import C
from ibis.backends.sql.rewrites import convert_pandas_udf_to_pyarrow
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowSchema, PyArrowType
Expand Down Expand Up @@ -268,14 +269,17 @@ def _register_udfs(self, expr: ir.Expr) -> None:
if udf_node.__input_type__ == InputType.PYARROW:
udf = self._compile_pyarrow_udf(udf_node)
self.con.register_udf(udf)
if udf_node.__input_type__ == InputType.PANDAS:
udf = self._compile_pandas_udf(udf_node)
self.con.register_udf(udf)

for udf_node in expr.op().find(ops.ElementWiseVectorizedUDF):
udf = self._compile_elementwise_udf(udf_node)
self.con.register_udf(udf)

def _compile_pyarrow_udf(self, udf_node):
def _compile_udf(self, udf_node, func):
return df.udf(
udf_node.__func__,
func,
input_types=[PyArrowType.from_ibis(arg.dtype) for arg in udf_node.args],
return_type=PyArrowType.from_ibis(udf_node.dtype),
volatility=getattr(udf_node, "__config__", {}).get(
Expand All @@ -284,6 +288,13 @@ def _compile_pyarrow_udf(self, udf_node):
name=udf_node.__func_name__,
)

def _compile_pyarrow_udf(self, udf_node):
return self._compile_udf(udf_node, func=udf_node.__func__)

def _compile_pandas_udf(self, udf_node):
pyarrow_udf = convert_pandas_udf_to_pyarrow(udf_node.__func__)
return self._compile_udf(udf_node, func=pyarrow_udf)

def _compile_elementwise_udf(self, udf_node):
return df.udf(
udf_node.func,
Expand Down
25 changes: 17 additions & 8 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import STAR, AlterTable, C, RenameTable
from ibis.backends.sql.rewrites import convert_pandas_udf_to_pyarrow
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType

Expand Down Expand Up @@ -1739,20 +1740,22 @@ def _register_udfs(self, expr: ir.Expr) -> None:
if registration_func is not None:
registration_func(con)

def _register_udf(self, udf_node: ops.ScalarUDF):
def _register_udf(
self,
udf_node: ops.ScalarUDF,
*,
func: callable | None = None,
input_type: InputType | None = None,
):
type_mapper = self.compiler.type_mapper
input_types = [
type_mapper.to_string(param.annotation.pattern.dtype)
for param in udf_node.__signature__.parameters.values()
]

def register_udf(con):
return con.create_function(
name=type(udf_node).__name__,
function=udf_node.__func__,
parameters=input_types,
function=func or udf_node.__func__,
parameters=[type_mapper.to_string(arg.dtype) for arg in udf_node.args],
return_type=type_mapper.to_string(udf_node.dtype),
type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__],
type=_UDF_INPUT_TYPE_MAPPING[input_type or udf_node.__input_type__],
**udf_node.__config__,
)

Expand All @@ -1761,6 +1764,12 @@ def register_udf(con):
_register_python_udf = _register_udf
_register_pyarrow_udf = _register_udf

def _register_pandas_udf(self, pandas_udf_node: ops.ScalarUDF) -> str:
pyarrow_function = convert_pandas_udf_to_pyarrow(pandas_udf_node.__func__)
return self._register_udf(
pandas_udf_node, func=pyarrow_function, input_type=InputType.PYARROW
)

def _get_temp_view_definition(self, name: str, definition: str) -> str:
return sge.Create(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def visit_StandardDev(self, op, *, arg, how, where):

def visit_ScalarUDF(self, op, **kw):
input_type = op.__input_type__
if input_type in (InputType.PYARROW, InputType.BUILTIN):
if input_type in (InputType.PYARROW, InputType.BUILTIN, InputType.PANDAS):
return self.f.anon[op.__func_name__](*kw.values())
else:
raise NotImplementedError(
Expand Down
31 changes: 30 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from __future__ import annotations

import functools
import operator
import sys
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from functools import reduce
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -670,3 +671,31 @@ def argument_replacer(_, y, **kwargs):
return ops.Subtract(y, 1)

return _.copy(body=_.body.replace(argument_replacer))


def convert_pandas_udf_to_pyarrow(pandas_udf: Callable) -> Callable:
"""Convert a pandas UDF to a PyArrow UDF.

This is useful for backends that support PyArrow UDFs but not pandas UDFs.

Parameters
----------
pandas_udf
The pandas UDF to convert.

Returns
-------
A PyArrow UDF that wraps the original pandas UDF.
"""

@functools.wraps(pandas_udf)
def pyarrow_udf(*pa_args, **pa_kwargs):
import pyarrow as pa

pandas_args = [arg.to_pandas() for arg in pa_args]
pandas_kwargs = {k: v.to_pandas() for k, v in pa_kwargs.items()}
pandas_result = pandas_udf(*pandas_args, **pandas_kwargs)
pa_result = pa.Array.from_pandas(pandas_result)
return pa_result

return pyarrow_udf
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pandas,
marks=[
mark.notyet(
["duckdb", "datafusion", "polars", "sqlite"],
["polars", "sqlite"],
raises=NotImplementedError,
reason="backend doesn't support pandas UDFs",
),
Expand Down
14 changes: 11 additions & 3 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,16 @@ class InputType(enum.Enum):
PYTHON = enum.auto()


class _UDFMixin:
__input_type__: InputType
__func__: Callable
__func_name__: str
__config__: FrozenDict
__udf_namespace__: ops.Namespace


@public
class ScalarUDF(ops.Impure):
class ScalarUDF(ops.Impure, _UDFMixin):
@attribute
def shape(self):
if not (args := getattr(self, "args")): # noqa: B009
Expand All @@ -65,7 +73,7 @@ def shape(self):


@public
class AggUDF(ops.Reduction, ops.Impure):
class AggUDF(ops.Reduction, ops.Impure, _UDFMixin):
where: Optional[ops.Value[dt.Boolean]] = None


Expand Down Expand Up @@ -479,7 +487,7 @@ def pandas(
... def str_cap(x: str) -> str:
... # note usage of pandas `str` method
... return x.str.capitalize()
>>> str_cap(t.str_col) # doctest: +SKIP
>>> str_cap(t.str_col)
┏━━━━━━━━━━━━━━━━━━━━━━━┓
┃ string_cap_0(str_col) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━┩
Expand Down
Loading