Skip to content

Commit

Permalink
Merge pull request #33 from condekind/docs/vec_of_option
Browse files Browse the repository at this point in the history
Recommendation against Vec<Option<T>> with examples
  • Loading branch information
MarcoGorelli authored Aug 6, 2024
2 parents 8be02d7 + e032868 commit 0d178a3
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/aggregate.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 13. In (the) aggregate
# 14. In (the) aggregate

Enough transorming columns! Let's aggregate them instead.

Expand Down
2 changes: 1 addition & 1 deletion docs/publishing.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 12. Publishing your plugin to PyPI and becoming famous
# 13. Publishing your plugin to PyPI and becoming famous

Here are the steps you should follow:

Expand Down
205 changes: 205 additions & 0 deletions docs/vec_of_option.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@

# 12. `Vec<Option<T>>` vs. `Vec<T>`

> "I got, I got, I got, I got options" – _Pitbull_, before writing his first Polars plugin
In the plugins we looked at so far, we typically created an iterator of options and let Polars collect it into a `ChunkedArray`.
Sometimes, however, you need to store intermediate values in a `Vec`. You might be tempted to make it a `Vec<Option<T>>`, where
missing values are `None` and present values are `Some`...

🛑 BUT WAIT!

Did you know that `Vec<Option<i32>>` occupies twice as much memory as `Vec<i32>`? Let's prove it:

```rust
use std::mem::size_of_val;

fn main() {
let vector: Vec<i32> = vec![1, 2, 3];
println!("{}", size_of_val(&*vector));
// Output: 12

let vector: Vec<Option<i32>> = vec![Some(1), Some(2), Some(3)];
println!("{}", size_of_val(&*vector));
// Output: 24
}
```

So...how can we create an output which includes missing values, without allocating twice as much memory as is necessary?

## Validity mask

