Skip to content

Commit b3f27e6

Browse files
authored
Merge pull request #46 from condekind/docs/arrays
Array chapter
2 parents 8978fe9 + b8f1f16 commit b3f27e6

File tree

14 files changed

+416
-6
lines changed

14 files changed

+416
-6
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ crate-type= ["cdylib"]
1313

1414
[dependencies]
1515
pyo3 = { version = "0.22.2", features = ["extension-module", "abi3-py38"] }
16-
pyo3-polars = { version = "0.16.1", features = ["derive", "dtype-struct", "dtype-decimal"] }
16+
pyo3-polars = { version = "0.16.1", features = ["derive", "dtype-struct", "dtype-decimal", "dtype-array"] }
1717
serde = { version = "1", features = ["derive"] }
1818
polars = { version = "0.42.0", features = ["dtype-struct"], default-features = false }
1919
polars-arrow = { version = "0.42.0", default-features = false }
20-
polars-core = { version = "0.42.0", default-features = false }
20+
polars-core = { version = "0.42.0", features = ["dtype-array"], default-features = false }
2121
polars-sql = { version = "0.42.0", default-features = false }
2222
reverse_geocoder = "4.1.1"
2323
# rust-stemmers = "1.2.0"

docs/aggregate.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 14. In (the) aggregate
1+
# 15. In (the) aggregate
22

33
Enough transorming columns! Let's aggregate them instead.
44

