Skip to content

Commit 21132a8

Browse files
authored
derive(FromPyObject): adds default option (#4829)
* derive(FromPyObject): adds default option Takes an optional expression to set a custom value that is not the one from the Default trait * Documentation, testing and hygiene * Support enum variant named fields and cover failures
1 parent 1840bc5 commit 21132a8

File tree

8 files changed

+258
-14
lines changed

8 files changed

+258
-14
lines changed

guide/src/conversions/traits.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,48 @@ If the input is neither a string nor an integer, the error message will be:
488488
- apply a custom function to convert the field from Python the desired Rust type.
489489
- the argument must be the name of the function as a string.
490490
- the function signature must be `fn(&Bound<PyAny>) -> PyResult<T>` where `T` is the Rust type of the argument.
491+
- `pyo3(default)`, `pyo3(default = ...)`
492+
- if the argument is set, uses the given default value.
493+
- in this case, the argument must be a Rust expression returning a value of the desired Rust type.
494+
- if the argument is not set, [`Default::default`](https://doc.rust-lang.org/std/default/trait.Default.html#tymethod.default) is used.
495+
- note that the default value is only used if the field is not set.
496+
If the field is set and the conversion function from Python to Rust fails, an exception is raised and the default value is not used.
497+
- this attribute is only supported on named fields.
498+
499+
For example, the code below applies the given conversion function on the `"value"` dict item to compute its length or fall back to the type default value (0):
500+
501+
```rust
502+
use pyo3::prelude::*;
503+
504+
#[derive(FromPyObject)]
505+
struct RustyStruct {
506+
#[pyo3(item("value"), default, from_py_with = "Bound::<'_, PyAny>::len")]
507+
len: usize,
508+
#[pyo3(item)]
509+
other: usize,
510+
}
511+
#
512+
# use pyo3::types::PyDict;
513+
# fn main() -> PyResult<()> {
514+
# Python::with_gil(|py| -> PyResult<()> {
515+
# // Filled case
516+
# let dict = PyDict::new(py);
517+
# dict.set_item("value", (1,)).unwrap();
518+
# dict.set_item("other", 1).unwrap();
519+
# let result = dict.extract::<RustyStruct>()?;
520+
# assert_eq!(result.len, 1);
521+
# assert_eq!(result.other, 1);
522+
#
523+
# // Empty case
524+
# let dict = PyDict::new(py);
525+
# dict.set_item("other", 1).unwrap();
526+
# let result = dict.extract::<RustyStruct>()?;
527+
# assert_eq!(result.len, 0);
528+
# assert_eq!(result.other, 1);
529+
# Ok(())
530+
# })
531+
# }
532+
```
491533

492534
### `IntoPyObject`
493535
The ['IntoPyObject'] trait defines the to-python conversion for a Rust type. All types in PyO3 implement this trait,

newsfragments/4829.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`derive(FromPyObject)` allow a `default` attribute to set a default value for extracted fields of named structs. The default value is either provided explicitly or fetched via `Default::default()`.

pyo3-macros-backend/src/attributes.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {
351351

352352
pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, LitStrValue<ExprPath>>;
353353

354+
pub type DefaultAttribute = OptionalKeywordAttribute<Token![default], Expr>;
355+
354356
/// For specifying the path to the pyo3 crate.
355357
pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;
356358

pyo3-macros-backend/src/frompyobject.rs

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
1+
use crate::attributes::{
2+
self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute,
3+
};
24
use crate::utils::Ctx;
35
use proc_macro2::TokenStream;
4-
use quote::{format_ident, quote};
6+
use quote::{format_ident, quote, ToTokens};
57
use syn::{
68
ext::IdentExt,
79
parenthesized,
@@ -90,6 +92,7 @@ struct NamedStructField<'a> {
9092
ident: &'a syn::Ident,
9193
getter: Option<FieldGetter>,
9294
from_py_with: Option<FromPyWithAttribute>,
95+
default: Option<DefaultAttribute>,
9396
}
9497

9598
struct TupleStructField {
@@ -144,6 +147,10 @@ impl<'a> Container<'a> {
144147
attrs.getter.is_none(),
145148
field.span() => "`getter` is not permitted on tuple struct elements."
146149
);
150+
ensure_spanned!(
151+
attrs.default.is_none(),
152+
field.span() => "`default` is not permitted on tuple struct elements."
153+
);
147154
Ok(TupleStructField {
148155
from_py_with: attrs.from_py_with,
149156
})
@@ -193,10 +200,15 @@ impl<'a> Container<'a> {
193200
ident,
194201
getter: attrs.getter,
195202
from_py_with: attrs.from_py_with,
203+
default: attrs.default,
196204
})
197205
})
198206
.collect::<Result<Vec<_>>>()?;
199-
if options.transparent {
207+
if struct_fields.iter().all(|field| field.default.is_some()) {
208+
bail_spanned!(
209+
fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
210+
)
211+
} else if options.transparent {
200212
ensure_spanned!(
201213
struct_fields.len() == 1,
202214
fields.span() => "transparent structs and variants can only have 1 field"
@@ -346,18 +358,33 @@ impl<'a> Container<'a> {
346358
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
347359
}
348360
};
349-
let extractor = match &field.from_py_with {
350-
None => {
351-
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
352-
}
353-
Some(FromPyWithAttribute {
354-
value: expr_path, ..
355-
}) => {
356-
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
357-
}
361+
let extractor = if let Some(FromPyWithAttribute {
362+
value: expr_path, ..
363+
}) = &field.from_py_with
364+
{
365+
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &value, #struct_name, #field_name)?)
366+
} else {
367+
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
368+
};
369+
let extracted = if let Some(default) = &field.default {
370+
let default_expr = if let Some(default_expr) = &default.value {
371+
default_expr.to_token_stream()
372+
} else {
373+
quote!(::std::default::Default::default())
374+
};
375+
quote!(if let ::std::result::Result::Ok(value) = #getter {
376+
#extractor
377+
} else {
378+
#default_expr
379+
})
380+
} else {
381+
quote!({
382+
let value = #getter?;
383+
#extractor
384+
})
358385
};
359386

360-
fields.push(quote!(#ident: #extractor));
387+
fields.push(quote!(#ident: #extracted));
361388
}
362389

363390
quote!(::std::result::Result::Ok(#self_ty{#fields}))
@@ -458,6 +485,7 @@ impl ContainerOptions {
458485
struct FieldPyO3Attributes {
459486
getter: Option<FieldGetter>,
460487
from_py_with: Option<FromPyWithAttribute>,
488+
default: Option<DefaultAttribute>,
461489
}
462490

463491
#[derive(Clone, Debug)]
@@ -469,6 +497,7 @@ enum FieldGetter {
469497
enum FieldPyO3Attribute {
470498
Getter(FieldGetter),
471499
FromPyWith(FromPyWithAttribute),
500+
Default(DefaultAttribute),
472501
}
473502

474503
impl Parse for FieldPyO3Attribute {
@@ -512,6 +541,8 @@ impl Parse for FieldPyO3Attribute {
512541
}
513542
} else if lookahead.peek(attributes::kw::from_py_with) {
514543
input.parse().map(FieldPyO3Attribute::FromPyWith)
544+
} else if lookahead.peek(Token![default]) {
545+
input.parse().map(FieldPyO3Attribute::Default)
515546
} else {
516547
Err(lookahead.error())
517548
}
@@ -523,6 +554,7 @@ impl FieldPyO3Attributes {
523554
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
524555
let mut getter = None;
525556
let mut from_py_with = None;
557+
let mut default = None;
526558

527559
for attr in attrs {
528560
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
@@ -542,6 +574,13 @@ impl FieldPyO3Attributes {
542574
);
543575
from_py_with = Some(from_py_with_attr);
544576
}
577+
FieldPyO3Attribute::Default(default_attr) => {
578+
ensure_spanned!(
579+
default.is_none(),
580+
attr.span() => "`default` may only be provided once"
581+
);
582+
default = Some(default_attr);
583+
}
545584
}
546585
}
547586
}
@@ -550,6 +589,7 @@ impl FieldPyO3Attributes {
550589
Ok(FieldPyO3Attributes {
551590
getter,
552591
from_py_with,
592+
default,
553593
})
554594
}
555595
}

src/tests/hygiene/misc.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ struct Derive3 {
1212
f: i32,
1313
#[pyo3(item(42))]
1414
g: i32,
15+
#[pyo3(default)]
16+
h: i32,
1517
} // struct case
1618

1719
#[derive(crate::FromPyObject)]

tests/test_frompyobject.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,3 +686,117 @@ fn test_with_keyword_item() {
686686
assert_eq!(result, expected);
687687
});
688688
}
689+
690+
#[derive(Debug, FromPyObject, PartialEq, Eq)]
691+
pub struct WithDefaultItem {
692+
#[pyo3(item, default)]
693+
opt: Option<usize>,
694+
#[pyo3(item)]
695+
value: usize,
696+
}
697+
698+
#[test]
699+
fn test_with_default_item() {
700+
Python::with_gil(|py| {
701+
let dict = PyDict::new(py);
702+
dict.set_item("value", 3).unwrap();
703+
let result = dict.extract::<WithDefaultItem>().unwrap();
704+
let expected = WithDefaultItem {
705+
value: 3,
706+
opt: None,
707+
};
708+
assert_eq!(result, expected);
709+
});
710+
}
711+
712+
#[derive(Debug, FromPyObject, PartialEq, Eq)]
713+
pub struct WithExplicitDefaultItem {
714+
#[pyo3(item, default = 1)]
715+
opt: usize,
716+
#[pyo3(item)]
717+
value: usize,
718+
}
719+
720+
#[test]
721+
fn test_with_explicit_default_item() {
722+
Python::with_gil(|py| {
723+
let dict = PyDict::new(py);
724+
dict.set_item("value", 3).unwrap();
725+
let result = dict.extract::<WithExplicitDefaultItem>().unwrap();
726+
let expected = WithExplicitDefaultItem { value: 3, opt: 1 };
727+
assert_eq!(result, expected);
728+
});
729+
}
730+
731+
#[derive(Debug, FromPyObject, PartialEq, Eq)]
732+
pub struct WithDefaultItemAndConversionFunction {
733+
#[pyo3(item, default, from_py_with = "Bound::<'_, PyAny>::len")]
734+
opt: usize,
735+
#[pyo3(item)]
736+
value: usize,
737+
}
738+
739+
#[test]
740+
fn test_with_default_item_and_conversion_function() {
741+
Python::with_gil(|py| {
742+
// Filled case
743+
let dict = PyDict::new(py);
744+
dict.set_item("opt", (1,)).unwrap();
745+
dict.set_item("value", 3).unwrap();
746+
let result = dict
747+
.extract::<WithDefaultItemAndConversionFunction>()
748+
.unwrap();
749+
let expected = WithDefaultItemAndConversionFunction { opt: 1, value: 3 };
750+
assert_eq!(result, expected);
751+
752+
// Empty case
753+
let dict = PyDict::new(py);
754+
dict.set_item("value", 3).unwrap();
755+
let result = dict
756+
.extract::<WithDefaultItemAndConversionFunction>()
757+
.unwrap();
758+
let expected = WithDefaultItemAndConversionFunction { opt: 0, value: 3 };
759+
assert_eq!(result, expected);
760+
761+
// Error case
762+
let dict = PyDict::new(py);
763+
dict.set_item("value", 3).unwrap();
764+
dict.set_item("opt", 1).unwrap();
765+
assert!(dict
766+
.extract::<WithDefaultItemAndConversionFunction>()
767+
.is_err());
768+
});
769+
}
770+
771+
#[derive(Debug, FromPyObject, PartialEq, Eq)]
772+
pub enum WithDefaultItemEnum {
773+
#[pyo3(from_item_all)]
774+
Foo {
775+
a: usize,
776+
#[pyo3(default)]
777+
b: usize,
778+
},
779+
NeverUsedA {
780+
a: usize,
781+
},
782+
}
783+
784+
#[test]
785+
fn test_with_default_item_enum() {
786+
Python::with_gil(|py| {
787+
// A and B filled
788+
let dict = PyDict::new(py);
789+
dict.set_item("a", 1).unwrap();
790+
dict.set_item("b", 2).unwrap();
791+
let result = dict.extract::<WithDefaultItemEnum>().unwrap();
792+
let expected = WithDefaultItemEnum::Foo { a: 1, b: 2 };
793+
assert_eq!(result, expected);
794+
795+
// A filled
796+
let dict = PyDict::new(py);
797+
dict.set_item("a", 1).unwrap();
798+
let result = dict.extract::<WithDefaultItemEnum>().unwrap();
799+
let expected = WithDefaultItemEnum::Foo { a: 1, b: 0 };
800+
assert_eq!(result, expected);
801+
});
802+
}

tests/ui/invalid_frompy_derive.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,4 +213,21 @@ struct FromItemAllConflictAttrWithArgs {
213213
field: String,
214214
}
215215

216+
#[derive(FromPyObject)]
217+
struct StructWithOnlyDefaultValues {
218+
#[pyo3(default)]
219+
field: String,
220+
}
221+
222+
#[derive(FromPyObject)]
223+
enum EnumVariantWithOnlyDefaultValues {
224+
Foo {
225+
#[pyo3(default)]
226+
field: String,
227+
},
228+
}
229+
230+
#[derive(FromPyObject)]
231+
struct NamedTuplesWithDefaultValues(#[pyo3(default)] String);
232+
216233
fn main() {}

tests/ui/invalid_frompy_derive.stderr

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ error: transparent structs and variants can only have 1 field
8484
70 | | },
8585
| |_____^
8686

87-
error: expected one of: `attribute`, `item`, `from_py_with`
87+
error: expected one of: `attribute`, `item`, `from_py_with`, `default`
8888
--> tests/ui/invalid_frompy_derive.rs:76:12
8989
|
9090
76 | #[pyo3(attr)]
@@ -223,3 +223,29 @@ error: The struct is already annotated with `from_item_all`, `attribute` is not
223223
|
224224
210 | #[pyo3(from_item_all)]
225225
| ^^^^^^^^^^^^^
226+
227+
error: cannot derive FromPyObject for structs and variants with only default values
228+
--> tests/ui/invalid_frompy_derive.rs:217:36
229+
|
230+
217 | struct StructWithOnlyDefaultValues {
231+
| ____________________________________^
232+
218 | | #[pyo3(default)]
233+
219 | | field: String,
234+
220 | | }
235+
| |_^
236+
237+
error: cannot derive FromPyObject for structs and variants with only default values
238+
--> tests/ui/invalid_frompy_derive.rs:224:9
239+
|
240+
224 | Foo {
241+
| _________^
242+
225 | | #[pyo3(default)]
243+
226 | | field: String,
244+
227 | | },
245+
| |_____^
246+
247+
error: `default` is not permitted on tuple struct elements.
248+
--> tests/ui/invalid_frompy_derive.rs:231:37
249+
|
250+
231 | struct NamedTuplesWithDefaultValues(#[pyo3(default)] String);
251+
| ^

0 commit comments

Comments
 (0)