Skip to content

Commit

Permalink
Merge pull request #47 from MarcoGorelli/validate-dtype-in-list-plugin
Browse files Browse the repository at this point in the history
validate dtypes in list plugin
  • Loading branch information
MarcoGorelli authored Sep 2, 2024
2 parents 89f149f + 196a7b3 commit 8978fe9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
11 changes: 10 additions & 1 deletion docs/lists.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ For now, let's just look at how to use this utility:
fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let values = inputs[0].list()?;
let weights = &inputs[1].list()?;
polars_ensure!(
values.dtype() == &DataType::List(Box::new(DataType::Int64)),
ComputeError: "Expected `values` to be of type `List(Int64)`, got: {}", values.dtype()
);
polars_ensure!(
weights.dtype() == &DataType::List(Box::new(DataType::Float64)),
ComputeError: "Expected `weights` to be of type `List(Float64)`, got: {}", weights.dtype()
);

let out: Float64Chunked = binary_amortized_elementwise(
values,
Expand Down Expand Up @@ -131,7 +139,8 @@ limitations:

- it assumes that each inner element of `values` and `weights` has the same
length - it would be better to raise an error if this assumption is not met
- it only accepts `Int64` values (see section 2 for how you could make it more generic).
- it only accepts `Int64` `values` and `Float64` `weights`
(see section 2 for how you could make it more generic).

To try it out, we compile with `maturin develop` (or `maturin develop --release` if you're
benchmarking), and then we should be able to run `run.py`:
Expand Down
4 changes: 4 additions & 0 deletions docs/lists_in_lists_out.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ fn list_idx_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
#[polars_expr(output_type_func=list_idx_dtype)]
fn non_zero_indices(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].list()?;
polars_ensure!(
ca.dtype() == &DataType::List(Box::new(DataType::Int64)),
ComputeError: "Expected `List(Int64)`, got: {}", ca.dtype()
);

let out: ListChunked = ca.apply_amortized(|s| {
let s: &Series = s.as_ref();
Expand Down
12 changes: 12 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ where
fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let values = inputs[0].list()?;
let weights = &inputs[1].list()?;
polars_ensure!(
values.dtype() == &DataType::List(Box::new(DataType::Int64)),
ComputeError: "Expected `values` to be of type `List(Int64)`, got: {}", values.dtype()
);
polars_ensure!(
weights.dtype() == &DataType::List(Box::new(DataType::Float64)),
ComputeError: "Expected `weights` to be of type `List(Float64)`, got: {}", weights.dtype()
);

let out: Float64Chunked = binary_amortized_elementwise(
values,
Expand Down Expand Up @@ -292,6 +300,10 @@ fn list_idx_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
#[polars_expr(output_type_func=list_idx_dtype)]
fn non_zero_indices(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].list()?;
polars_ensure!(
ca.dtype() == &DataType::List(Box::new(DataType::Int64)),
ComputeError: "Expected `List(Int64)`, got: {}", ca.dtype()
);

let out: ListChunked = ca.apply_amortized(|s| {
let s: &Series = s.as_ref();
Expand Down

0 comments on commit 8978fe9

Please sign in to comment.