Skip to content

Commit b03562e

Browse files
committed
Add support for Python based TableProviderFactory
This adds the ability to register Python based TableProviderFactory instances to the SessionContext.
1 parent 29dcd98 commit b03562e

File tree

5 files changed

+122
-21
lines changed

5 files changed

+122
-21
lines changed

python/datafusion/catalog.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from datafusion import DataFrame, SessionContext
3131
from datafusion.context import TableProviderExportable
32+
from datafusion.expr import CreateExternalTable
3233

3334
try:
3435
from warnings import deprecated # Python 3.13+
@@ -243,6 +244,15 @@ def kind(self) -> str:
243244
return self._inner.kind
244245

245246

247+
class TableProviderFactory(ABC):
248+
"""Abstract class for defining a Python based Table Provider Factory."""
249+
250+
@abstractmethod
251+
def create(self, cmd: CreateExternalTable) -> Table:
252+
"""Create a table using the :class:`CreateExternalTable`."""
253+
...
254+
255+
246256
class TableProviderFactoryExportable(Protocol):
247257
"""Type hint for object that has __datafusion_table_provider_factory__ PyCapsule.
248258

python/datafusion/context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
CatalogProviderExportable,
3838
CatalogProviderList,
3939
CatalogProviderListExportable,
40+
TableProviderFactory,
4041
TableProviderFactoryExportable,
4142
)
4243
from datafusion.dataframe import DataFrame
@@ -832,7 +833,9 @@ def deregister_table(self, name: str) -> None:
832833
self.ctx.deregister_table(name)
833834

834835
def register_table_factory(
835-
self, format: str, factory: TableProviderFactoryExportable
836+
self,
837+
format: str,
838+
factory: TableProviderFactory | TableProviderFactoryExportable,
836839
) -> None:
837840
"""Register a :py:class:`~datafusion.TableProviderFactoryExportable`.
838841

