Skip to content

Commit

Permalink
Merge pull request #11 from MarcoGorelli/struct
Browse files Browse the repository at this point in the history
Struct
  • Loading branch information
MarcoGorelli authored Jan 29, 2024
2 parents 59a6aa9 + bcb58fc commit 8bad3c9
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 6 deletions.
33 changes: 33 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ crate-type= ["cdylib"]
pyo3 = { version = "0.20.0", features = ["extension-module"] }
pyo3-polars = { version = "0.11.1", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
polars = { version = "0.37.0", default-features = false }
polars = { version = "0.37.0", features=["dtype-struct"], default-features = false }
# rust-stemmers = "1.2.0"

[target.'cfg(target_os = "linux")'.dependencies]
Expand Down
125 changes: 125 additions & 0 deletions docs/struct.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# STRUCTin'

> "Day one, I'm in love with your struct" Thumpasaurus (kinda)
How do we consume structs, and how do we return them?

To learn about structs, we'll rewrite a plugin which takes a `Struct` as
input, and shifts all values forwards by one key. So, for example, if
the input was `{'a': 1, 'b': 2., 'c': '3'}`, then the output will be
`{'a': 2., 'b': '3', 'c': 1}`.

On the Python side, usual business:

```python
def shift_struct(expr: IntoExpr) -> pl.Expr:
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
symbol="shift_struct",
is_elementwise=True,
)
```

On the Roost side, we need to start by activating the necessary
feature - in `Cargo.toml`, please make this change:
```diff
-polars = { version = "0.37.0", default-features = false }
+polars = { version = "0.37.0", features=["dtype-struct"], default-features = false }
```

Then, we need to get the schema right.
```Rust
fn shifted_struct(input_fields: &[Field]) -> PolarsResult<Field> {
let field = &input_fields[0];
match field.data_type() {
DataType::Struct(fields) => {
let mut field_0 = fields[0].clone();
let name = field_0.name().clone();
field_0.set_name(fields[fields.len() - 1].name().clone());
let mut fields = fields[1..]
.iter()
.zip(fields[0..fields.len() - 1].iter())
.map(|(fld, name)| Field::new(name.name(), fld.data_type().clone()))
.collect::<Vec<_>>();
fields.push(field_0);
Ok(Field::new(&name, DataType::Struct(fields)))
}
_ => unreachable!(),
}
}
```
In this case, I put the first field's name as the output struct's name, but it doesn't
really matter what we put, as Polars doesn't allow us to rename expressions within
plugins. You can always rename on the Python side if you really want to, but I'd suggest
to just let Polars follow its usual "left-hand-rule".

The function definition is going to follow a similar logic:

```rust
#[polars_expr(output_type_func=shifted_struct)]
fn shift_struct(inputs: &[Series]) -> PolarsResult<Series> {
let struct_ = inputs[0].struct_()?;
let fields = struct_.fields();
if fields.is_empty() {
return Ok(inputs[0].clone());
}
let mut field_0 = fields[0].clone();
field_0.rename(fields[fields.len() - 1].name());
let mut fields = fields[1..]
.iter()
.zip(fields[..fields.len() - 1].iter())
.map(|(s, name)| {
let mut s = s.clone();
s.rename(name.name());
s
})
.collect::<Vec<_>>();
fields.push(field_0);
StructChunked::new(struct_.name(), &fields).map(|ca| ca.into_series())
}
```

Let's try this out. Put the following in `run.py`:

```python
import polars as pl
import minimal_plugin as mp

df = pl.DataFrame(
{
"a": [1, 3, 8],
"b": [2.0, 3.1, 2.5],
"c": ["3", "7", "3"],
}
).select(abc=pl.struct("a", "b", "c"))
print(df.with_columns(abc_shifted=mp.shift_struct("abc")))
```

Compile with `maturin develop` (or `maturin develop --release` if you're
benchmarking), and if you run `python run.py` you'll see:

```
shape: (3, 2)
┌─────────────┬─────────────┐
│ abc ┆ abc_shifted │
│ --- ┆ --- │
│ struct[3] ┆ struct[3] │
╞═════════════╪═════════════╡
│ {1,2.0,"3"} ┆ {2.0,"3",1} │
│ {3,3.1,"7"} ┆ {3.1,"7",3} │
│ {8,2.5,"3"} ┆ {2.5,"3",8} │
└─────────────┴─────────────┘
```

The values look right - but is the schema?
Let's take a look
```
import pprint
pprint.pprint(df.with_columns(abc_shifted=mp.shift_struct("abc")).schema)
```
```
OrderedDict([('abc', Struct({'a': Int64, 'b': Float64, 'c': String})),
('abc_shifted', Struct({'a': Float64, 'b': String, 'c': Int64}))])
```
Looks correct!
8 changes: 8 additions & 0 deletions minimal_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,11 @@ def weighted_mean(expr: IntoExpr, weights: IntoExpr) -> pl.Expr:
is_elementwise=True,
args=[weights]
)

def shift_struct(expr: IntoExpr) -> pl.Expr:
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
symbol="shift_struct",
is_elementwise=True,
)
17 changes: 12 additions & 5 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import polars as pl
import minimal_plugin as mp

df = pl.DataFrame({
'values': [[1, 3, 2], [5, 7]],
'weights': [[.5, .3, .2], [.1, .9]]
})
print(df.with_columns(weighted_mean = mp.weighted_mean('values', 'weights')))
df = pl.DataFrame(
{
"a": [1, 3, 8],
"b": [2.0, 3.1, 2.5],
"c": ["3", "7", "3"],
}
).select(abc=pl.struct("a", "b", "c"))
print(df.with_columns(abc_shifted=mp.shift_struct("abc")))
import pprint
pprint.pprint(df.with_columns(abc_shifted=mp.shift_struct("abc")).schema)
# print(df.lazy().with_columns(swapped= mp.shift_struct('a')).schema)
# print(df.lazy().with_columns(swapped= mp.shift_struct('a')).collect().schema)
41 changes: 41 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,44 @@ fn weighted_mean(inputs: &[Series]) -> PolarsResult<Series> {
);
Ok(out.into_series())
}

fn shifted_struct(input_fields: &[Field]) -> PolarsResult<Field> {
let field = &input_fields[0];
match field.data_type() {
DataType::Struct(fields) => {
let mut field_0 = fields[0].clone();
let name = field_0.name().clone();
field_0.set_name(fields[fields.len() - 1].name().clone());
let mut fields = fields[1..]
.iter()
.zip(fields[0..fields.len() - 1].iter())
.map(|(fld, name)| Field::new(name.name(), fld.data_type().clone()))
.collect::<Vec<_>>();
fields.push(field_0);
Ok(Field::new(&name, DataType::Struct(fields)))
}
_ => unreachable!(),
}
}

#[polars_expr(output_type_func=shifted_struct)]
fn shift_struct(inputs: &[Series]) -> PolarsResult<Series> {
let struct_ = inputs[0].struct_()?;
let fields = struct_.fields();
if fields.is_empty() {
return Ok(inputs[0].clone());
}
let mut field_0 = fields[0].clone();
field_0.rename(fields[fields.len() - 1].name());
let mut fields = fields[1..]
.iter()
.zip(fields[..fields.len() - 1].iter())
.map(|(s, name)| {
let mut s = s.clone();
s.rename(name.name());
s
})
.collect::<Vec<_>>();
fields.push(field_0);
StructChunked::new(struct_.name(), &fields).map(|ca| ca.into_series())
}

0 comments on commit 8bad3c9

Please sign in to comment.