docs/arrays.md

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
2+
# 11. ARRAY, captain!
3+
4+
We've talked about lists, structs, but what about arrays?
5+
6+
In this section we're gonna cover how to deal with fixed sized arrays, e.g., x and y coordinates of 2d points *in the same column*:
7+
8+
```python
9+
points = pl.Series(
10+
"points",
11+
[
12+
[6.63, 8.35],
13+
[7.19, 4.85],
14+
[2.1, 4.21],
15+
[3.4, 6.13],
16+
],
17+
dtype=pl.Array(pl.Float64, 2),
18+
)
19+
df = pl.DataFrame(points)
20+
21+
print(df)
22+
```
23+
24+
```
25+
shape: (4, 1)
26+
┌───────────────┐
27+
│ points │
28+
│ --- │
29+
│ array[f64, 2] │
30+
╞═══════════════╡
31+
│ [6.63, 8.35] │
32+
│ [7.19, 4.85] │
33+
│ [2.1, 4.21] │
34+
│ [3.4, 6.13] │
35+
└───────────────┘
36+
```
37+
38+
Let's get to work - what if we wanted to make a plugin that takes a Series like `points` above, and, likewise, returned a Series of arrays?
39+
Turns out we _can_ do it! But it's a little bit tricky.
40+
41+
__First of all__, we need to include `features = ["dtype-array"]` in both `pyo3-polars` and `polars-core` in our `Cargo.toml`.
42+
43+
Now let's create a plugin that calculates the midpoint between a reference point and each point in a Series like the one above.
44+
This should illustrate both how to unpack an array inside our Rust code and also return a Series of the same type.
45+
46+
We'll start by registering our plugin:
47+
48+
```python
49+
def midpoint_2d(expr: IntoExpr, ref_point: tuple[float, float]) -> pl.Expr:
50+
return register_plugin_function(
51+
args=[expr],
52+
plugin_path=Path(__file__).parent,
53+
function_name="midpoint_2d",
54+
is_elementwise=True,
55+
kwargs={"ref_point": ref_point},
56+
)
57+
```
58+
59+
As you can see, we included an additional kwarg: `ref_point`, which we annotated with the type `tuple: [float, float]`.
60+
In our Rust code, we won't receive it as a tuple, though, it'll also be an array.
61+
This isn't crucial for this example, so just accept it for now.
62+
As you saw in the __arguments__ chapter, we take kwargs by defining a struct for them:
63+
64+
```rust
65+
#[derive(Deserialize)]
66+
struct MidPoint2DKwargs {
67+
ref_point: [f64; 2],
68+
}
69+
```
70+
71+
And we can finally move to the actual plugin code:
72+
73+
```rust
74+
// We need this to ensure the output is of dtype array.
75+
// Unfortunately, polars plugins do not support something similar to:
76+
// #[polars_expr(output_type=Array)]
77+
pub fn point_2d_output(_: &[Field]) -> PolarsResult<Field> {
78+
Ok(Field::new(
79+
"point_2d",
80+
DataType::Array(Box::new(DataType::Float64), 2),
81+
))
82+
}
83+
84+
#[polars_expr(output_type_func=point_2d_output)]
85+
fn midpoint_2d(inputs: &[Series], kwargs: MidPoint2DKwargs) -> PolarsResult<Series> {
86+
let ca: &ArrayChunked = inputs[0].array()?;
87+
let ref_point = kwargs.ref_point;
88+
89+
let out: ArrayChunked = unsafe {
90+
ca.try_apply_amortized_same_type(|row| {
91+
let s = row.as_ref();
92+
let ca = s.f64()?;
93+
let out_inner: Float64Chunked = ca
94+
.iter()
95+
.enumerate()
96+
.map(|(idx, opt_val)| {
97+
opt_val.map(|val| {
98+
(val + ref_point[idx]) / 2.0f64
99+
})
100+
}).collect_trusted();
101+
Ok(out_inner.into_series())
102+
})}?;
103+
104+
Ok(out.into_series())
105+
}
106+
```
107+
108+
Uh-oh, unsafe, we're doomed!
109+
110+
Hold on a moment - it's true that we need unsafe here, but let's not freak out.
111+
If we read the docs of `try_apply_amortized_same_type`, we see the following:
112+
113+
> ```rust
114+
> /// Try apply a closure `F` to each array.
115+
> ///
116+
> /// # Safety
117+
> /// Return series of `F` must has the same dtype and number of elements as input if it is Ok.
118+
> pub unsafe fn try_apply_amortized_same_type<F>(&self, mut f: F) -> PolarsResult<Self>
119+
> where
120+
> F: FnMut(AmortSeries) -> PolarsResult<Series>,
121+
> ```
122+
123+
124+
In this example, we can uphold that contract - we know we're returning a Series with the same number of elements and same dtype as the input!
125+
126+
Still, the code looks a bit scary, doesn't it? So let's break it down:
127+
128+
```rust
129+
let out: ArrayChunked = unsafe {
130+
131+
// This is similar to apply_values, but it's amortized and made specifically
132+
// for arrays.
133+
ca.try_apply_amortized_same_type(|row| {
134+
let s = row.as_ref();
135+
// `s` is a Series which contains two elements.
136+
// We unpack it similarly to the way we've been unpacking Series in the
137+
// previous chapters:
138+
//
139+
// Previously we've been doing this to unpack a column we had behind a
140+
// Series - this time, inside this closure, the Series contains the two
141+
// elements composing the "row" (x and y):
142+
let ca = s.f64()?;
143+
144+
// There are many ways to extract the x and y coordinates from ca.
145+
// Here, we remain idiomatic and consistent with what we've been doing
146+
// in the past - iterate, enumerate and map:
147+
let out_inner: Float64Chunked = ca
148+
.iter()
149+
.enumerate()
150+
.map(|(idx, opt_val)| {
151+
152+
// We only use map here because opt_val is an Option
153+
opt_val.map(|val| {
154+
155+
// Here's where the simple logic of calculating a
156+
// midpoint happens. We take the coordinate (`val`) at
157+
// index `idx`, add it to the `idx-th` entry of our
158+
// reference point (which is a coordinate of our point),
159+
// then divide it by two, since we're dealing with 2d
160+
// points only.
161+
(val + ref_point[idx]) / 2.0f64
162+
})
163+
// Our map already returns Some or None, so we don't have to
164+
// worry about wrapping the result in, e.g., Some()
165+
}).collect_trusted();
166+
167+
// At last, we convert out_inner (which is a Float64Chunked) back to a
168+
// Series
169+
Ok(out_inner.into_series())
170+
})}?;
171+
172+
// And finally, we convert our ArrayChunked into a Series, ready to ship to
173+
// Python-land:
174+
Ok(out.into_series())
175+
```
176+
177+
That's it. What does the result look like?
178+
In `run.py`, we have:
179+
180+
```python
181+
import polars as pl
182+
from minimal_plugin import midpoint_2d
183+
184+
points = pl.Series(
185+
"points",
186+
[
187+
[6.63, 8.35],
188+
[7.19, 4.85],
189+
[2.1, 4.21],
190+
[3.4, 6.13],
191+
[2.48, 9.26],
192+
[9.41, 7.26],
193+
[7.45, 8.85],
194+
[6.58, 5.22],
195+
[6.05, 5.77],
196+
[8.57, 4.16],
197+
[3.22, 4.98],
198+
[6.62, 6.62],
199+
[9.36, 7.44],
200+
[8.34, 3.43],
201+
[4.47, 7.61],
202+
[4.34, 5.05],
203+
[5.0, 5.05],
204+
[5.0, 5.0],
205+
[2.07, 7.8],
206+
[9.45, 9.6],
207+
[3.1, 3.26],
208+
[4.37, 5.72],
209+
],
210+
dtype=pl.Array(pl.Float64, 2),
211+
)
212+
df = pl.DataFrame(points)
213+
214+
# Now we call our plugin:
215+
result = df.with_columns(midpoints=midpoint_2d("points", ref_point=(5.0, 5.0)))
216+
print(result)
217+
```
218+
219+
Let's compile and run it:
220+
```shell
221+
maturin develop
222+
223+
python run.py
224+
```
225+
226+
🥁:
227+
```
228+
shape: (22, 2)
229+
┌───────────────┬────────────────┐
230+
│ points ┆ midpoints │
231+
│ --- ┆ --- │
232+
│ array[f64, 2] ┆ array[f64, 2] │
233+
╞═══════════════╪════════════════╡
234+
│ [6.63, 8.35] ┆ [5.815, 6.675] │
235+
│ [7.19, 4.85] ┆ [6.095, 4.925] │
236+
│ [2.1, 4.21] ┆ [3.55, 4.605] │
237+
│ [3.4, 6.13] ┆ [4.2, 5.565] │
238+
│ [2.48, 9.26] ┆ [3.74, 7.13] │
239+
│ … ┆ … │
240+
│ [5.0, 5.0] ┆ [5.0, 5.0] │
241+
│ [2.07, 7.8] ┆ [3.535, 6.4] │
242+
│ [9.45, 9.6] ┆ [7.225, 7.3] │
243+
│ [3.1, 3.26] ┆ [4.05, 4.13] │
244+
│ [4.37, 5.72] ┆ [4.685, 5.36] │
245+
└───────────────┴────────────────┘
246+
```
247+
248+
249+
!!!note
250+
Notice how the dtype remains the same.
251+
As an exercise, try to achieve the same in pure-Python (without Rust plugins)
252+
without explicitly casting the type of the Series.
253+
254+
Hurray, we did it!
255+
And why exactly go through all this trouble instead of just doing the same thing in pure Python?
256+
For performance of course!
257+
258+
_Spoilers ahead if you haven't tried the exercise from the note above_
259+
260+
With the following implementation in Python, we can take some measurements:
261+
262+
```python
263+
ref_point = (5.0, 5.0)
264+
265+
def using_plugin(df=df, ref_point=ref_point):
266+
result = df.with_columns(midpoints=midpoint_2d("points", ref_point=ref_point))
267+
return result
268+
269+
def midpoint(points:pl.Series) -> pl.Series:
270+
result=[]
271+
for point in points:
272+
result.append([(point[0]+ref_point[0])/2, (point[1]+ref_point[1])/2])
273+
return pl.Series(result, dtype=pl.Array(pl.Float64, 2))
274+
275+
def using_python(df=df, ref_point=ref_point):
276+
result = (
277+
df.with_columns(
278+
midpoints=pl.col('points').map_batches(midpoint, return_dtype=pl.Array(pl.Float64, 2))
279+
)
280+
)
281+
return result
282+
```
283+
284+
For the sake of brevity, some extra methods to generate and parse an input file were left out of the code above, as well as the `timeit` bits.
285+
By measuring both versions with 1.000.000 points a few times and taking the average, we got the following result:
286+
287+
```
288+
Using plugin:
289+
min: 0.5307095803339811
290+
max: 0.5741689523274545
291+
mean +/- stderr: 0.5524565599986263 +/- 0.0064489015434971925
292+
293+
Using python:
294+
min: 6.682447870339577
295+
max: 6.99253460233255
296+
mean +/- stderr: 6.808615755191394 +/- 0.03757884107880601
297+
```
298+
299+
A speedup of __12x__, that's a __big win__!
300+
301+
!!!note
302+
When benchmarking Rust code, remember to use `maturin develop --release`, otherwise the timings will be much slower!

