Skip to content

Commit 3c553b7

Browse files
authored
Computed fields (#547)
* implement computed fields 🎉 * dataclass properties * catch errors in properties * fix include and exclude * add example from pydantic/pydantic#2625, fix by_alias * fix on older python
1 parent f8cafe7 commit 3c553b7

File tree

10 files changed

+499
-23
lines changed

10 files changed

+499
-23
lines changed

pydantic_core/core_schema.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,29 @@ def model_ser_schema(cls: Type[Any], schema: CoreSchema) -> ModelSerSchema:
412412
]
413413

414414

415+
class ComputedField(TypedDict, total=False):
416+
type: Required[Literal['computed-field']]
417+
property_name: Required[str]
418+
json_return_type: JsonReturnTypes
419+
alias: str
420+
421+
422+
def computed_field(
423+
property_name: str, *, json_return_type: JsonReturnTypes | None = None, alias: str | None = None
424+
) -> ComputedField:
425+
"""
426+
ComputedFields are properties of a model or dataclass that are included in serialization.
427+
428+
Args:
429+
property_name: The name of the property on the model or dataclass
430+
json_return_type: The type that the property returns if `mode='json'`
431+
alias: The name to use in the serialized output
432+
"""
433+
return dict_not_none(
434+
type='computed-field', property_name=property_name, json_return_type=json_return_type, alias=alias
435+
)
436+
437+
415438
class AnySchema(TypedDict, total=False):
416439
type: Required[Literal['any']]
417440
ref: str
@@ -2633,6 +2656,7 @@ def typed_dict_field(
26332656
class TypedDictSchema(TypedDict, total=False):
26342657
type: Required[Literal['typed-dict']]
26352658
fields: Required[Dict[str, TypedDictField]]
2659+
computed_fields: List[ComputedField]
26362660
strict: bool
26372661
extra_validator: CoreSchema
26382662
return_fields_set: bool
@@ -2649,6 +2673,7 @@ class TypedDictSchema(TypedDict, total=False):
26492673
def typed_dict_schema(
26502674
fields: Dict[str, TypedDictField],
26512675
*,
2676+
computed_fields: list[ComputedField] | None = None,
26522677
strict: bool | None = None,
26532678
extra_validator: CoreSchema | None = None,
26542679
return_fields_set: bool | None = None,
@@ -2675,6 +2700,7 @@ def typed_dict_schema(
26752700
26762701
Args:
26772702
fields: The fields to use for the typed dict
2703+
computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model
26782704
strict: Whether the typed dict is strict
26792705
extra_validator: The extra validator to use for the typed dict
26802706
return_fields_set: Whether the typed dict should return a fields set
@@ -2689,6 +2715,7 @@ def typed_dict_schema(
26892715
return dict_not_none(
26902716
type='typed-dict',
26912717
fields=fields,
2718+
computed_fields=computed_fields,
26922719
strict=strict,
26932720
extra_validator=extra_validator,
26942721
return_fields_set=return_fields_set,
@@ -2851,6 +2878,7 @@ class DataclassArgsSchema(TypedDict, total=False):
28512878
type: Required[Literal['dataclass-args']]
28522879
dataclass_name: Required[str]
28532880
fields: Required[List[DataclassField]]
2881+
computed_fields: List[ComputedField]
28542882
populate_by_name: bool # default: False
28552883
collect_init_only: bool # default: False
28562884
ref: str
@@ -2863,6 +2891,7 @@ def dataclass_args_schema(
28632891
dataclass_name: str,
28642892
fields: list[DataclassField],
28652893
*,
2894+
computed_fields: List[ComputedField] | None = None,
28662895
populate_by_name: bool | None = None,
28672896
collect_init_only: bool | None = None,
28682897
ref: str | None = None,
@@ -2890,6 +2919,7 @@ def dataclass_args_schema(
28902919
Args:
28912920
dataclass_name: The name of the dataclass being validated
28922921
fields: The fields to use for the dataclass
2922+
computed_fields: Computed fields to use when serializing the dataclass
28932923
populate_by_name: Whether to populate by name
28942924
collect_init_only: Whether to collect init only fields into a dict to pass to `__post_init__`
28952925
ref: optional unique identifier of the schema, used to reference the schema in other places
@@ -2901,6 +2931,7 @@ def dataclass_args_schema(
29012931
type='dataclass-args',
29022932
dataclass_name=dataclass_name,
29032933
fields=fields,
2934+
computed_fields=computed_fields,
29042935
populate_by_name=populate_by_name,
29052936
collect_init_only=collect_init_only,
29062937
ref=ref,

src/serializers/computed_fields.rs

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
use pyo3::intern;
2+
use pyo3::prelude::*;
3+
use pyo3::types::{PyDict, PyList, PyString};
4+
5+
use serde::ser::SerializeMap;
6+
use serde::Serialize;
7+
8+
use crate::build_tools::SchemaDict;
9+
use crate::serializers::filter::SchemaFilter;
10+
11+
use super::errors::py_err_se_err;
12+
use super::infer::{infer_serialize, infer_serialize_known, infer_to_python, infer_to_python_known};
13+
use super::ob_type::ObType;
14+
use super::Extra;
15+
16+
use super::type_serializers::function::get_json_return_type;
17+
18+
#[derive(Debug, Clone)]
19+
pub(super) struct ComputedFields(Vec<ComputedField>);
20+
21+
impl ComputedFields {
22+
pub fn new(schema: &PyDict) -> PyResult<Option<Self>> {
23+
let py = schema.py();
24+
if let Some(computed_fields) = schema.get_as::<&PyList>(intern!(py, "computed_fields"))? {
25+
let computed_fields = computed_fields
26+
.iter()
27+
.map(ComputedField::new)
28+
.collect::<PyResult<Vec<_>>>()?;
29+
Ok(Some(Self(computed_fields)))
30+
} else {
31+
Ok(None)
32+
}
33+
}
34+
35+
pub fn to_python(
36+
&self,
37+
model: &PyAny,
38+
output_dict: &PyDict,
39+
filter: &SchemaFilter<isize>,
40+
include: Option<&PyAny>,
41+
exclude: Option<&PyAny>,
42+
extra: &Extra,
43+
) -> PyResult<()> {
44+
for computed_fields in self.0.iter() {
45+
computed_fields.to_python(model, output_dict, filter, include, exclude, extra)?;
46+
}
47+
Ok(())
48+
}
49+
50+
pub fn serde_serialize<S: serde::ser::Serializer>(
51+
&self,
52+
model: &PyAny,
53+
map: &mut S::SerializeMap,
54+
filter: &SchemaFilter<isize>,
55+
include: Option<&PyAny>,
56+
exclude: Option<&PyAny>,
57+
extra: &Extra,
58+
) -> Result<(), S::Error> {
59+
for computed_field in self.0.iter() {
60+
let property_name_py = computed_field.property_name_py.as_ref(model.py());
61+
if let Some((next_include, next_exclude)) = filter
62+
.key_filter(property_name_py, include, exclude)
63+
.map_err(py_err_se_err)?
64+
{
65+
let cfs = ComputedFieldSerializer {
66+
model,
67+
computed_field,
68+
include: next_include,
69+
exclude: next_exclude,
70+
extra,
71+
};
72+
let key = match extra.by_alias {
73+
true => computed_field.alias.as_str(),
74+
false => computed_field.property_name.as_str(),
75+
};
76+
map.serialize_entry(key, &cfs)?;
77+
}
78+
}
79+
Ok(())
80+
}
81+
}
82+
83+
#[derive(Debug, Clone)]
84+
struct ComputedField {
85+
property_name: String,
86+
property_name_py: Py<PyString>,
87+
return_ob_type: Option<ObType>,
88+
alias: String,
89+
alias_py: Py<PyString>,
90+
}
91+
92+
impl ComputedField {
93+
pub fn new(schema: &PyAny) -> PyResult<Self> {
94+
let py = schema.py();
95+
let schema: &PyDict = schema.downcast()?;
96+
let property_name: &PyString = schema.get_as_req(intern!(py, "property_name"))?;
97+
let return_ob_type = get_json_return_type(schema)?;
98+
let alias_py: &PyString = schema.get_as(intern!(py, "alias"))?.unwrap_or(property_name);
99+
Ok(Self {
100+
property_name: property_name.extract()?,
101+
property_name_py: property_name.into_py(py),
102+
return_ob_type,
103+
alias: alias_py.extract()?,
104+
alias_py: alias_py.into_py(py),
105+
})
106+
}
107+
108+
fn to_python(
109+
&self,
110+
model: &PyAny,
111+
output_dict: &PyDict,
112+
filter: &SchemaFilter<isize>,
113+
include: Option<&PyAny>,
114+
exclude: Option<&PyAny>,
115+
extra: &Extra,
116+
) -> PyResult<()> {
117+
let py = model.py();
118+
let property_name_py = self.property_name_py.as_ref(py);
119+
120+
if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? {
121+
let next_value = model.getattr(property_name_py)?;
122+
123+
// TODO fix include & exclude
124+
let value = match self.return_ob_type {
125+
Some(ref ob_type) => infer_to_python_known(ob_type, next_value, next_include, next_exclude, extra),
126+
None => infer_to_python(next_value, next_include, next_exclude, extra),
127+
}?;
128+
let key = match extra.by_alias {
129+
true => self.alias_py.as_ref(py),
130+
false => property_name_py,
131+
};
132+
output_dict.set_item(key, value)?;
133+
}
134+
Ok(())
135+
}
136+
}
137+
138+
pub(crate) struct ComputedFieldSerializer<'py> {
139+
model: &'py PyAny,
140+
computed_field: &'py ComputedField,
141+
include: Option<&'py PyAny>,
142+
exclude: Option<&'py PyAny>,
143+
extra: &'py Extra<'py>,
144+
}
145+
146+
impl<'py> Serialize for ComputedFieldSerializer<'py> {
147+
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
148+
let py = self.model.py();
149+
let property_name_py = self.computed_field.property_name_py.as_ref(py);
150+
let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?;
151+
152+
match self.computed_field.return_ob_type {
153+
Some(ref ob_type) => {
154+
infer_serialize_known(ob_type, next_value, serializer, self.include, self.exclude, self.extra)
155+
}
156+
None => infer_serialize(next_value, serializer, self.include, self.exclude, self.extra),
157+
}
158+
}
159+
}

src/serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub(crate) use extra::{Extra, SerMode, SerializationState};
1414
pub use shared::CombinedSerializer;
1515
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
1616

17+
mod computed_fields;
1718
mod config;
1819
mod errors;
1920
mod extra;

src/serializers/type_serializers/dataclass.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ use ahash::AHashMap;
66

77
use crate::build_context::BuildContext;
88
use crate::build_tools::{py_error_type, SchemaDict};
9-
use crate::serializers::filter::SchemaFilter;
10-
use crate::serializers::shared::CombinedSerializer;
119

1210
use super::model::ModelSerializer;
1311
use super::typed_dict::{TypedDictField, TypedDictSerializer};
14-
use super::BuildSerializer;
12+
use super::{BuildSerializer, CombinedSerializer, ComputedFields, SchemaFilter};
1513

1614
pub struct DataclassArgsBuilder;
1715

@@ -48,8 +46,9 @@ impl BuildSerializer for DataclassArgsBuilder {
4846
}
4947

5048
let filter = SchemaFilter::from_vec_hash(py, exclude)?;
49+
let computed_fields = ComputedFields::new(schema)?;
5150

52-
Ok(TypedDictSerializer::new(fields, false, filter).into())
51+
Ok(TypedDictSerializer::new(fields, false, filter, computed_fields).into())
5352
}
5453
}
5554

src/serializers/type_serializers/format.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ impl BuildSerializer for FormatSerializer {
8484
}
8585
}
8686
}
87+
8788
impl FormatSerializer {
8889
fn call(&self, value: &PyAny) -> Result<PyObject, String> {
8990
let py = value.py();

src/serializers/type_serializers/function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ impl SerializationCallable {
447447
}
448448
}
449449

450-
fn get_json_return_type(schema: &PyDict) -> PyResult<Option<ObType>> {
450+
pub fn get_json_return_type(schema: &PyDict) -> PyResult<Option<ObType>> {
451451
match schema.get_as::<&str>(intern!(schema.py(), "json_return_type"))? {
452452
Some(t) => Ok(Some(
453453
ObType::from_str(t).map_err(|_| py_error_type!("Unknown return type {:?}", t))?,

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub mod union;
2323
pub mod url;
2424
pub mod with_default;
2525

26+
pub(self) use super::computed_fields::ComputedFields;
2627
pub(self) use super::config::utf8_py_error;
2728
pub(self) use super::errors::{py_err_se_err, PydanticSerializationError};
2829
pub(self) use super::extra::{Extra, ExtraOwned, SerMode};

0 commit comments

Comments
 (0)