Skip to content

Commit 665722a

Browse files
authored
fix: Invalid conversion from non-bit numpy bools (#24312)
1 parent 1e8cab3 commit 665722a

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

crates/polars-python/src/series/construction.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::borrow::Cow;
33
use arrow::array::Array;
44
use arrow::bitmap::BitmapBuilder;
55
use arrow::types::NativeType;
6-
use numpy::{Element, PyArray1, PyArrayMethods};
6+
use numpy::{Element, PyArray1, PyArrayMethods, PyUntypedArrayMethods};
77
use polars_core::prelude::*;
88
use polars_core::utils::CustomIterTools;
99
use pyo3::exceptions::{PyTypeError, PyValueError};
@@ -58,8 +58,13 @@ impl PySeries {
5858
_strict: bool,
5959
) -> PyResult<Self> {
6060
let array = array.readonly();
61-
let vals = array.as_slice().unwrap();
62-
py.enter_polars_series(|| Ok(Series::new(name.into(), vals)))
61+
62+
// We use raw ptr methods to read this as a u8 slice to work around PyO3/rust-numpy#509.
63+
assert!(array.is_contiguous());
64+
let data_ptr = array.data().cast::<u8>();
65+
let data_len = array.len();
66+
let vals = unsafe { core::slice::from_raw_parts(data_ptr, data_len) };
67+
py.enter_polars_series(|| Series::new(name.into(), vals).cast(&DataType::Boolean))
6368
}
6469

6570
#[staticmethod]

py-polars/tests/unit/interop/numpy/test_numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.testing import assert_array_equal
66

77
import polars as pl
8+
from polars.testing import assert_series_equal
89

910

1011
@pytest.fixture(
@@ -100,3 +101,8 @@ def test_init_from_numpy_values(np_dtype_cls: Any, expected_pl_dtype: Any) -> No
100101
# test init from raw numpy values (vs arrays)
101102
s = pl.Series("n", [np_dtype_cls(0), np_dtype_cls(4), np_dtype_cls(8)])
102103
assert s.dtype == expected_pl_dtype
104+
105+
106+
def test_from_numpy_nonbit_bools_24296() -> None:
107+
a = np.array([24, 15, 32, 1, 0], dtype=np.uint8).view(bool)
108+
assert_series_equal(pl.Series(a), pl.Series([True, True, True, True, False]))

0 commit comments

Comments
 (0)