python/tests/test_catalog.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def register_catalog(
120120
self.catalogs[name] = catalog
121121

122122

123+
class CustomTableProviderFactory(dfn.catalog.TableProviderFactory):
124+
def create(self, cmd: dfn.expr.CreateExternalTable):
125+
assert cmd.name() == "test_table_factory"
126+
return create_dataset()
127+
128+
123129
def test_python_catalog_provider_list(ctx: SessionContext):
124130
ctx.register_catalog_provider_list(CustomCatalogProviderList())
125131

@@ -314,3 +320,24 @@ def my_table_function_udtf() -> Table:
314320
assert len(result[0]) == 1
315321
assert len(result[0][0]) == 1
316322
assert result[0][0][0].as_py() == 3
323+
324+
325+
def test_register_python_table_provider_factory(ctx: SessionContext):
326+
ctx.register_table_factory("CUSTOM_FACTORY", CustomTableProviderFactory())
327+
328+
ctx.sql("""
329+
CREATE EXTERNAL TABLE test_table_factory
330+
STORED AS CUSTOM_FACTORY
331+
LOCATION foo;
332+
""").collect()
333+
334+
result = ctx.sql("SELECT * FROM test_table_factory;").collect()
335+
336+
expect = [
337+
pa.RecordBatch.from_arrays(
338+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
339+
names=["a", "b"],
340+
)
341+
]
342+
343+
assert result == expect

src/context.rs

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ use crate::record_batch::PyRecordBatchStream;
7878
use crate::sql::logical::PyLogicalPlan;
7979
use crate::sql::util::replace_placeholders_with_strings;
8080
use crate::store::StorageContexts;
81-
use crate::table::PyTable;
81+
use crate::table::{PyTable, RustWrappedPyTableProviderFactory};
8282
use crate::udaf::PyAggregateUDF;
8383
use crate::udf::PyScalarUDF;
8484
use crate::udtf::PyTableFunction;
@@ -663,22 +663,31 @@ impl PySessionContext {
663663
pub fn register_table_factory(
664664
&self,
665665
format: &str,
666-
factory: Bound<'_, PyAny>,
666+
mut factory: Bound<'_, PyAny>,
667667
) -> PyDataFusionResult<()> {
668-
let py = factory.py();
669-
let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
670-
671-
let capsule = factory
672-
.getattr("__datafusion_table_provider_factory__")?
673-
.call1((codec_capsule,))?;
674-
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
675-
validate_pycapsule(capsule, "datafusion_table_provider_factory")?;
676-
677-
let factory: NonNull<FFI_TableProviderFactory> = capsule
678-
.pointer_checked(Some(c_str!("datafusion_table_provider_factory")))?
679-
.cast();
680-
let factory = unsafe { factory.as_ref() };
681-
let factory: Arc<dyn TableProviderFactory> = factory.into();
668+
if factory.hasattr("__datafusion_table_provider_factory__")? {
669+
let py = factory.py();
670+
let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
671+
factory = factory
672+
.getattr("__datafusion_table_provider_factory__")?
673+
.call1((codec_capsule,))?;
674+
}
675+
676+
let factory: Arc<dyn TableProviderFactory> =
677+
if let Ok(capsule) = factory.cast::<PyCapsule>().map_err(py_datafusion_err) {
678+
validate_pycapsule(capsule, "datafusion_table_provider_factory")?;
679+
680+
let data: NonNull<FFI_TableProviderFactory> = capsule
681+
.pointer_checked(Some(c_str!("datafusion_table_provider_factory")))?
682+
.cast();
683+
let factory = unsafe { data.as_ref() };
684+
factory.into()
685+
} else {
686+
Arc::new(RustWrappedPyTableProviderFactory::new(
687+
factory.into(),
688+
self.logical_codec.clone(),
689+
))
690+
};
682691

683692
let st = self.ctx.state_ref();
684693
let mut lock = st.write();

src/table.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@ use std::sync::Arc;
2121
use arrow::datatypes::SchemaRef;
2222
use arrow::pyarrow::ToPyArrow;
2323
use async_trait::async_trait;
24-
use datafusion::catalog::Session;
24+
use datafusion::catalog::{Session, TableProviderFactory};
2525
use datafusion::common::Column;
2626
use datafusion::datasource::{TableProvider, TableType};
27-
use datafusion::logical_expr::{Expr, LogicalPlanBuilder, TableProviderFilterPushDown};
27+
use datafusion::logical_expr::{
28+
CreateExternalTable, Expr, LogicalPlanBuilder, TableProviderFilterPushDown,
29+
};
2830
use datafusion::physical_plan::ExecutionPlan;
2931
use datafusion::prelude::DataFrame;
32+
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
3033
use pyo3::IntoPyObjectExt;
3134
use pyo3::prelude::*;
3235

3336
use crate::context::PySessionContext;
3437
use crate::dataframe::PyDataFrame;
3538
use crate::dataset::Dataset;
36-
use crate::utils::table_provider_from_pycapsule;
39+
use crate::expr::create_external_table::PyCreateExternalTable;
40+
use crate::{errors, utils};
3741

3842
/// This struct is used as a common method for all TableProviders,
3943
/// whether they refer to an FFI provider, an internally known
@@ -91,7 +95,7 @@ impl PyTable {
9195
Some(session) => session,
9296
None => PySessionContext::global_ctx()?.into_bound_py_any(obj.py())?,
9397
};
94-
table_provider_from_pycapsule(obj.clone(), session)?
98+
utils::table_provider_from_pycapsule(obj.clone(), session)?
9599
} {
96100
Ok(PyTable::from(provider))
97101
} else {
@@ -206,3 +210,51 @@ impl TableProvider for TempViewTable {
206210
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
207211
}
208212
}
213+
214+
#[derive(Debug)]
215+
pub(crate) struct RustWrappedPyTableProviderFactory {
216+
pub(crate) table_provider_factory: Py<PyAny>,
217+
pub(crate) codec: Arc<FFI_LogicalExtensionCodec>,
218+
}
219+
220+
impl RustWrappedPyTableProviderFactory {
221+
pub fn new(table_provider_factory: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
222+
Self {
223+
table_provider_factory,
224+
codec,
225+
}
226+
}
227+
228+
fn create_inner(
229+
&self,
230+
cmd: CreateExternalTable,
231+
codec: Bound<PyAny>,
232+
) -> PyResult<Arc<dyn TableProvider>> {
233+
Python::attach(|py| {
234+
let provider = self.table_provider_factory.bind(py);
235+
let cmd = PyCreateExternalTable::from(cmd);
236+
237+
provider
238+
.call_method1("create", (cmd,))
239+
.and_then(|t| PyTable::new(t, Some(codec)))
240+
.map(|t| t.table())
241+
})
242+
}
243+
}
244+
245+
#[async_trait]
246+
impl TableProviderFactory for RustWrappedPyTableProviderFactory {
247+
async fn create(
248+
&self,
249+
_: &dyn Session,
250+
cmd: &CreateExternalTable,
251+
) -> datafusion::common::Result<Arc<dyn TableProvider>> {
252+
Python::attach(|py| {
253+
let codec = utils::create_logical_extension_capsule(py, self.codec.as_ref())
254+
.map_err(errors::to_datafusion_err)?;
255+
256+
self.create_inner(cmd.clone(), codec.into_any())
257+
.map_err(errors::to_datafusion_err)
258+
})
259+
}
260+
}

0 commit comments

Comments
 (0)