Skip to content

Commit b57325f

Browse files
timsaucerclaude
andcommitted
feat: pass calling SessionContext to Python UDTF callbacks
DataFusion 53 added `TableFunctionImpl::call_with_args(TableFunctionArgs)` where `TableFunctionArgs` carries both the positional expression arguments and the calling `&dyn Session`. The pure-Python UDTF path previously discarded everything but the exprs. Thread the session through when the user callback's signature opts in by declaring a `session` keyword parameter (or `**kwargs`). At call time we downcast the `&dyn Session` to its canonical `SessionState` impl and build a fresh `SessionContext` over the same Arc-shared state, exposed to Python as a `datafusion.SessionContext` wrapper. Existing callbacks whose signatures do not declare `session` continue to be called with the positional expression arguments only — no behavior change for current users. Note: a UDTF body cannot drive a fresh `ctx.sql(...).collect()` on the passed-in session because the outer SQL execution already holds the tokio runtime. Use the session for metadata access (catalogs, UDF lookups, config) rather than nested DataFrame collection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent dac9ec6 commit b57325f

3 files changed

Lines changed: 186 additions & 17 deletions

File tree

crates/core/src/udtf.rs

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,33 @@
1818
use std::ptr::NonNull;
1919
use std::sync::Arc;
2020

21-
use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
22-
use datafusion::error::Result as DataFusionResult;
21+
use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl, TableProvider};
22+
use datafusion::error::{DataFusionError, Result as DataFusionResult};
23+
use datafusion::execution::context::SessionContext;
24+
use datafusion::execution::session_state::SessionState;
2325
use datafusion::logical_expr::Expr;
2426
use datafusion_ffi::udtf::FFI_TableFunction;
2527
use pyo3::IntoPyObjectExt;
2628
use pyo3::exceptions::{PyImportError, PyTypeError};
2729
use pyo3::prelude::*;
28-
use pyo3::types::{PyCapsule, PyTuple, PyType};
30+
use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType};
2931

3032
use crate::context::PySessionContext;
3133
use crate::errors::{py_datafusion_err, to_datafusion_err};
3234
use crate::expr::PyExpr;
3335
use crate::table::PyTable;
3436

37+
/// A pure-Python UDTF callable plus the metadata we discovered about it
38+
/// at registration time.
39+
#[derive(Debug, Clone)]
40+
pub(crate) struct PythonTableFunctionCallable {
41+
pub(crate) callable: Arc<Py<PyAny>>,
42+
/// Whether the callable's signature accepts a ``session`` keyword
43+
/// argument (or ``**kwargs``). When true the calling
44+
/// :class:`SessionContext` is threaded through on each invocation.
45+
pub(crate) accepts_session: bool,
46+
}
47+
3548
/// Represents a user defined table function
3649
#[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")]
3750
#[derive(Debug, Clone)]
@@ -40,21 +53,21 @@ pub struct PyTableFunction {
4053
pub(crate) inner: PyTableFunctionInner,
4154
}
4255

43-
// TODO: Implement pure python based user defined table functions
4456
#[derive(Debug, Clone)]
4557
pub(crate) enum PyTableFunctionInner {
46-
PythonFunction(Arc<Py<PyAny>>),
58+
PythonFunction(PythonTableFunctionCallable),
4759
FFIFunction(Arc<dyn TableFunctionImpl>),
4860
}
4961

