Skip to content

Commit 4e6d2a0

Browse files
committed
Remove validate_pycapsule
The Bound<'_, PyCapsule>::pointer_checked does the same validation and is already used across the codebase
1 parent 1160d5a commit 4e6d2a0

File tree

11 files changed

+31
-106
lines changed

11 files changed

+31
-106
lines changed

examples/datafusion-ffi-example/src/utils.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,10 @@ pub(crate) fn ffi_logical_codec_from_pycapsule(
3535
};
3636

3737
let capsule = capsule.cast::<PyCapsule>()?;
38-
validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;
39-
4038
let data: NonNull<FFI_LogicalExtensionCodec> = capsule
4139
.pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))?
4240
.cast();
4341
let codec = unsafe { data.as_ref() };
4442

4543
Ok(codec.clone())
4644
}
47-
48-
pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
49-
let capsule_name = capsule.name()?;
50-
if capsule_name.is_none() {
51-
return Err(PyValueError::new_err(format!(
52-
"Expected {name} PyCapsule to have name set."
53-
)));
54-
}
55-
56-
let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
57-
if capsule_name != name {
58-
return Err(PyValueError::new_err(format!(
59-
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
60-
)));
61-
}
62-
63-
Ok(())
64-
}

src/array.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,11 @@ use arrow::array::{Array, ArrayRef};
2222
use arrow::datatypes::{Field, FieldRef};
2323
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
2424
use arrow::pyarrow::ToPyArrow;
25-
use pyo3::ffi::c_str;
2625
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
2726
use pyo3::types::PyCapsule;
2827
use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods};
2928

3029
use crate::errors::PyDataFusionResult;
31-
use crate::utils::validate_pycapsule;
3230

