11use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
22use pyo3:: intern;
33use pyo3:: sync:: GILOnceCell ;
4- use pyo3:: types:: { IntoPyDict , PyDict , PyTuple , PyType } ;
4+ use pyo3:: types:: { IntoPyDict , PyDict , PyString , PyTuple , PyType } ;
55use pyo3:: { prelude:: * , PyTypeInfo } ;
66
77use crate :: build_tools:: { is_strict, schema_or_config_same} ;
@@ -28,6 +28,18 @@ pub fn get_decimal_type(py: Python) -> &Bound<'_, PyType> {
2828 . bind ( py)
2929}
3030
31+ fn validate_as_decimal ( py : Python , schema : & Bound < ' _ , PyDict > , key : & str ) -> PyResult < Option < Py < PyAny > > > {
32+ match schema. get_as :: < Bound < ' _ , PyAny > > ( & PyString :: new ( py, key) ) ? {
33+ Some ( value) => match value. validate_decimal ( false , py) {
34+ Ok ( v) => Ok ( Some ( v. into_inner ( ) . unbind ( ) ) ) ,
35+ Err ( _) => Err ( PyValueError :: new_err ( format ! (
36+ "'{key}' must be coercible to a Decimal instance" ,
37+ ) ) ) ,
38+ } ,
39+ None => Ok ( None ) ,
40+ }
41+ }
42+
3143#[ derive( Debug , Clone ) ]
3244pub struct DecimalValidator {
3345 strict : bool ,
@@ -50,6 +62,7 @@ impl BuildValidator for DecimalValidator {
5062 _definitions : & mut DefinitionsBuilder < CombinedValidator > ,
5163 ) -> PyResult < CombinedValidator > {
5264 let py = schema. py ( ) ;
65+
5366 let allow_inf_nan = schema_or_config_same ( schema, config, intern ! ( py, "allow_inf_nan" ) ) ?. unwrap_or ( false ) ;
5467 let decimal_places = schema. get_as ( intern ! ( py, "decimal_places" ) ) ?;
5568 let max_digits = schema. get_as ( intern ! ( py, "max_digits" ) ) ?;
@@ -58,16 +71,17 @@ impl BuildValidator for DecimalValidator {
5871 "allow_inf_nan=True cannot be used with max_digits or decimal_places" ,
5972 ) ) ;
6073 }
74+
6175 Ok ( Self {
6276 strict : is_strict ( schema, config) ?,
6377 allow_inf_nan,
6478 check_digits : decimal_places. is_some ( ) || max_digits. is_some ( ) ,
6579 decimal_places,
66- multiple_of : schema . get_as ( intern ! ( py, "multiple_of" ) ) ?,
67- le : schema . get_as ( intern ! ( py, "le" ) ) ?,
68- lt : schema . get_as ( intern ! ( py, "lt" ) ) ?,
69- ge : schema . get_as ( intern ! ( py, "ge" ) ) ?,
70- gt : schema . get_as ( intern ! ( py, "gt" ) ) ?,
80+ multiple_of : validate_as_decimal ( py, schema , "multiple_of" ) ?,
81+ le : validate_as_decimal ( py, schema , "le" ) ?,
82+ lt : validate_as_decimal ( py, schema , "lt" ) ?,
83+ ge : validate_as_decimal ( py, schema , "ge" ) ?,
84+ gt : validate_as_decimal ( py, schema , "gt" ) ?,
7185 max_digits,
7286 }
7387 . into ( ) )
0 commit comments