Instead of creating a vector of options, we can create a vector of primitive values with zeroes in place of the missing values, and use
a validity mask to indicate which values are missing. One example of this can be seen in Polars' `interpolate_impl`, which does the heavy lifting for the
[`Series.interpolate`](https://docs.pola.rs/api/python/version/0.18/reference/series/api/polars.Series.interpolate.html):

```rust
fn interpolate_impl<T, I>(chunked_arr: &ChunkedArray<T>, interpolation_branch: I) -> ChunkedArray<T>
where
T: PolarsNumericType,
I: Fn(T::Native, T::Native, IdxSize, T::Native, &mut Vec<T::Native>),
{
// This implementation differs from pandas as that boundary None's are not removed.
// This prevents a lot of errors due to expressions leading to different lengths.
if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() {
return chunked_arr.clone();
}

// We first find the first and last so that we can set the null buffer.
let first = chunked_arr.first_non_null().unwrap();
let last = chunked_arr.last_non_null().unwrap() + 1;

// Fill out with `first` nulls.
let mut out = Vec::with_capacity(chunked_arr.len());
let mut iter = chunked_arr.iter().skip(first);
for _ in 0..first {
out.push(Zero::zero());
}

// The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first
// elements `first` and if all values were missing we'd have done an early return.
let mut low = iter.next().unwrap().unwrap();
out.push(low);
while let Some(next) = iter.next() {
if let Some(v) = next {
out.push(v);
low = v;
} else {
let mut steps = 1 as IdxSize;
for next in iter.by_ref() {
steps += 1;
if let Some(high) = next {
let steps_n: T::Native = NumCast::from(steps).unwrap();
interpolation_branch(low, high, steps, steps_n, &mut out);
out.push(high);
low = high;
break;
}
}
}
}
if first != 0 || last != chunked_arr.len() {
let mut validity = MutableBitmap::with_capacity(chunked_arr.len());
validity.extend_constant(chunked_arr.len(), true);

for i in 0..first {
validity.set(i, false);
}

for i in last..chunked_arr.len() {
validity.set(i, false);
out.push(Zero::zero())
}

let array = PrimitiveArray::new(
T::get_dtype().to_arrow(CompatLevel::newest()),
out.into(),
Some(validity.into()),
);
ChunkedArray::with_chunk(chunked_arr.name(), array)
} else {
ChunkedArray::from_vec(chunked_arr.name(), out)
}
}
```

That's a lot to digest at once, so let's take small steps and focus on the core logic.
At the start, we store the indexes of the first and last non-null values:

```rust
let first = chunked_arr.first_non_null().unwrap();
let last = chunked_arr.last_non_null().unwrap() + 1;
```

We then create a vector `out` to store the result values in, and in places where we'd like
the output to be missing, we push zeroes (we'll see below how we tell Polars that these are
to be considered missing, rather than as ordinary zeroes):

```rust
let mut out = Vec::with_capacity(chunked_arr.len());
for _ in 0..first {
out.push(Zero::zero());
}
```

We then skip the first `first` elements and start interpolating (note how we write `out.push(low)`, not `out.push(Some(low))`
- we gloss over the rest as it's not related to the main focus of this chapter):

```rust
let mut iter = chunked_arr.iter().skip(first);
let mut low = iter.next().unwrap().unwrap();
out.push(low);
while let Some(next) = iter.next() {
// Interpolation logic
}
```

Now, after _most_ of the work is done and we've filled up most of `out`,
we create a validity mask and set it to `false` for elements which we'd like to declare as missing:

```rust
if first != 0 || last != chunked_arr.len() {
// A validity mask is created for the vector, initially all set to true
let mut validity = MutableBitmap::with_capacity(chunked_arr.len());
validity.extend_constant(chunked_arr.len(), true);

for i in 0..first {
// The indexes corresponding to the zeroes before the first valid value
// are set to false (invalid)
validity.set(i, false);
}

for i in last..chunked_arr.len() {
// The indexes corresponding to the values after the last valid value
// are set to false (invalid)
validity.set(i, false);

out.push(Zero::zero()) // Push zeroes after the last valid value, as
// many as there are nulls at the end, just like
// it was done before the first valid value.
}

let array = PrimitiveArray::new(
T::get_dtype().to_arrow(CompatLevel::newest()),
out.into(),
Some(validity.into()),
);
ChunkedArray::with_chunk(chunked_arr.name(), array)
} else {
ChunkedArray::from_vec(chunked_arr.name(), out)
}
```

The `MutableBitmap` only requires one byte per 8 elements, so the total space used is much less than it would've been
if we'd created `out` as a vector of options!
Further, note how the validity mask is only allocated when the output contains nulls - if there are no nulls, we can
save even more memory by not having a validity mask at all!

## Sentinel values

Let's look at another example of where it's possible to avoid allocating a vector of options. This example comes
from the Polars-XDT plugin. There's one function there which creates a temporary `idx` vector in which, for
each element, we store the index of the previous element larger than it. If an element has no previous larger
element, then rather than storing `None` (thus forcing all non-missing elements to be `Some`), we can just
store `-1`.

Take a look at [this diff from a PR](https://github.com/pola-rs/polars-xdt/pull/79/files#diff-991878a926639bba03bcc36a2790f73181b358f2ff59e0256f9ad76aa707be35) which does exactly that,
in which most changes are along the lines of:

```diff
- if i < Some(0) {
- idx.push(None);
+ if i < 0 {
+ idx.push(-1);
```

There's no functional behaviour change, but we already know the memory benefits!

## Conclusion

In general, _if you can avoid allocating `Vec<Option<T>>` instead of `Vec<T>`,_ __do it!__!

!!!note

This advice only applies if you're creating a vector to store results in. If you're collecting
an iterator of options into a chunked array, then Polars already optimises this for you.
9 changes: 9 additions & 0 deletions minimal_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,12 @@ def vertical_weighted_mean(values: IntoExpr, weights: IntoExpr) -> pl.Expr:
is_elementwise=False,
returns_scalar=True,
)


def interpolate(expr: IntoExpr) -> pl.Expr:
return register_plugin(
args=[expr],
lib=lib,
symbol="interpolate",
is_elementwise=False,
)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ nav:
- lists_in_lists_out.md
- struct.md
- lost_in_space.md
- vec_of_option.md
- publishing.md
- aggregate.md
- where_to_go.md
Expand Down
8 changes: 7 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@
'weights': [.5, .3, .2, .1, .9],
'group': ['a', 'a', 'a', 'b', 'b'],
})
print(df.group_by('group').agg(weighted_mean = mp.vertical_weighted_mean('values', 'weights')))
print(df.group_by('group').agg(weighted_mean = mp.vertical_weighted_mean('values', 'weights')))

