Skip to content

Commit

Permalink
Merge pull request #39 from MarcoGorelli/rewrite-lists
Browse files Browse the repository at this point in the history
Rewrite lists
  • Loading branch information
MarcoGorelli authored Aug 14, 2024
2 parents 131a09f + 04a47eb commit e125fec
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 147 deletions.
64 changes: 44 additions & 20 deletions docs/lists.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,47 @@ def weighted_mean(expr: IntoExpr, weights: IntoExpr) -> pl.Expr:
)
```

On the Rust side, we'll make use of `binary_amortized_elementwise`, which you
can find in `src/utils.rs` (if you followed the instructions in [Prerequisites]).
Don't worry about understanding it.
Some of its details (such as `.as_ref()` to get a `Series` out of an `UnstableSeries`) are
optimizations with some gotchas - unless you really know what you're doing, I'd suggest
just using `binary_amortized_elementwise` directly. Hopefully a utility like this
can be added to Polars itself, so that plugin authors won't need to worry about it.

To use it, just add
On the Rust side, we'll define a helper function which will let us work with
pairs of list chunked arrays:

```rust
use crate::utils::binary_amortized_elementwise;
fn binary_amortized_elementwise<'a, T, K, F>(
lhs: &'a ListChunked,
rhs: &'a ListChunked,
mut f: F,
) -> ChunkedArray<T>
where
T: PolarsDataType,
T::Array: ArrayFromIter<Option<K>>,
F: FnMut(&AmortSeries, &AmortSeries) -> Option<K> + Copy,
{
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
lhs.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(lhs, rhs)| match (lhs, rhs) {
(Some(lhs), Some(rhs)) => f(&lhs, &rhs),
_ => None,
})
.collect_ca(lhs.name())
}
}
```
to the top of `src/expressions.rs`, after the previous imports.

We just need to write a function which accepts two `Series`, computes their dot product, and then
divides by the sum of the weights:
That's a bit of a mouthful, so let's try to make sense of it.

- As we learned about in [Prerequisites], Polars Series are backed by chunked arrays.
`align_chunks_binary` just ensures that the chunks have the same lengths. It may need
to rechunk under the hood for us;
- `amortized_iter` returns an iterator of `AmortSeries`, each of which corresponds
to a row from our input.

We'll explain more about `AmortSeries` in a future iteration of this tutorial.
For now, let's just look at how to use this utility:

- we pass it `ListChunked` as inputs;
- we also pass a function which takes two `AmortSeries` and produces a scalar
value.

```rust
#[polars_expr(output_type=Float64)]
Expand All @@ -76,9 +101,9 @@ fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let out: Float64Chunked = binary_amortized_elementwise(
values,
weights,
|values_inner: &Series, weights_inner: &Series| -> Option<f64> {
let values_inner = values_inner.i64().unwrap();
let weights_inner = weights_inner.f64().unwrap();
|values_inner: &AmortSeries, weights_inner: &AmortSeries| -> Option<f64> {
let values_inner = values_inner.as_ref().i64().unwrap();
let weights_inner = weights_inner.as_ref().f64().unwrap();
if values_inner.len() == 0 {
// Mirror Polars, and return None for empty mean.
return None
Expand All @@ -101,13 +126,12 @@ fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
}
```

Note: this function has some limitations:
If you just need to get a problem solved, this function works! But let's note its
limitations:

- it assumes that each inner element of `values` and `weights` has the same
length - it would be better to raise an error if this assumption is not met
- it only accepts `Int64` values () (see section 2 for how you could make it more generic).

Nonetheless, if you just need to get a problem solved, it works!
- it only accepts `Int64` values (see section 2 for how you could make it more generic).

To try it out, we compile with `maturin develop` (or `maturin develop --release` if you're
benchmarking), and then we should be able to run `run.py`:
Expand Down
45 changes: 12 additions & 33 deletions docs/lost_in_space.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,43 +62,22 @@ use reverse_geocoder::ReverseGeocoder;

#[polars_expr(output_type=String)]
fn reverse_geocode(inputs: &[Series]) -> PolarsResult<Series> {
let lat = inputs[0].f64()?;
let lon = inputs[1].f64()?;
let latitude = inputs[0].f64()?;
let longitude = inputs[1].f64()?;
let geocoder = ReverseGeocoder::new();

let (lhs, rhs) = align_chunks_binary(lat, lon);
let chunks = lhs
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(lat_arr, lon_arr)| {
let mut mutarr = MutablePlString::with_capacity(lat_arr.len());

for (lat_opt_val, lon_opt_val) in lat_arr.iter().zip(lon_arr.iter()) {
match (lat_opt_val, lon_opt_val) {
(Some(lat_val), Some(lon_val)) => {
let res = &geocoder.search((*lat_val, *lon_val)).record.name;
mutarr.push(Some(res))
}
_ => mutarr.push_null(),
}
}

mutarr.freeze().boxed()
})
.collect();
let out: StringChunked = unsafe { ChunkedArray::from_chunks("placeholder", chunks) };
let out = binary_elementwise_into_string_amortized(latitude, longitude, |lhs, rhs, out| {
let search_result = geocoder.search((lhs, rhs));
write!(out, "{}", search_result.record.name).unwrap();
});
Ok(out.into_series())
}
```
That's a bit of a mouthful, so let's try to make sense of it.

- As we learned about in [Prerequisites], Polars Series are backed by chunked arrays.
`align_chunks_binary` just ensures that the chunks have the same lengths. It may need
to rechunk under the hood for us;
- `downcast_iter` returns an iterator of Arrow Arrays. Using `zip`, we iterate over
respective pairs of Arrow Arrays from `lhs` and `rhs`
- to learn about `MutablePlString`, please read
[Why we have rewritten our String type](https://pola.rs/posts/polars-string-type/)

We use the utility function `binary_elementwise_into_string_amortized`,
which is a binary version of `apply_into_string_amortized` which we learned
about in the [Stringify] chapter.

[Stringify]: ../stringify/

To run it, put the following in `run.py`:
```python
Expand Down
2 changes: 2 additions & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mkdocs
mkdocs-material
104 changes: 36 additions & 68 deletions src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
#![allow(clippy::unused_unit)]
use polars::export::num::{NumCast, Zero};
use polars::prelude::arity::broadcast_binary_elementwise;
use polars::prelude::arity::{
binary_elementwise_into_string_amortized, broadcast_binary_elementwise,
};
use polars::prelude::*;
use polars_arrow::bitmap::MutableBitmap;
use polars_core::series::amortized_iter::AmortSeries;
use polars_core::utils::align_chunks_binary;
use pyo3_polars::derive::polars_expr;
use pyo3_polars::export::polars_core::export::num::Signed;
use pyo3_polars::export::polars_core::utils::arrow::array::PrimitiveArray;
use pyo3_polars::export::polars_core::utils::CustomIterTools;
use serde::Deserialize;
use std::ops::{Add, Div, Mul, Sub};

use crate::utils::binary_amortized_elementwise;

fn same_output_type(input_fields: &[Field]) -> PolarsResult<Field> {
let field = &input_fields[0];
Ok(field.clone())
Expand Down Expand Up @@ -111,39 +113,6 @@ fn pig_latinnify(inputs: &[Series]) -> PolarsResult<Series> {
Ok(out.into_series())
}

// #[polars_expr(output_type=String)]
// fn reverse_geocode(inputs: &[Series]) -> PolarsResult<Series> {
// let binding = inputs[0].struct_()?.field_by_name("lat")?;
// let latitude = binding.f64()?;
// let binding = inputs[0].struct_()?.field_by_name("lon")?;
// let longitude = binding.f64()?;
// let geocoder = ReverseGeocoder::new();
// let (lhs, rhs) = align_chunks_binary(latitude, longitude);
// let iter = lhs.downcast_iter().zip(rhs.downcast_iter()).map(
// |(lhs_arr, rhs_arr)| -> LargeStringArray {
// let mut buf = String::new();
// let mut mutarr: MutableUtf8Array<i64> =
// MutableUtf8Array::with_capacities(lhs_arr.len(), lhs_arr.len() * 20);

// for (lhs_opt_val, rhs_opt_val) in lhs_arr.iter().zip(rhs_arr.iter()) {
// match (lhs_opt_val, rhs_opt_val) {
// (Some(lhs_val), Some(rhs_val)) => {
// buf.clear();
// let search_result = geocoder.search((*lhs_val, *rhs_val));
// write!(buf, "{}", search_result.record.name).unwrap();
// mutarr.push(Some(&buf))
// }
// _ => mutarr.push_null(),
// }
// }
// let arr: Utf8Array<i64> = mutarr.into();
// arr
// },
// );
// let out = StringChunked::from_chunk_iter(lhs.name(), iter);
// Ok(out.into_series())
// }

#[polars_expr(output_type=Int64)]
fn abs_i64_fast(inputs: &[Series]) -> PolarsResult<Series> {
let s = &inputs[0];
Expand Down Expand Up @@ -187,6 +156,28 @@ fn add_suffix(inputs: &[Series], kwargs: AddSuffixKwargs) -> PolarsResult<Series
// Ok(out.into_series())
// }

fn binary_amortized_elementwise<'a, T, K, F>(
lhs: &'a ListChunked,
rhs: &'a ListChunked,
mut f: F,
) -> ChunkedArray<T>
where
T: PolarsDataType,
T::Array: ArrayFromIter<Option<K>>,
F: FnMut(&AmortSeries, &AmortSeries) -> Option<K> + Copy,
{
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
lhs.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(lhs, rhs)| match (lhs, rhs) {
(Some(lhs), Some(rhs)) => f(&lhs, &rhs),
_ => None,
})
.collect_ca(lhs.name())
}
}

#[polars_expr(output_type=Float64)]
fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let values = inputs[0].list()?;
Expand All @@ -195,9 +186,9 @@ fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
let out: Float64Chunked = binary_amortized_elementwise(
values,
weights,
|values_inner: &Series, weights_inner: &Series| -> Option<f64> {
let values_inner = values_inner.i64().unwrap();
let weights_inner = weights_inner.f64().unwrap();
|values_inner: &AmortSeries, weights_inner: &AmortSeries| -> Option<f64> {
let values_inner = values_inner.as_ref().i64().unwrap();
let weights_inner = weights_inner.as_ref().f64().unwrap();
if values_inner.is_empty() {
// Mirror Polars, and return None for empty mean.
return None;
Expand Down Expand Up @@ -260,40 +251,17 @@ fn shift_struct(inputs: &[Series]) -> PolarsResult<Series> {
StructChunked::from_series(struct_.name(), &fields).map(|ca| ca.into_series())
}

use polars_arrow::array::MutablePlString;
use polars_core::utils::align_chunks_binary;
use reverse_geocoder::ReverseGeocoder;

#[polars_expr(output_type=String)]
fn reverse_geocode(inputs: &[Series]) -> PolarsResult<Series> {
let lhs = inputs[0].f64()?;
let rhs = inputs[1].f64()?;
let latitude = inputs[0].f64()?;
let longitude = inputs[1].f64()?;
let geocoder = ReverseGeocoder::new();

let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let chunks = lhs
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(lhs_arr, rhs_arr)| {
let mut buf = String::new();
let mut mutarr = MutablePlString::with_capacity(lhs_arr.len());

for (lhs_opt_val, rhs_opt_val) in lhs_arr.iter().zip(rhs_arr.iter()) {
match (lhs_opt_val, rhs_opt_val) {
(Some(lhs_val), Some(rhs_val)) => {
let res = &geocoder.search((*lhs_val, *rhs_val)).record.name;
buf.clear();
write!(buf, "{res}").unwrap();
mutarr.push(Some(&buf))
}
_ => mutarr.push_null(),
}
}

mutarr.freeze().boxed()
})
.collect();
let out: StringChunked = unsafe { ChunkedArray::from_chunks(lhs.name(), chunks) };
let out = binary_elementwise_into_string_amortized(latitude, longitude, |lhs, rhs, out| {
let search_result = geocoder.search((lhs, rhs));
write!(out, "{}", search_result.record.name).unwrap();
});
Ok(out.into_series())
}

Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
mod expressions;
mod utils;

#[cfg(target_os = "linux")]
use jemallocator::Jemalloc;
Expand Down
25 changes: 0 additions & 25 deletions src/utils.rs

This file was deleted.

0 comments on commit e125fec

Please sign in to comment.