Skip to content

Support exclude_if callable at field level #1535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2817,7 +2817,8 @@ class TypedDictField(TypedDict, total=False):
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
serialization_alias: str
serialization_exclude: bool # default: False
metadata: Dict[str, Any]
exclude_if: Callable[[Any], bool] # default None
metadata: Any


def typed_dict_field(
Expand All @@ -2827,7 +2828,8 @@ def typed_dict_field(
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
serialization_alias: str | None = None,
serialization_exclude: bool | None = None,
metadata: Dict[str, Any] | None = None,
exclude_if: Callable[[Any], bool] | None = None,
metadata: Any = None,
) -> TypedDictField:
"""
Returns a schema that matches a typed dict field, e.g.:
Expand All @@ -2844,6 +2846,7 @@ def typed_dict_field(
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
serialization_exclude: Whether to exclude the field when serializing
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
metadata: Any other information you want to include with the schema, not used by pydantic-core
"""
return _dict_not_none(
Expand All @@ -2853,6 +2856,7 @@ def typed_dict_field(
validation_alias=validation_alias,
serialization_alias=serialization_alias,
serialization_exclude=serialization_exclude,
exclude_if=exclude_if,
metadata=metadata,
)

Expand Down Expand Up @@ -2943,6 +2947,7 @@ class ModelField(TypedDict, total=False):
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
serialization_alias: str
serialization_exclude: bool # default: False
exclude_if: Callable[[Any], bool] # default: None
frozen: bool
metadata: Dict[str, Any]

Expand All @@ -2953,6 +2958,7 @@ def model_field(
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
serialization_alias: str | None = None,
serialization_exclude: bool | None = None,
exclude_if: Callable[[Any], bool] | None = None,
frozen: bool | None = None,
metadata: Dict[str, Any] | None = None,
) -> ModelField:
Expand All @@ -2970,6 +2976,7 @@ def model_field(
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
serialization_exclude: Whether to exclude the field when serializing
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
frozen: Whether the field is frozen
metadata: Any other information you want to include with the schema, not used by pydantic-core
"""
Expand All @@ -2979,6 +2986,7 @@ def model_field(
validation_alias=validation_alias,
serialization_alias=serialization_alias,
serialization_exclude=serialization_exclude,
exclude_if=exclude_if,
frozen=frozen,
metadata=metadata,
)
Expand Down Expand Up @@ -3171,7 +3179,8 @@ class DataclassField(TypedDict, total=False):
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
serialization_alias: str
serialization_exclude: bool # default: False
metadata: Dict[str, Any]
exclude_if: Callable[[Any], bool] # default: None
metadata: Any


def dataclass_field(
Expand All @@ -3184,7 +3193,8 @@ def dataclass_field(
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
serialization_alias: str | None = None,
serialization_exclude: bool | None = None,
metadata: Dict[str, Any] | None = None,
exclude_if: Callable[[Any], bool] | None = None,
metadata: Any = None,
frozen: bool | None = None,
) -> DataclassField:
"""
Expand All @@ -3210,6 +3220,7 @@ def dataclass_field(
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
serialization_exclude: Whether to exclude the field when serializing
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
metadata: Any other information you want to include with the schema, not used by pydantic-core
frozen: Whether the field is frozen
"""
Expand All @@ -3223,6 +3234,7 @@ def dataclass_field(
validation_alias=validation_alias,
serialization_alias=serialization_alias,
serialization_exclude=serialization_exclude,
exclude_if=exclude_if,
metadata=metadata,
frozen=frozen,
)
Expand Down
58 changes: 39 additions & 19 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub(super) struct SerField {
// None serializer means exclude
pub serializer: Option<CombinedSerializer>,
pub required: bool,
pub exclude_if: Option<Py<PyAny>>,
}

impl_py_gc_traverse!(SerField { serializer });
Expand All @@ -40,6 +41,7 @@ impl SerField {
alias: Option<String>,
serializer: Option<CombinedSerializer>,
required: bool,
exclude_if: Option<Py<PyAny>>,
) -> Self {
let alias_py = alias
.as_ref()
Expand All @@ -50,6 +52,7 @@ impl SerField {
alias_py,
serializer,
required,
exclude_if,
}
}

Expand All @@ -72,6 +75,18 @@ impl SerField {
}
}

fn exclude_if(exclude_if_callable: &Option<Py<PyAny>>, value: &Bound<'_, PyAny>) -> PyResult<bool> {
if let Some(exclude_if_callable) = exclude_if_callable {
let py = value.py();
let result = exclude_if_callable.call1(py, (value,))?;
let exclude = result.extract::<bool>(py)?;
if exclude {
return Ok(true);
}
}
Ok(false)
}

fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
if extra.exclude_defaults {
if let Some(default) = serializer.get_default(value.py())? {
Expand All @@ -80,6 +95,7 @@ fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &Combine
}
}
}
// If neither condition is met, do not exclude the field
Ok(false)
}

Expand Down Expand Up @@ -176,16 +192,16 @@ impl GeneralFieldsSerializer {
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
if let Some(field) = op_field {
if let Some(ref serializer) = field.serializer {
if !exclude_default(&value, &field_extra, serializer)? {
let value = serializer.to_python(
&value,
next_include.as_ref(),
next_exclude.as_ref(),
&field_extra,
)?;
let output_key = field.get_key_py(output_dict.py(), &field_extra);
output_dict.set_item(output_key, value)?;
if exclude_default(&value, &field_extra, serializer)? {
continue;
}
if exclude_if(&field.exclude_if, &value)? {
continue;
}
let value =
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?;
let output_key = field.get_key_py(output_dict.py(), &field_extra);
output_dict.set_item(output_key, value)?;
}

if field.required {
Expand Down Expand Up @@ -263,17 +279,21 @@ impl GeneralFieldsSerializer {
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
let s = PydanticSerializer::new(
&value,
serializer,
next_include.as_ref(),
next_exclude.as_ref(),
&field_extra,
);
let output_key = field.get_key_json(key_str, &field_extra);
map.serialize_entry(&output_key, &s)?;
if exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
continue;
}
if exclude_if(&field.exclude_if, &value).map_err(py_err_se_err)? {
continue;
}
let s = PydanticSerializer::new(
&value,
serializer,
next_include.as_ref(),
next_exclude.as_ref(),
&field_extra,
);
let output_key = field.get_key_json(key_str, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(&key, &field_extra).map_err(py_err_se_err)?;
Expand Down
8 changes: 6 additions & 2 deletions src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,18 @@ impl BuildSerializer for DataclassArgsBuilder {
let key_py: Py<PyString> = PyString::new_bound(py, &name).into();

if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
fields.insert(name, SerField::new(py, key_py, None, None, true));
fields.insert(name, SerField::new(py, key_py, None, None, true, None));
} else {
let schema = field_info.get_as_req(intern!(py, "schema"))?;
let serializer = CombinedSerializer::build(&schema, config, definitions)
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", index, e))?;

let alias = field_info.get_as(intern!(py, "serialization_alias"))?;
fields.insert(name, SerField::new(py, key_py, alias, Some(serializer), true));
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
fields.insert(
name,
SerField::new(py, key_py, alias, Some(serializer), true, exclude_if),
);
}
}

Expand Down
9 changes: 6 additions & 3 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,18 @@ impl BuildSerializer for ModelFieldsBuilder {
let key_py: Py<PyString> = key_py.into();

if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
fields.insert(key, SerField::new(py, key_py, None, None, true));
fields.insert(key, SerField::new(py, key_py, None, None, true, None));
} else {
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;

let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
let schema = field_info.get_as_req(intern!(py, "schema"))?;
let serializer = CombinedSerializer::build(&schema, config, definitions)
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;

fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), true));
fields.insert(
key,
SerField::new(py, key_py, alias, Some(serializer), true, exclude_if),
);
}
}

Expand Down
9 changes: 6 additions & 3 deletions src/serializers/type_serializers/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ impl BuildSerializer for TypedDictBuilder {
let required = field_info.get_as(intern!(py, "required"))?.unwrap_or(total);

if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
fields.insert(key, SerField::new(py, key_py, None, None, required));
fields.insert(key, SerField::new(py, key_py, None, None, required, None));
} else {
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;

let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
let schema = field_info.get_as_req(intern!(py, "schema"))?;
let serializer = CombinedSerializer::build(&schema, config, definitions)
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;
fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), required));
fields.insert(
key,
SerField::new(py, key_py, alias, Some(serializer), required, exclude_if),
);
}
}

Expand Down
10 changes: 8 additions & 2 deletions tests/serializers/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_serialization_exclude():
core_schema.dataclass_args_schema(
'Foo',
[
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
core_schema.dataclass_field(name='a', schema=core_schema.str_schema(), exclude_if=lambda x: x == 'bye'),
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True),
],
),
Expand All @@ -63,12 +63,18 @@ def test_serialization_exclude():
s = SchemaSerializer(schema)
assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'}
assert s.to_python(Foo(a='hello', b=b'more'), mode='json') == {'a': 'hello'}
# a = 'bye' excludes it
assert s.to_python(Foo(a='bye', b=b'more'), mode='json') == {}
j = s.to_json(Foo(a='hello', b=b'more'))

if on_pypy:
assert json.loads(j) == {'a': 'hello'}
else:
assert j == b'{"a":"hello"}'
j = s.to_json(Foo(a='bye', b=b'more'))
if on_pypy:
assert json.loads(j) == {}
else:
assert j == b'{}'


def test_serialization_alias():
Expand Down
14 changes: 12 additions & 2 deletions tests/serializers/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,9 @@ def __init__(self, **kwargs):
MyModel,
core_schema.typed_dict_schema(
{
'a': core_schema.typed_dict_field(core_schema.any_schema()),
'a': core_schema.typed_dict_field(
core_schema.any_schema(), exclude_if=lambda x: isinstance(x, int) and x >= 2
),
'b': core_schema.typed_dict_field(core_schema.any_schema()),
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
}
Expand All @@ -535,6 +537,14 @@ def __init__(self, **kwargs):
assert s.to_json(m, exclude={'b'}) == b'{"a":1}'
assert calls == 6

m = MyModel(a=2, b=b'foobar', c='excluded')
assert s.to_python(m) == {'b': b'foobar'}
assert calls == 7
assert s.to_python(m, mode='json') == {'b': 'foobar'}
assert calls == 8
assert s.to_json(m) == b'{"b":"foobar"}'
assert calls == 9


def test_function_plain_model():
calls = 0
Expand All @@ -553,7 +563,7 @@ def __init__(self, **kwargs):
MyModel,
core_schema.typed_dict_schema(
{
'a': core_schema.typed_dict_field(core_schema.any_schema()),
'a': core_schema.typed_dict_field(core_schema.any_schema(), exclude_if=lambda x: x == 100),
'b': core_schema.typed_dict_field(core_schema.any_schema()),
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
}
Expand Down
26 changes: 26 additions & 0 deletions tests/serializers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,32 @@ def test_include_exclude_args(params):
assert json.loads(s.to_json(value, include=include, exclude=exclude)) == expected


def test_exclude_if():
s = SchemaSerializer(
core_schema.model_schema(
BasicModel,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(core_schema.int_schema(), exclude_if=lambda x: x > 1),
'b': core_schema.model_field(core_schema.str_schema(), exclude_if=lambda x: 'foo' in x),
'c': core_schema.model_field(
core_schema.str_schema(), serialization_exclude=True, exclude_if=lambda x: 'foo' in x
),
}
),
)
)
assert s.to_python(BasicModel(a=0, b='bar', c='bar')) == {'a': 0, 'b': 'bar'}
assert s.to_python(BasicModel(a=2, b='bar', c='bar')) == {'b': 'bar'}
assert s.to_python(BasicModel(a=0, b='foo', c='bar')) == {'a': 0}
assert s.to_python(BasicModel(a=2, b='foo', c='bar')) == {}

assert s.to_json(BasicModel(a=0, b='bar', c='bar')) == b'{"a":0,"b":"bar"}'
assert s.to_json(BasicModel(a=2, b='bar', c='bar')) == b'{"b":"bar"}'
assert s.to_json(BasicModel(a=0, b='foo', c='bar')) == b'{"a":0}'
assert s.to_json(BasicModel(a=2, b='foo', c='bar')) == b'{}'


def test_alias():
s = SchemaSerializer(
core_schema.model_schema(
Expand Down
Loading