Skip to content

Commit 938a0bb

Browse files
authored
Skip extra/unknown fields checks when using from_attributes=True (#537)
1 parent 8dca91d commit 938a0bb

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

src/lookup_key.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ impl PathItem {
478478
}
479479

480480
/// wrapper around `getattr` that returns `Ok(None)` for attribute errors, but returns other errors
481-
/// We dont check `try_from_attributes` because that check was performed on the top level object before we got here
481+
/// We don't check `try_from_attributes` because that check was performed on the top level object before we got here
482482
fn py_get_attrs<'a>(obj: &'a PyAny, attr_name: &Py<PyString>) -> PyResult<Option<&'a PyAny>> {
483483
match obj.getattr(attr_name.extract::<&PyString>(obj.py())?) {
484484
Ok(attr) => Ok(Some(attr)),

src/validators/typed_dict.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ impl Validator for TypedDictValidator {
156156

157157
// we only care about which keys have been used if we're iterating over the object for extra after
158158
// the first pass
159-
let mut used_keys: Option<AHashSet<&str>> = match self.extra_behavior {
160-
ExtraBehavior::Allow | ExtraBehavior::Forbid => Some(AHashSet::with_capacity(self.fields.len())),
159+
let mut used_keys: Option<AHashSet<&str>> = match (&self.extra_behavior, &dict) {
160+
(_, GenericMapping::PyGetAttr(_, _)) => None,
161+
(ExtraBehavior::Allow | ExtraBehavior::Forbid, _) => Some(AHashSet::with_capacity(self.fields.len())),
161162
_ => None,
162163
};
163164

tests/validators/test_typed_dict.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from dataclasses import dataclass
55
from datetime import datetime
6-
from typing import Any, Dict, Mapping, Union
6+
from typing import Any, Dict, List, Mapping, Union
77

88
import pytest
99
from dirty_equals import FunctionCheck, HasRepr, IsStr
@@ -1192,13 +1192,55 @@ class MyDataclass:
11921192
}
11931193
)
11941194

1195-
assert v.validate_python(Foobar()) == ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})
1196-
assert v.validate_python(MyDataclass()) == ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})
1197-
assert v.validate_python(Cls(a=1, b=2, c='ham')) == ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})
1198-
assert v.validate_python(Cls(a=1, b=datetime(2000, 1, 1))) == ({'a': 1, 'b': datetime(2000, 1, 1)}, {'a', 'b'})
1195+
assert v.validate_python(Foobar()) == ({'a': 1}, {'a'})
1196+
assert v.validate_python(MyDataclass()) == ({'a': 1}, {'a'})
1197+
assert v.validate_python(Cls(a=1, b=2, c='ham')) == ({'a': 1}, {'a'})
1198+
assert v.validate_python(Cls(a=1, b=datetime(2000, 1, 1))) == ({'a': 1}, {'a'})
11991199
assert v.validate_python(Cls(a=1, b=datetime.now, c=lambda: 42)) == ({'a': 1}, {'a'})
12001200

12011201

1202+
def test_from_attributes_extra_ignore_no_attributes_accessed() -> None:
1203+
v = SchemaValidator(
1204+
{
1205+
'type': 'typed-dict',
1206+
'fields': {'a': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}},
1207+
'from_attributes': True,
1208+
'extra_behavior': 'ignore',
1209+
}
1210+
)
1211+
1212+
accessed: List[str] = []
1213+
1214+
class Source:
1215+
a = 1
1216+
b = 2
1217+
1218+
def __getattribute__(self, __name: str) -> Any:
1219+
accessed.append(__name)
1220+
return super().__getattribute__(__name)
1221+
1222+
assert v.validate_python(Source()) == {'a': 1}
1223+
assert 'a' in accessed and 'b' not in accessed
1224+
1225+
1226+
def test_from_attributes_extra_forbid() -> None:
1227+
class Source:
1228+
a = 1
1229+
b = 2
1230+
1231+
v = SchemaValidator(
1232+
{
1233+
'type': 'typed-dict',
1234+
'return_fields_set': True,
1235+
'fields': {'a': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}},
1236+
'from_attributes': True,
1237+
'extra_behavior': 'forbid',
1238+
}
1239+
)
1240+
1241+
assert v.validate_python(Source()) == ({'a': 1}, {'a'})
1242+
1243+
12021244
def foobar():
12031245
pass
12041246

0 commit comments

Comments
 (0)