5062
#[pymethods]
5163
impl PyTableFunction {
5264
#[new]
53-
#[pyo3(signature=(name, func, session))]
65+
#[pyo3(signature=(name, func, session, accepts_session=false))]
5466
pub fn new(
5567
name: &str,
5668
func: Bound<'_, PyAny>,
5769
session: Option<Bound<PyAny>>,
70+
accepts_session: bool,
5871
) -> PyResult<Self> {
5972
let inner = if func.hasattr("__datafusion_table_function__")? {
6073
let py = func.py();
@@ -80,8 +93,10 @@ impl PyTableFunction {
8093

8194
PyTableFunctionInner::FFIFunction(foreign_func)
8295
} else {
83-
let py_obj = Arc::new(func.unbind());
84-
PyTableFunctionInner::PythonFunction(py_obj)
96+
PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable {
97+
callable: Arc::new(func.unbind()),
98+
accepts_session,
99+
})
85100
};
86101

87102
Ok(Self {
@@ -107,20 +122,59 @@ impl PyTableFunction {
107122
}
108123
}
109124

125+
/// Materialize a fresh :class:`PySessionContext` from the borrowed
126+
/// ``&dyn Session`` handed in at call time.
127+
///
128+
/// Upstream invokes ``call_with_args`` with a trait-object reference
129+
/// rather than an owned context; we downcast it to the canonical
130+
/// :class:`SessionState` impl and rebuild a :class:`SessionContext`
131+
/// (sharing the same registries via the Arc-heavy interior of
132+
/// :class:`SessionState`). Returns an error if the trait object is a
133+
/// non-:class:`SessionState` implementation (e.g. a foreign FFI
134+
/// session) — those are not exposed to Python today.
135+
fn py_session_from_session(session: &dyn Session) -> DataFusionResult<PySessionContext> {
136+
let state = session
137+
.as_any()
138+
.downcast_ref::<SessionState>()
139+
.ok_or_else(|| {
140+
DataFusionError::Execution(
141+
"Cannot expose this UDTF's calling session to Python: \
142+
the session is not a SessionState. Drop the `session` \
143+
keyword from the callback signature to fall back to the \
144+
expression-only call form."
145+
.to_string(),
146+
)
147+
})?;
148+
Ok(PySessionContext::from(SessionContext::new_with_state(
149+
state.clone(),
150+
)))
151+
}
152+
110153
#[allow(clippy::result_large_err)]
111154
fn call_python_table_function(
112-
func: &Arc<Py<PyAny>>,
113-
args: &[Expr],
155+
func: &PythonTableFunctionCallable,
156+
args: TableFunctionArgs,
114157
) -> DataFusionResult<Arc<dyn TableProvider>> {
115-
let args = args
158+
let py_session = if func.accepts_session {
159+
Some(py_session_from_session(args.session())?)
160+
} else {
161+
None
162+
};
163+
let py_exprs = args
164+
.exprs()
116165
.iter()
117166
.map(|arg| PyExpr::from(arg.clone()))
118167
.collect::<Vec<_>>();
119168

120-
// move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
121169
Python::attach(|py| {
122-
let py_args = PyTuple::new(py, args)?;
123-
let provider_obj = func.call1(py, py_args)?;
170+
let py_args = PyTuple::new(py, py_exprs)?;
171+
let provider_obj = if let Some(session) = py_session {
172+
let kwargs = PyDict::new(py);
173+
kwargs.set_item("session", session.into_pyobject(py)?)?;
174+
func.callable.call(py, py_args, Some(&kwargs))?
175+
} else {
176+
func.callable.call1(py, py_args)?
177+
};
124178
let provider = provider_obj.bind(py).clone();
125179

126180
Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider, None)?.table)
@@ -132,8 +186,8 @@ impl TableFunctionImpl for PyTableFunction {
132186
fn call_with_args(&self, args: TableFunctionArgs) -> DataFusionResult<Arc<dyn TableProvider>> {
133187
match &self.inner {
134188
PyTableFunctionInner::FFIFunction(func) => func.call_with_args(args),
135-
PyTableFunctionInner::PythonFunction(obj) => {
136-
call_python_table_function(obj, args.exprs())
189+
PyTableFunctionInner::PythonFunction(callable) => {
190+
call_python_table_function(callable, args)
137191
}
138192
}
139193
}

python/datafusion/user_defined.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,47 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF:
10541054
)
10551055

10561056

1057+
def _callable_accepts_session_kwarg(func: object) -> bool:
1058+
"""Return True if ``func`` accepts a ``session`` keyword argument.
1059+
1060+
Used to opt a Python UDTF callback into receiving the calling
1061+
:class:`SessionContext` at invocation time. ``**kwargs`` callables
1062+
are treated as accepting it; built-ins and objects without an
1063+
introspectable signature fall back to ``False``.
1064+
"""
1065+
import inspect # noqa: PLC0415
1066+
1067+
try:
1068+
signature = inspect.signature(func)
1069+
except (TypeError, ValueError):
1070+
return False
1071+
1072+
for parameter in signature.parameters.values():
1073+
if parameter.name == "session":
1074+
return True
1075+
if parameter.kind is inspect.Parameter.VAR_KEYWORD:
1076+
return True
1077+
return False
1078+
1079+
1080+
def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]:
1081+
"""Adapt the raw internal session pyo3 object back to a Python wrapper.
1082+
1083+
The Rust call site forwards a ``datafusion._internal.SessionContext``,
1084+
but UDTF authors expect to interact with the public
1085+
:class:`datafusion.SessionContext` wrapper. This closure wraps the
1086+
internal object once per call before delegating to ``func``.
1087+
"""
1088+
1089+
@functools.wraps(func, updated=())
1090+
def adapter(*args: Any, session: Any, **kwargs: Any) -> Any:
1091+
wrapped = SessionContext.__new__(SessionContext)
1092+
wrapped.ctx = session
1093+
return func(*args, session=wrapped, **kwargs)
1094+
1095+
return adapter
1096+
1097+
10571098
class TableFunction:
10581099
"""Class for performing user-defined table functions (UDTF).
10591100
@@ -1066,10 +1107,19 @@ def __init__(
10661107
) -> None:
10671108
"""Instantiate a user-defined table function (UDTF).
10681109
1110+
If ``func``'s signature accepts a ``session`` keyword (or
1111+
``**kwargs``), the calling :class:`SessionContext` is threaded
1112+
through to it on each invocation. Use it inside the body to look
1113+
up registered tables, UDFs, or session configuration. Callables
1114+
whose signatures do not declare ``session`` are invoked with the
1115+
positional expression arguments only.
1116+
10691117
See :py:func:`udtf` for a convenience function and argument
10701118
descriptions.
10711119
"""
1072-
self._udtf = df_internal.TableFunction(name, func, ctx)
1120+
accepts_session = _callable_accepts_session_kwarg(func)
1121+
registered = _wrap_session_kwarg_for_udtf(func) if accepts_session else func
1122+
self._udtf = df_internal.TableFunction(name, registered, ctx, accepts_session)
10731123

10741124
def __call__(self, *args: Expr) -> Any:
10751125
"""Execute the UDTF and return a table provider."""

python/tests/test_udtf.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,68 @@ def string_arg_func(prefix: Expr) -> TableProviderExportable:
134134
result = ctx.sql("SELECT * FROM string_arg_func('test')").collect()
135135
assert len(result) == 1
136136
assert result[0].schema.names == ["test_a", "test_b"]
137+
138+
139+
def test_python_table_function_receives_session() -> None:
140+
"""A UDTF whose signature declares ``session`` gets the calling ctx."""
141+
ctx = SessionContext()
142+
captured: list[SessionContext] = []
143+
144+
@udtf("session_aware_func")
145+
def session_aware_func(*, session: SessionContext) -> TableProviderExportable:
146+
captured.append(session)
147+
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
148+
return Table(ds.dataset([batch]))
149+
150+
ctx.register_udtf(session_aware_func)
151+
result = ctx.sql("SELECT * FROM session_aware_func()").collect()
152+
153+
assert len(captured) == 1
154+
assert isinstance(captured[0], SessionContext)
155+
# Sharing the same catalog confirms the wrapper points at the caller's state.
156+
assert captured[0].catalog().schema().names() == ctx.catalog().schema().names()
157+
assert result[0].column(0).to_pylist() == [1, 2, 3]
158+
159+
160+
def test_python_table_function_session_used_for_metadata() -> None:
161+
"""The UDTF can inspect session state through the passed-in context."""
162+
ctx = SessionContext()
163+
base_batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]})
164+
ctx.register_batch("base_tbl", base_batch)
165+
166+
seen_tables: list[set[str]] = []
167+
168+
@udtf("table_inventory")
169+
def table_inventory(*, session: SessionContext) -> TableProviderExportable:
170+
# Stash the visible tables to verify the session wired through.
171+
seen_tables.append(session.catalog().schema().names())
172+
batch = pa.RecordBatch.from_pydict({"name": ["base_tbl"]})
173+
return Table(ds.dataset([batch]))
174+
175+
ctx.register_udtf(table_inventory)
176+
result = ctx.sql("SELECT * FROM table_inventory()").collect()
177+
178+
assert seen_tables == [{"base_tbl"}]
179+
assert result[0].column(0).to_pylist() == ["base_tbl"]
180+
181+
182+
def test_python_table_function_class_callable_session_kwarg() -> None:
183+
"""Class-based UDTFs whose __call__ accepts ``session`` get it too."""
184+
ctx = SessionContext()
185+
captured: list[SessionContext] = []
186+
187+
class SessionAware:
188+
def __call__(
189+
self, n: Expr, *, session: SessionContext
190+
) -> TableProviderExportable:
191+
captured.append(session)
192+
count = n.to_variant().value_i64()
193+
batch = pa.RecordBatch.from_pydict({"a": list(range(count))})
194+
return Table(ds.dataset([batch]))
195+
196+
ctx.register_udtf(udtf(SessionAware(), "session_class_func"))
197+
result = ctx.sql("SELECT * FROM session_class_func(3)").collect()
198+
199+
assert len(captured) == 1
200+
assert isinstance(captured[0], SessionContext)
201+
assert result[0].column(0).to_pylist() == [0, 1, 2]

0 commit comments

Comments
 (0)