Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions crates/core/src/udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,33 @@
use std::ptr::NonNull;
use std::sync::Arc;

use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
use datafusion::error::Result as DataFusionResult;
use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl, TableProvider};
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::execution::context::SessionContext;
use datafusion::execution::session_state::SessionState;
use datafusion::logical_expr::Expr;
use datafusion_ffi::udtf::FFI_TableFunction;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyImportError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple, PyType};
use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType};

use crate::context::PySessionContext;
use crate::errors::{py_datafusion_err, to_datafusion_err};
use crate::expr::PyExpr;
use crate::table::PyTable;

/// A pure-Python UDTF callable plus the metadata we discovered about it
/// at registration time.
#[derive(Debug, Clone)]
pub(crate) struct PythonTableFunctionCallable {
pub(crate) callable: Arc<Py<PyAny>>,
/// Whether the callable's signature accepts a ``session`` keyword
/// argument (or ``**kwargs``). When true the calling
/// :class:`SessionContext` is threaded through on each invocation.
pub(crate) accepts_session: bool,
}

/// Represents a user defined table function
#[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")]
#[derive(Debug, Clone)]
Expand All @@ -40,21 +53,21 @@ pub struct PyTableFunction {
pub(crate) inner: PyTableFunctionInner,
}

// TODO: Implement pure python based user defined table functions
#[derive(Debug, Clone)]
pub(crate) enum PyTableFunctionInner {
PythonFunction(Arc<Py<PyAny>>),
PythonFunction(PythonTableFunctionCallable),
FFIFunction(Arc<dyn TableFunctionImpl>),
}