docs/assets/array00.png

8.29 KB
Loading

docs/assets/array01.png

43.9 KB
Loading

docs/lost_in_space.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 11. Lost in space
1+
# 12. Lost in space
22

33
Suppose, hypothetically speaking, that you're lost somewhere and only have access
44
to your latitude, your longitude, and a laptop on which you can write a Polars Plugin.

docs/publishing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 13. Publishing your plugin to PyPI and becoming famous
1+
# 14. Publishing your plugin to PyPI and becoming famous
22

33
Here are the steps you should follow:
44

docs/vec_of_option.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
# 12. `Vec<Option<T>>` vs. `Vec<T>`
2+
# 13. `Vec<Option<T>>` vs. `Vec<T>`
33

44
> "I got, I got, I got, I got options" – _Pitbull_, before writing his first Polars plugin
55

minimal_plugin/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,13 @@ def life_step(left: IntoExpr, mid: IntoExpr, right: IntoExpr) -> pl.Expr:
166166
function_name="life_step",
167167
is_elementwise=False,
168168
)
169+
170+
171+
def midpoint_2d(expr: IntoExpr, ref_point: tuple[float, float]) -> pl.Expr:
172+
return register_plugin_function(
173+
args=[expr],
174+
plugin_path=LIB,
175+
function_name="midpoint_2d",
176+
is_elementwise=True,
177+
kwargs={"ref_point": ref_point},
178+
)

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ nav:
2424
- lists.md
2525
- lists_in_lists_out.md
2626
- struct.md
27+
- arrays.md
2728
- lost_in_space.md
2829
- vec_of_option.md
2930
- publishing.md

0 commit comments

Comments
 (0)