3331
/// A Python object which implements the Arrow PyCapsule for importing
3432
/// into other libraries.
@@ -53,10 +51,8 @@ impl PyArrowArrayExportable {
5351
requested_schema: Option<Bound<'py, PyCapsule>>,
5452
) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
5553
let field = if let Some(schema_capsule) = requested_schema {
56-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
57-
5854
let data: NonNull<FFI_ArrowSchema> = schema_capsule
59-
.pointer_checked(Some(c_str!("arrow_schema")))?
55+
.pointer_checked(Some(c"arrow_schema"))?
6056
.cast();
6157
let schema_ptr = unsafe { data.as_ref() };
6258
let desired_field = Field::try_from(schema_ptr)?;

src/catalog.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,14 @@ use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
3232
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
3333
use pyo3::IntoPyObjectExt;
3434
use pyo3::exceptions::PyKeyError;
35-
use pyo3::ffi::c_str;
3635
use pyo3::prelude::*;
3736
use pyo3::types::PyCapsule;
3837

3938
use crate::dataset::Dataset;
4039
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
4140
use crate::table::PyTable;
4241
use crate::utils::{
43-
create_logical_extension_capsule, extract_logical_extension_codec, validate_pycapsule,
44-
wait_for_future,
42+
create_logical_extension_capsule, extract_logical_extension_codec, wait_for_future,
4543
};
4644

4745
#[pyclass(
@@ -658,9 +656,8 @@ fn extract_catalog_provider_from_pyobj(
658656
}
659657

660658
let provider = if let Ok(capsule) = catalog_provider.cast::<PyCapsule>() {
661-
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
662659
let data: NonNull<FFI_CatalogProvider> = capsule
663-
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
660+
.pointer_checked(Some(c"datafusion_catalog_provider"))?
664661
.cast();
665662
let provider = unsafe { data.as_ref() };
666663
let provider: Arc<dyn CatalogProvider + Send> = provider.into();
@@ -691,10 +688,8 @@ fn extract_schema_provider_from_pyobj(
691688
}
692689

693690
let provider = if let Ok(capsule) = schema_provider.cast::<PyCapsule>() {
694-
validate_pycapsule(capsule, "datafusion_schema_provider")?;
695-
696691
let data: NonNull<FFI_SchemaProvider> = capsule
697-
.pointer_checked(Some(c_str!("datafusion_schema_provider")))?
692+
.pointer_checked(Some(c"datafusion_schema_provider"))?
698693
.cast();
699694
let provider = unsafe { data.as_ref() };
700695
let provider: Arc<dyn SchemaProvider + Send> = provider.into();

src/context.rs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
5555
use object_store::ObjectStore;
5656
use pyo3::IntoPyObjectExt;
5757
use pyo3::exceptions::{PyKeyError, PyValueError};
58-
use pyo3::ffi::c_str;
5958
use pyo3::prelude::*;
6059
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
6160
use url::Url;
@@ -84,7 +83,7 @@ use crate::udtf::PyTableFunction;
8483
use crate::udwf::PyWindowUDF;
8584
use crate::utils::{
8685
create_logical_extension_capsule, extract_logical_extension_codec, get_global_ctx,
87-
get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future,
86+
get_tokio_runtime, spawn_future, wait_for_future,
8887
};
8988

9089
/// Configuration options for a SessionContext
@@ -671,12 +670,9 @@ impl PySessionContext {
671670
.call1((codec_capsule,))?;
672671
}
673672

674-
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
675-
{
676-
validate_pycapsule(capsule, "datafusion_catalog_provider_list")?;
677-
673+
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
678674
let data: NonNull<FFI_CatalogProviderList> = capsule
679-
.pointer_checked(Some(c_str!("datafusion_catalog_provider_list")))?
675+
.pointer_checked(Some(c"datafusion_catalog_provider_list"))?
680676
.cast();
681677
let provider = unsafe { data.as_ref() };
682678
let provider: Arc<dyn CatalogProviderList + Send> = provider.into();
@@ -709,12 +705,9 @@ impl PySessionContext {
709705
.call1((codec_capsule,))?;
710706
}
711707

712-
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
713-
{
714-
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
715-
708+
let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
716709
let data: NonNull<FFI_CatalogProvider> = capsule
717-
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
710+
.pointer_checked(Some(c"datafusion_catalog_provider"))?
718711
.cast();
719712
let provider = unsafe { data.as_ref() };
720713
let provider: Arc<dyn CatalogProvider + Send> = provider.into();

src/dataframe.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ use futures::{StreamExt, TryStreamExt};
4545
use parking_lot::Mutex;
4646
use pyo3::PyErr;
4747
use pyo3::exceptions::PyValueError;
48-
use pyo3::ffi::c_str;
4948
use pyo3::prelude::*;
5049
use pyo3::pybacked::PyBackedStr;
5150
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
@@ -58,7 +57,7 @@ use crate::physical_plan::PyExecutionPlan;
5857
use crate::record_batch::{PyRecordBatchStream, poll_next_batch};
5958
use crate::sql::logical::PyLogicalPlan;
6059
use crate::table::{PyTable, TempViewTable};
61-
use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};
60+
use crate::utils::{is_ipython_env, spawn_future, wait_for_future};
6261

6362
/// File-level static CStr for the Arrow array stream capsule name.
6463
static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
@@ -1117,10 +1116,8 @@ impl PyDataFrame {
11171116
let mut projection: Option<SchemaRef> = None;
11181117

11191118
if let Some(schema_capsule) = requested_schema {
1120-
validate_pycapsule(&schema_capsule, "arrow_schema")?;
1121-
11221119
let data: NonNull<FFI_ArrowSchema> = schema_capsule
1123-
.pointer_checked(Some(c_str!("arrow_schema")))?
1120+
.pointer_checked(Some(c"arrow_schema"))?
11241121
.cast();
11251122
let schema_ptr = unsafe { data.as_ref() };
11261123
let desired_schema = Schema::try_from(schema_ptr)?;

src/dataset_exec.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use futures::{TryStreamExt, stream};
3939
use pyo3::prelude::*;
4040
use pyo3::types::{PyDict, PyIterator, PyList};
4141

42-
use crate::errors::PyDataFusionResult;
42+
use crate::errors::{PyDataFusionResult, to_datafusion_err};
4343
use crate::pyarrow_filter_expression::PyArrowFilterExpression;
4444

4545
struct PyArrowBatchesAdapter {
@@ -128,7 +128,7 @@ impl DatasetExec {
128128
)?;
129129

130130
let fragments_iter = pylist.call1((fragments_iterator,))?;
131-
let fragments = fragments_iter.cast::<PyList>().map_err(PyErr::from)?;
131+
let fragments = fragments_iter.cast::<PyList>().map_err(to_datafusion_err)?;
132132

133133
let projected_statistics = Statistics::new_unknown(&schema);
134134
let plan_properties = Arc::new(PlanProperties::new(

src/udaf.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@ use datafusion::logical_expr::{
2727
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
2828
};
2929
use datafusion_ffi::udaf::FFI_AggregateUDF;
30-
use pyo3::ffi::c_str;
3130
use pyo3::prelude::*;
3231
use pyo3::types::{PyCapsule, PyTuple};
3332

3433
use crate::common::data_type::PyScalarValue;
3534
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
3635
use crate::expr::PyExpr;
37-
use crate::utils::{parse_volatility, validate_pycapsule};
36+
use crate::utils::parse_volatility;
3837

3938
#[derive(Debug)]
4039
struct RustAccumulator {
@@ -157,10 +156,8 @@ pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
157156
}
158157

159158
fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
160-
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
161-
162159
let data: NonNull<FFI_AggregateUDF> = capsule
163-
.pointer_checked(Some(c_str!("datafusion_aggregate_udf")))?
160+
.pointer_checked(Some(c"datafusion_aggregate_udf"))?
164161
.cast();
165162
let udaf = unsafe { data.as_ref() };
166163
let udaf: Arc<dyn AggregateUDFImpl> = udaf.into();

src/udf.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,13 @@ use datafusion::logical_expr::{
3232
Volatility,
3333
};
3434
use datafusion_ffi::udf::FFI_ScalarUDF;
35-
use pyo3::ffi::c_str;
3635
use pyo3::prelude::*;
3736
use pyo3::types::{PyCapsule, PyTuple};
3837

3938
use crate::array::PyArrowArrayExportable;
40-
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
39+
use crate::errors::{PyDataFusionResult, to_datafusion_err};
4140
use crate::expr::PyExpr;
42-
use crate::utils::{parse_volatility, validate_pycapsule};
41+
use crate::utils::parse_volatility;
4342

4443
/// This struct holds the Python written function that is a
4544
/// ScalarUDF.
@@ -194,11 +193,9 @@ impl PyScalarUDF {
194193
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
195194
if func.hasattr("__datafusion_scalar_udf__")? {
196195
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
197-
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
198-
validate_pycapsule(capsule, "datafusion_scalar_udf")?;
199-
196+
let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
200197
let data: NonNull<FFI_ScalarUDF> = capsule
201-
.pointer_checked(Some(c_str!("datafusion_scalar_udf")))?
198+
.pointer_checked(Some(c"datafusion_scalar_udf"))?
202199
.cast();
203200
let udf = unsafe { data.as_ref() };
204201
let udf: Arc<dyn ScalarUDFImpl> = udf.into();

src/udtf.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,13 @@ use datafusion::logical_expr::Expr;
2424
use datafusion_ffi::udtf::FFI_TableFunction;
2525
use pyo3::IntoPyObjectExt;
2626
use pyo3::exceptions::{PyImportError, PyTypeError};
27-
use pyo3::ffi::c_str;
2827
use pyo3::prelude::*;
2928
use pyo3::types::{PyCapsule, PyTuple, PyType};
3029

3130
use crate::context::PySessionContext;
3231
use crate::errors::{py_datafusion_err, to_datafusion_err};
3332
use crate::expr::PyExpr;
3433
use crate::table::PyTable;
35-
use crate::utils::validate_pycapsule;
3634

3735
/// Represents a user defined table function
3836
#[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")]
@@ -73,11 +71,9 @@ impl PyTableFunction {
7371
err
7472
}
7573
})?;
76-
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
77-
validate_pycapsule(capsule, "datafusion_table_function")?;
78-
74+
let capsule = capsule.cast::<PyCapsule>()?;
7975
let data: NonNull<FFI_TableFunction> = capsule
80-
.pointer_checked(Some(c_str!("datafusion_table_function")))?
76+
.pointer_checked(Some(c"datafusion_table_function"))?
8177
.cast();
8278
let ffi_func = unsafe { data.as_ref() };
8379
let foreign_func: Arc<dyn TableFunctionImpl> = ffi_func.to_owned().into();

src/udwf.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@ use datafusion::logical_expr::{
3333
use datafusion::scalar::ScalarValue;
3434
use datafusion_ffi::udwf::FFI_WindowUDF;
3535
use pyo3::exceptions::PyValueError;
36-
use pyo3::ffi::c_str;
3736
use pyo3::prelude::*;
3837
use pyo3::types::{PyCapsule, PyList, PyTuple};
3938

4039
use crate::common::data_type::PyScalarValue;
41-
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
40+
use crate::errors::{PyDataFusionResult, to_datafusion_err};
4241
use crate::expr::PyExpr;
43-
use crate::utils::{parse_volatility, validate_pycapsule};
42+
use crate::utils::parse_volatility;
4443

4544
#[derive(Debug)]
4645
struct RustPartitionEvaluator {
@@ -262,11 +261,9 @@ impl PyWindowUDF {
262261
func
263262
};
264263

265-
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
266-
validate_pycapsule(capsule, "datafusion_window_udf")?;
267-
264+
let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
268265
let data: NonNull<FFI_WindowUDF> = capsule
269-
.pointer_checked(Some(c_str!("datafusion_window_udf")))?
266+
.pointer_checked(Some(c"datafusion_window_udf"))?
270267
.cast();
271268
let udwf = unsafe { data.as_ref() };
272269
let udwf: Arc<dyn WindowUDFImpl> = udwf.into();

0 commit comments

Comments
 (0)