df = pl.DataFrame({
'a': [None, None, 3, None, None, 9, 11, None],
})
result = df.with_columns(interpolate=mp.interpolate('a'))
print(result)
100 changes: 100 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#![allow(clippy::unused_unit)]
use polars::export::num::{NumCast, Zero};
use polars::prelude::arity::broadcast_binary_elementwise;
use polars::prelude::*;
use polars_arrow::bitmap::MutableBitmap;
use pyo3_polars::derive::polars_expr;
use pyo3_polars::export::polars_core::export::num::Signed;
use pyo3_polars::export::polars_core::utils::arrow::array::PrimitiveArray;
use pyo3_polars::export::polars_core::utils::CustomIterTools;
use serde::Deserialize;
use std::ops::{Add, Div, Mul, Sub};

use crate::utils::binary_amortized_elementwise;

Expand Down Expand Up @@ -332,3 +335,100 @@ fn vertical_weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let result = numerator / denominator;
Ok(Series::new("", vec![result]))
}

fn linear_itp<T>(low: T, step: T, slope: T) -> T
where
T: Sub<Output = T> + Mul<Output = T> + Add<Output = T> + Div<Output = T>,
{
low + step * slope
}

#[inline]
fn signed_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec<T>)
where
T: Sub<Output = T> + Mul<Output = T> + Add<Output = T> + Div<Output = T> + NumCast + Copy,
{
let slope = (high - low) / steps_n;
for step_i in 1..steps {
let step_i: T = NumCast::from(step_i).unwrap();
let v = linear_itp(low, step_i, slope);
out.push(v)
}
}

fn interpolate_impl<T, I>(chunked_arr: &ChunkedArray<T>, interpolation_branch: I) -> ChunkedArray<T>
where
T: PolarsNumericType,
I: Fn(T::Native, T::Native, IdxSize, T::Native, &mut Vec<T::Native>),
{
// This implementation differs from pandas as that boundary None's are not removed.
// This prevents a lot of errors due to expressions leading to different lengths.
if chunked_arr.null_count() == 0 || chunked_arr.null_count() == chunked_arr.len() {
return chunked_arr.clone();
}

// We first find the first and last so that we can set the null buffer.
let first = chunked_arr.first_non_null().unwrap();
let last = chunked_arr.last_non_null().unwrap() + 1;

// Fill out with `first` nulls.
let mut out = Vec::with_capacity(chunked_arr.len());
let mut iter = chunked_arr.iter().skip(first);
for _ in 0..first {
out.push(Zero::zero());
}

// The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first
// elements `first` and if all values were missing we'd have done an early return.
let mut low = iter.next().unwrap().unwrap();
out.push(low);
while let Some(next) = iter.next() {
if let Some(v) = next {
out.push(v);
low = v;
} else {
let mut steps = 1 as IdxSize;
for next in iter.by_ref() {
steps += 1;
if let Some(high) = next {
let steps_n: T::Native = NumCast::from(steps).unwrap();
interpolation_branch(low, high, steps, steps_n, &mut out);
out.push(high);
low = high;
break;
}
}
}
}
if first != 0 || last != chunked_arr.len() {
let mut validity = MutableBitmap::with_capacity(chunked_arr.len());
validity.extend_constant(chunked_arr.len(), true);

for i in 0..first {
validity.set(i, false);
}

for i in last..chunked_arr.len() {
validity.set(i, false);
out.push(Zero::zero())
}

let array = PrimitiveArray::new(
T::get_dtype().to_arrow(true),
out.into(),
Some(validity.into()),
);
ChunkedArray::with_chunk(chunked_arr.name(), array)
} else {
ChunkedArray::from_vec(chunked_arr.name(), out)
}
}

#[polars_expr(output_type=Int64)]
fn interpolate(inputs: &[Series]) -> PolarsResult<Series> {
let s = &inputs[0];
let ca = s.i64()?;
let mut out: Int64Chunked = interpolate_impl(ca, signed_interp::<i64>);
out.rename(ca.name());
Ok(out.into_series())
}

0 comments on commit 0d178a3

Please sign in to comment.