#[pymethods]
impl PyTableFunction {
#[new]
#[pyo3(signature=(name, func, session))]
#[pyo3(signature=(name, func, session, accepts_session=false))]
pub fn new(
name: &str,
func: Bound<'_, PyAny>,
session: Option<Bound<PyAny>>,
accepts_session: bool,
) -> PyResult<Self> {
let inner = if func.hasattr("__datafusion_table_function__")? {
let py = func.py();
Expand All @@ -80,8 +93,10 @@ impl PyTableFunction {

PyTableFunctionInner::FFIFunction(foreign_func)
} else {
let py_obj = Arc::new(func.unbind());
PyTableFunctionInner::PythonFunction(py_obj)
PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable {
callable: Arc::new(func.unbind()),
accepts_session,
})
};

Ok(Self {
Expand All @@ -107,20 +122,59 @@ impl PyTableFunction {
}
}

/// Materialize a fresh :class:`PySessionContext` from the borrowed
/// ``&dyn Session`` handed in at call time.
///
/// Upstream invokes ``call_with_args`` with a trait-object reference
/// rather than an owned context; we downcast it to the canonical
/// :class:`SessionState` impl and rebuild a :class:`SessionContext`
/// (sharing the same registries via the Arc-heavy interior of
/// :class:`SessionState`). Returns an error if the trait object is a
/// non-:class:`SessionState` implementation (e.g. a foreign FFI
/// session) — those are not exposed to Python today.
fn py_session_from_session(session: &dyn Session) -> DataFusionResult<PySessionContext> {
let state = session
.as_any()
.downcast_ref::<SessionState>()
.ok_or_else(|| {
DataFusionError::Execution(
"Cannot expose this UDTF's calling session to Python: \
the session is not a SessionState. Drop the `session` \
keyword from the callback signature to fall back to the \
expression-only call form."
.to_string(),
)
})?;
Ok(PySessionContext::from(SessionContext::new_with_state(
state.clone(),
)))
}

#[allow(clippy::result_large_err)]
fn call_python_table_function(
func: &Arc<Py<PyAny>>,
args: &[Expr],
func: &PythonTableFunctionCallable,
args: TableFunctionArgs,
) -> DataFusionResult<Arc<dyn TableProvider>> {
let args = args
let py_session = if func.accepts_session {
Some(py_session_from_session(args.session())?)
} else {
None
};
let py_exprs = args
.exprs()
.iter()
.map(|arg| PyExpr::from(arg.clone()))
.collect::<Vec<_>>();

// move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::attach(|py| {
let py_args = PyTuple::new(py, args)?;
let provider_obj = func.call1(py, py_args)?;
let py_args = PyTuple::new(py, py_exprs)?;
let provider_obj = if let Some(session) = py_session {
let kwargs = PyDict::new(py);
kwargs.set_item("session", session.into_pyobject(py)?)?;
func.callable.call(py, py_args, Some(&kwargs))?
} else {
func.callable.call1(py, py_args)?
};
let provider = provider_obj.bind(py).clone();

Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider, None)?.table)
Expand All @@ -132,8 +186,8 @@ impl TableFunctionImpl for PyTableFunction {
fn call_with_args(&self, args: TableFunctionArgs) -> DataFusionResult<Arc<dyn TableProvider>> {
match &self.inner {
PyTableFunctionInner::FFIFunction(func) => func.call_with_args(args),
PyTableFunctionInner::PythonFunction(obj) => {
call_python_table_function(obj, args.exprs())
PyTableFunctionInner::PythonFunction(callable) => {
call_python_table_function(callable, args)
}
}
}
Expand Down
52 changes: 51 additions & 1 deletion python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,47 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF:
)


def _callable_accepts_session_kwarg(func: object) -> bool:
"""Return True if ``func`` accepts a ``session`` keyword argument.

Used to opt a Python UDTF callback into receiving the calling
:class:`SessionContext` at invocation time. ``**kwargs`` callables
are treated as accepting it; built-ins and objects without an
introspectable signature fall back to ``False``.
"""
import inspect # noqa: PLC0415

try:
signature = inspect.signature(func)
except (TypeError, ValueError):
return False

for parameter in signature.parameters.values():
if parameter.name == "session":
return True
if parameter.kind is inspect.Parameter.VAR_KEYWORD:
return True
return False


def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]:
"""Adapt the raw internal session pyo3 object back to a Python wrapper.

The Rust call site forwards a ``datafusion._internal.SessionContext``,
but UDTF authors expect to interact with the public
:class:`datafusion.SessionContext` wrapper. This closure wraps the
internal object once per call before delegating to ``func``.
"""

@functools.wraps(func, updated=())
def adapter(*args: Any, session: Any, **kwargs: Any) -> Any:
wrapped = SessionContext.__new__(SessionContext)
wrapped.ctx = session
return func(*args, session=wrapped, **kwargs)

return adapter


class TableFunction:
"""Class for performing user-defined table functions (UDTF).

Expand All @@ -1066,10 +1107,19 @@ def __init__(
) -> None:
"""Instantiate a user-defined table function (UDTF).

If ``func``'s signature accepts a ``session`` keyword (or
``**kwargs``), the calling :class:`SessionContext` is threaded
through to it on each invocation. Use it inside the body to look
up registered tables, UDFs, or session configuration. Callables
whose signatures do not declare ``session`` are invoked with the
positional expression arguments only.

See :py:func:`udtf` for a convenience function and argument
descriptions.
"""
self._udtf = df_internal.TableFunction(name, func, ctx)
accepts_session = _callable_accepts_session_kwarg(func)
registered = _wrap_session_kwarg_for_udtf(func) if accepts_session else func
self._udtf = df_internal.TableFunction(name, registered, ctx, accepts_session)

def __call__(self, *args: Expr) -> Any:
"""Execute the UDTF and return a table provider."""
Expand Down
65 changes: 65 additions & 0 deletions python/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,68 @@ def string_arg_func(prefix: Expr) -> TableProviderExportable:
result = ctx.sql("SELECT * FROM string_arg_func('test')").collect()
assert len(result) == 1
assert result[0].schema.names == ["test_a", "test_b"]


def test_python_table_function_receives_session() -> None:
"""A UDTF whose signature declares ``session`` gets the calling ctx."""
ctx = SessionContext()
captured: list[SessionContext] = []

@udtf("session_aware_func")
def session_aware_func(*, session: SessionContext) -> TableProviderExportable:
captured.append(session)
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
return Table(ds.dataset([batch]))

ctx.register_udtf(session_aware_func)
result = ctx.sql("SELECT * FROM session_aware_func()").collect()

assert len(captured) == 1
assert isinstance(captured[0], SessionContext)
# Sharing the same catalog confirms the wrapper points at the caller's state.
assert captured[0].catalog().schema().names() == ctx.catalog().schema().names()
assert result[0].column(0).to_pylist() == [1, 2, 3]


def test_python_table_function_session_used_for_metadata() -> None:
"""The UDTF can inspect session state through the passed-in context."""
ctx = SessionContext()
base_batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]})
ctx.register_batch("base_tbl", base_batch)

seen_tables: list[set[str]] = []

@udtf("table_inventory")
def table_inventory(*, session: SessionContext) -> TableProviderExportable:
# Stash the visible tables to verify the session wired through.
seen_tables.append(session.catalog().schema().names())
batch = pa.RecordBatch.from_pydict({"name": ["base_tbl"]})
return Table(ds.dataset([batch]))

ctx.register_udtf(table_inventory)
result = ctx.sql("SELECT * FROM table_inventory()").collect()

assert seen_tables == [{"base_tbl"}]
assert result[0].column(0).to_pylist() == ["base_tbl"]


def test_python_table_function_class_callable_session_kwarg() -> None:
"""Class-based UDTFs whose __call__ accepts ``session`` get it too."""
ctx = SessionContext()
captured: list[SessionContext] = []

class SessionAware:
def __call__(
self, n: Expr, *, session: SessionContext
) -> TableProviderExportable:
captured.append(session)
count = n.to_variant().value_i64()
batch = pa.RecordBatch.from_pydict({"a": list(range(count))})
return Table(ds.dataset([batch]))

ctx.register_udtf(udtf(SessionAware(), "session_class_func"))
result = ctx.sql("SELECT * FROM session_class_func(3)").collect()

assert len(captured) == 1
assert isinstance(captured[0], SessionContext)
assert result[0].column(0).to_pylist() == [0, 1, 2]
Loading