Skip to content

Commit ccc7c27

Browse files
authored
Replace **kwargs with a ValidatorInfo object in validation functions (#423)
1 parent 5763f90 commit ccc7c27

18 files changed

+164
-150
lines changed

pydantic_core/core_schema.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ def __repr__(self) -> str:
7070
...
7171

7272

73+
class ValidationInfo(Protocol):
74+
"""
75+
Argument passed to validation functions.
76+
"""
77+
78+
data: Dict[str, Any]
79+
"""All of the fields and data being validated for this model."""
80+
context: Dict[str, Any]
81+
"""Current validation context."""
82+
config: CoreConfig | None
83+
"""The CoreConfig that applies to this validation."""
84+
85+
7386
ExpectedSerializationTypes = Literal[
7487
'none',
7588
'int',
@@ -1430,9 +1443,7 @@ def dict_schema(
14301443

14311444

14321445
class ValidatorFunction(Protocol):
1433-
def __call__(
1434-
self, __input_value: Any, *, data: Any, config: CoreConfig | None, context: Any, **future_kwargs: Any
1435-
) -> Any: # pragma: no cover
1446+
def __call__(self, __input_value: Any, __info: ValidationInfo) -> Any: # pragma: no cover
14361447
...
14371448

14381449

@@ -1461,7 +1472,7 @@ def function_before_schema(
14611472
from typing import Any
14621473
from pydantic_core import SchemaValidator, core_schema
14631474
1464-
def fn(v: Any, **kwargs) -> str:
1475+
def fn(v: Any, info: core_schema.ValidationInfo) -> str:
14651476
v_str = str(v)
14661477
assert 'hello' in v_str
14671478
return v_str + 'world'
@@ -1503,7 +1514,7 @@ def function_after_schema(
15031514
```py
15041515
from pydantic_core import SchemaValidator, core_schema
15051516
1506-
def fn(v: str, **kwargs) -> str:
1517+
def fn(v: str, info: core_schema.ValidationInfo) -> str:
15071518
assert 'hello' in v
15081519
return v + 'world'
15091520
@@ -1537,14 +1548,7 @@ def __call__(self, input_value: Any, outer_location: str | int | None = None) ->
15371548

15381549
class WrapValidatorFunction(Protocol):
15391550
def __call__(
1540-
self,
1541-
__input_value: Any,
1542-
*,
1543-
validator: CallableValidator,
1544-
data: Any,
1545-
config: CoreConfig | None,
1546-
context: Any,
1547-
**future_kwargs: Any,
1551+
self, __input_value: Any, __validator: CallableValidator, __info: ValidationInfo
15481552
) -> Any: # pragma: no cover
15491553
...
15501554

@@ -1575,7 +1579,7 @@ def function_wrap_schema(
15751579
```py
15761580
from pydantic_core import SchemaValidator, core_schema
15771581
1578-
def fn(v: str, *, validator, **kwargs) -> str:
1582+
def fn(v: str, validator: core_schema.CallableValidator, info: core_schema.ValidationInfo) -> str:
15791583
return validator(input_value=v) + 'world'
15801584
15811585
schema = core_schema.function_wrap_schema(function=fn, schema=core_schema.str_schema())
@@ -1619,7 +1623,7 @@ def function_plain_schema(
16191623
```py
16201624
from pydantic_core import SchemaValidator, core_schema
16211625
1622-
def fn(v: str, **kwargs) -> str:
1626+
def fn(v: str, info: core_schema.ValidationInfo) -> str:
16231627
assert 'hello' in v
16241628
return v + 'world'
16251629
@@ -1912,7 +1916,7 @@ def chain_schema(
19121916
```py
19131917
from pydantic_core import SchemaValidator, core_schema
19141918
1915-
def fn(v: str, **kwargs) -> str:
1919+
def fn(v: str, info: core_schema.ValidationInfo) -> str:
19161920
assert 'hello' in v
19171921
return v + ' world'
19181922
@@ -1956,7 +1960,7 @@ def lax_or_strict_schema(
19561960
```py
19571961
from pydantic_core import SchemaValidator, core_schema
19581962
1959-
def fn(v: str, **kwargs) -> str:
1963+
def fn(v: str, info: core_schema.ValidationInfo) -> str:
19601964
assert 'hello' in v
19611965
return v + ' world'
19621966

src/build_tools.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,6 @@ pub fn function_name(f: &PyAny) -> PyResult<String> {
203203
}
204204
}
205205

206-
macro_rules! kwargs {
207-
($py:ident, $($k:ident: $v:expr),* $(,)?) => {{
208-
Some(pyo3::types::IntoPyDict::into_py_dict([$((stringify!($k), $v.into_py($py)),)*], $py).into())
209-
}};
210-
}
211-
pub(crate) use kwargs;
212-
213206
pub fn safe_repr(v: &PyAny) -> Cow<str> {
214207
match v.repr() {
215208
Ok(r) => r.to_string_lossy(),

src/validators/function.rs

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use pyo3::intern;
33
use pyo3::prelude::*;
44
use pyo3::types::{PyAny, PyDict};
55

6-
use crate::build_tools::{function_name, kwargs, py_err, SchemaDict};
6+
use crate::build_tools::{function_name, py_err, SchemaDict};
77
use crate::errors::{
88
ErrorType, LocItem, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError,
99
};
@@ -86,10 +86,10 @@ impl Validator for FunctionBeforeValidator {
8686
slots: &'data [CombinedValidator],
8787
recursion_guard: &'s mut RecursionGuard,
8888
) -> ValResult<'data, PyObject> {
89-
let kwargs = kwargs!(py, data: extra.data, config: self.config.clone_ref(py), context: extra.context);
89+
let info = ValidationInfo::new(extra, &self.config, py);
9090
let value = self
9191
.func
92-
.call(py, (input.to_object(py),), kwargs)
92+
.call1(py, (input.to_object(py), info))
9393
.map_err(|e| convert_err(py, e, input))?;
9494

9595
self.validator
@@ -129,8 +129,8 @@ impl Validator for FunctionAfterValidator {
129129
recursion_guard: &'s mut RecursionGuard,
130130
) -> ValResult<'data, PyObject> {
131131
let v = self.validator.validate(py, input, extra, slots, recursion_guard)?;
132-
let kwargs = kwargs!(py, data: extra.data, config: self.config.clone_ref(py), context: extra.context);
133-
self.func.call(py, (v,), kwargs).map_err(|e| convert_err(py, e, input))
132+
let info = ValidationInfo::new(extra, &self.config, py);
133+
self.func.call1(py, (v, info)).map_err(|e| convert_err(py, e, input))
134134
}
135135

136136
fn get_name(&self) -> &str {
@@ -178,9 +178,9 @@ impl Validator for FunctionPlainValidator {
178178
_slots: &'data [CombinedValidator],
179179
_recursion_guard: &'s mut RecursionGuard,
180180
) -> ValResult<'data, PyObject> {
181-
let kwargs = kwargs!(py, data: extra.data, config: self.config.clone_ref(py), context: extra.context);
181+
let info = ValidationInfo::new(extra, &self.config, py);
182182
self.func
183-
.call(py, (input.to_object(py),), kwargs)
183+
.call1(py, (input.to_object(py), info))
184184
.map_err(|e| convert_err(py, e, input))
185185
}
186186

@@ -208,18 +208,12 @@ impl Validator for FunctionWrapValidator {
208208
slots: &'data [CombinedValidator],
209209
recursion_guard: &'s mut RecursionGuard,
210210
) -> ValResult<'data, PyObject> {
211-
let validator_kwarg = ValidatorCallable {
211+
let call_next_validator = ValidatorCallable {
212212
validator: InternalValidator::new(py, "ValidatorCallable", &self.validator, slots, extra, recursion_guard),
213213
};
214-
let kwargs = kwargs!(
215-
py,
216-
validator: validator_kwarg,
217-
data: extra.data,
218-
config: self.config.clone_ref(py),
219-
context: extra.context,
220-
);
214+
let info = ValidationInfo::new(extra, &self.config, py);
221215
self.func
222-
.call(py, (input.to_object(py),), kwargs)
216+
.call1(py, (input.to_object(py), call_next_validator, info))
223217
.map_err(|e| convert_err(py, e, input))
224218
}
225219

@@ -303,3 +297,23 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) ->
303297
ValError::InternalErr(err)
304298
}
305299
}
300+
301+
#[pyclass(module = "pydantic_core._pydantic_core")]
302+
pub struct ValidationInfo {
303+
#[pyo3(get)]
304+
data: Option<Py<PyDict>>,
305+
#[pyo3(get)]
306+
config: PyObject,
307+
#[pyo3(get)]
308+
context: Option<PyObject>,
309+
}
310+
311+
impl ValidationInfo {
312+
fn new(extra: &Extra, config: &PyObject, py: Python) -> Self {
313+
ValidationInfo {
314+
data: extra.data.map(|v| v.into()),
315+
config: config.clone_ref(py),
316+
context: extra.context.map(|v| v.into()),
317+
}
318+
}
319+
}

tests/benchmarks/complete_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ class MyModel:
33
# __slots__ is not required, but it avoids __fields_set__ falling into __dict__
44
__slots__ = '__dict__', '__fields_set__'
55

6-
def append_func(input_value, **kwargs):
6+
def append_func(input_value, info):
77
return f'{input_value} Changed'
88

9-
def wrap_function(input_value, *, validator, **kwargs):
9+
def wrap_function(input_value, validator, info):
1010
return f'Input {validator(input_value)} Changed'
1111

1212
return {

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def validate_with_expected_error():
907907

908908
@pytest.mark.benchmark(group='raise-error')
909909
def test_dont_raise_error(benchmark):
910-
def f(input_value, **kwargs):
910+
def f(input_value, info):
911911
return input_value
912912

913913
v = SchemaValidator({'type': 'function', 'mode': 'plain', 'function': f})
@@ -919,7 +919,7 @@ def t():
919919

920920
@pytest.mark.benchmark(group='raise-error')
921921
def test_raise_error_value_error(benchmark):
922-
def f(input_value, **kwargs):
922+
def f(input_value, info):
923923
raise ValueError('this is a custom error')
924924

925925
v = SchemaValidator({'type': 'function', 'mode': 'plain', 'function': f})
@@ -936,7 +936,7 @@ def t():
936936

937937
@pytest.mark.benchmark(group='raise-error')
938938
def test_raise_error_custom(benchmark):
939-
def f(input_value, **kwargs):
939+
def f(input_value, info):
940940
raise PydanticCustomError('my_error', 'this is a custom error {foo}', {'foo': 'FOOBAR'})
941941

942942
v = SchemaValidator({'type': 'function', 'mode': 'plain', 'function': f})
@@ -1033,10 +1033,7 @@ def test_chain_list(benchmark):
10331033
validator = SchemaValidator(
10341034
{
10351035
'type': 'chain',
1036-
'steps': [
1037-
{'type': 'str'},
1038-
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: Decimal(v)},
1039-
],
1036+
'steps': [{'type': 'str'}, {'type': 'function', 'mode': 'plain', 'function': lambda v, info: Decimal(v)}],
10401037
}
10411038
)
10421039
assert validator.validate_python('42.42') == Decimal('42.42')
@@ -1047,7 +1044,7 @@ def test_chain_list(benchmark):
10471044
@pytest.mark.benchmark(group='chain')
10481045
def test_chain_function(benchmark):
10491046
validator = SchemaValidator(
1050-
{'type': 'function', 'mode': 'after', 'schema': {'type': 'str'}, 'function': lambda v, **kwargs: Decimal(v)}
1047+
{'type': 'function', 'mode': 'after', 'schema': {'type': 'str'}, 'function': lambda v, info: Decimal(v)}
10511048
)
10521049
assert validator.validate_python('42.42') == Decimal('42.42')
10531050

@@ -1061,8 +1058,8 @@ def test_chain_two_functions(benchmark):
10611058
'type': 'chain',
10621059
'steps': [
10631060
{'type': 'str'},
1064-
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: Decimal(v)},
1065-
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: v * 2},
1061+
{'type': 'function', 'mode': 'plain', 'function': lambda v, info: Decimal(v)},
1062+
{'type': 'function', 'mode': 'plain', 'function': lambda v, info: v * 2},
10661063
],
10671064
}
10681065
)
@@ -1080,10 +1077,10 @@ def test_chain_nested_functions(benchmark):
10801077
'type': 'function',
10811078
'schema': {'type': 'str'},
10821079
'mode': 'after',
1083-
'function': lambda v, **kwargs: Decimal(v),
1080+
'function': lambda v, info: Decimal(v),
10841081
},
10851082
'mode': 'after',
1086-
'function': lambda v, **kwargs: v * 2,
1083+
'function': lambda v, info: v * 2,
10871084
}
10881085
)
10891086
assert validator.validate_python('42.42') == Decimal('84.84')
@@ -1096,7 +1093,7 @@ def validate_yield(iterable, validator):
10961093
yield validator(item)
10971094

10981095

1099-
def generator_gen_python(v, *, validator, **_kwargs):
1096+
def generator_gen_python(v, validator, info):
11001097
try:
11011098
iterable = iter(v)
11021099
except TypeError:

tests/serializers/test_other.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,20 @@ def test_chain():
1616

1717

1818
def test_function_plain():
19-
s = SchemaSerializer(core_schema.function_plain_schema(lambda v, **kwargs: v + 1))
19+
s = SchemaSerializer(core_schema.function_plain_schema(lambda v, info: v + 1))
2020
# can't infer the type from plain function validators
2121
# insert_assert(plain_repr(s))
2222
assert plain_repr(s) == 'SchemaSerializer(serializer=Any(AnySerializer),slots=[])'
2323

2424

2525
def test_function_before():
26-
s = SchemaSerializer(core_schema.function_before_schema(lambda v, **kwargs: v + 1, core_schema.int_schema()))
26+
s = SchemaSerializer(core_schema.function_before_schema(lambda v, info: v + 1, core_schema.int_schema()))
2727
# insert_assert(plain_repr(s))
2828
assert plain_repr(s) == 'SchemaSerializer(serializer=Int(IntSerializer),slots=[])'
2929

3030

3131
def test_function_after():
32-
s = SchemaSerializer(core_schema.function_after_schema(core_schema.int_schema(), lambda v, **kwargs: v + 1))
32+
s = SchemaSerializer(core_schema.function_after_schema(core_schema.int_schema(), lambda v, info: v + 1))
3333
# insert_assert(plain_repr(s))
3434
assert plain_repr(s) == 'SchemaSerializer(serializer=Int(IntSerializer),slots=[])'
3535

0 commit comments

Comments
 (0)