Skip to content

Commit 7d7ca9f

Browse files
authored
general to_json method and more inferred types (#372)
* general "to_json" method and more inferred types * support or filtering generators * increase CI width * allow windows tests to fail * another xfail * remove incorrect xfail
1 parent a478fa4 commit 7d7ca9f

File tree

18 files changed

+339
-100
lines changed

18 files changed

+339
-100
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ on:
99
pull_request: {}
1010

1111
env:
12-
COLUMNS: 120
12+
COLUMNS: 150
1313

1414
jobs:
1515
test-cpython:

pydantic_core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Url,
1212
ValidationError,
1313
__version__,
14+
to_json,
1415
)
1516
from .core_schema import CoreConfig, CoreSchema
1617

@@ -29,4 +30,5 @@
2930
'PydanticOmit',
3031
'PydanticSerializationError',
3132
'PydanticSerializationUnexpectedValue',
33+
'to_json',
3234
)

pydantic_core/_pydantic_core.pyi

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ from typing import Any, TypedDict
55
from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType
66

77
if sys.version_info < (3, 11):
8-
from typing_extensions import NotRequired, TypeAlias
8+
from typing_extensions import Literal, NotRequired, TypeAlias
99
else:
10-
from typing import NotRequired, TypeAlias
10+
from typing import Literal, NotRequired, TypeAlias
1111

1212
__all__ = (
1313
'__version__',
@@ -75,6 +75,17 @@ class SchemaSerializer:
7575
warnings: bool = True,
7676
) -> bytes: ...
7777

78+
def to_json(
79+
value: Any,
80+
indent: int | None = None,
81+
include: IncEx = None,
82+
exclude: IncEx = None,
83+
exclude_none: bool = False,
84+
round_trip: bool = False,
85+
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
86+
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
87+
) -> bytes: ...
88+
7889
class Url:
7990
scheme: str
8091
username: 'str | None'

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ mod validators;
2424
pub use self::url::{PyMultiHostUrl, PyUrl};
2525
pub use build_tools::SchemaError;
2626
pub use errors::{list_all_errors, PydanticCustomError, PydanticKnownError, PydanticOmit, ValidationError};
27-
pub use serializers::{PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer};
27+
pub use serializers::{to_json, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer};
2828
pub use validators::SchemaValidator;
2929

3030
pub fn get_version() -> String {
@@ -52,6 +52,7 @@ fn _pydantic_core(_py: Python, m: &PyModule) -> PyResult<()> {
5252
m.add_class::<PyUrl>()?;
5353
m.add_class::<PyMultiHostUrl>()?;
5454
m.add_class::<SchemaSerializer>()?;
55+
m.add_function(wrap_pyfunction!(to_json, m)?)?;
5556
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
5657
Ok(())
5758
}

src/serializers/config.rs

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::borrow::Cow;
22
use std::str::{from_utf8, Utf8Error};
33

44
use pyo3::prelude::*;
5-
use pyo3::types::{PyBytes, PyDelta, PyDict};
5+
use pyo3::types::{PyDelta, PyDict};
66
use pyo3::{intern, PyNativeType};
77

88
use serde::ser::Error;
@@ -27,6 +27,15 @@ impl SerializationConfig {
2727
bytes_mode,
2828
})
2929
}
30+
31+
pub fn from_args(timedelta_mode: Option<&str>, bytes_mode: Option<&str>) -> PyResult<Self> {
32+
let timedelta_mode = TimedeltaMode::from_str(timedelta_mode)?;
33+
let bytes_mode = BytesMode::from_str(bytes_mode)?;
34+
Ok(Self {
35+
timedelta_mode,
36+
bytes_mode,
37+
})
38+
}
3039
}
3140

3241
#[derive(Debug, Clone)]
@@ -41,7 +50,11 @@ impl TimedeltaMode {
4150
Some(c) => c.get_as::<&str>(intern!(c.py(), "ser_json_timedelta"))?,
4251
None => None,
4352
};
44-
match raw_mode {
53+
Self::from_str(raw_mode)
54+
}
55+
56+
pub fn from_str(s: Option<&str>) -> PyResult<Self> {
57+
match s {
4558
Some("iso8601") => Ok(Self::Iso8601),
4659
Some("float") => Ok(Self::Float),
4760
Some(s) => py_err!(
@@ -113,7 +126,11 @@ impl BytesMode {
113126
Some(c) => c.get_as::<&str>(intern!(c.py(), "ser_json_bytes"))?,
114127
None => None,
115128
};
116-
let base64_config = match raw_mode {
129+
Self::from_str(raw_mode)
130+
}
131+
132+
pub fn from_str(s: Option<&str>) -> PyResult<Self> {
133+
let base64_config = match s {
117134
Some("utf8") => None,
118135
Some("base64") => Some(base64::Config::new(base64::CharacterSet::UrlSafe, true)),
119136
Some(s) => return py_err!("Invalid bytes serialization mode: `{}`, expected `utf8` or `base64`", s),
@@ -122,23 +139,21 @@ impl BytesMode {
122139
Ok(Self { base64_config })
123140
}
124141

125-
pub fn bytes_to_string<'py>(&self, py_bytes: &'py PyBytes) -> PyResult<Cow<'py, str>> {
142+
pub fn bytes_to_string<'py>(&self, py: Python, bytes: &'py [u8]) -> PyResult<Cow<'py, str>> {
126143
if let Some(config) = self.base64_config {
127-
Ok(Cow::Owned(base64::encode_config(py_bytes.as_bytes(), config)))
144+
Ok(Cow::Owned(base64::encode_config(bytes, config)))
128145
} else {
129-
py_bytes_to_str(py_bytes).map(Cow::Borrowed)
146+
from_utf8(bytes)
147+
.map_err(|err| utf8_py_error(py, err, bytes))
148+
.map(Cow::Borrowed)
130149
}
131150
}
132151

133-
pub fn serialize_bytes<S: serde::ser::Serializer>(
134-
&self,
135-
py_bytes: &PyBytes,
136-
serializer: S,
137-
) -> Result<S::Ok, S::Error> {
152+
pub fn serialize_bytes<S: serde::ser::Serializer>(&self, bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error> {
138153
if let Some(config) = self.base64_config {
139-
serializer.serialize_str(&base64::encode_config(py_bytes.as_bytes(), config))
154+
serializer.serialize_str(&base64::encode_config(bytes, config))
140155
} else {
141-
match from_utf8(py_bytes.as_bytes()) {
156+
match from_utf8(bytes) {
142157
Ok(s) => serializer.serialize_str(s),
143158
Err(e) => Err(Error::custom(e.to_string())),
144159
}
@@ -152,9 +167,3 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
152167
Err(err) => err,
153168
}
154169
}
155-
156-
fn py_bytes_to_str(py_bytes: &PyBytes) -> PyResult<&str> {
157-
let py = py_bytes.py();
158-
let data = py_bytes.as_bytes();
159-
from_utf8(data).map_err(|err| utf8_py_error(py, err, data))
160-
}

src/serializers/filter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use pyo3::types::{PyBool, PyDict, PySet, PyString};
99
use crate::build_tools::SchemaDict;
1010

1111
#[derive(Debug, Clone, Default)]
12-
pub(super) struct SchemaFilter<T> {
12+
pub(crate) struct SchemaFilter<T> {
1313
include: Option<AHashSet<T>>,
1414
exclude: Option<AHashSet<T>>,
1515
}

src/serializers/infer.rs

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
use std::borrow::Cow;
2-
use std::str::from_utf8;
32

43
use pyo3::exceptions::PyTypeError;
54
use pyo3::intern;
65
use pyo3::prelude::*;
76
use pyo3::types::{
8-
PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTime, PyTuple,
7+
PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString,
8+
PyTime, PyTuple,
99
};
1010

1111
use serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer};
1212

1313
use crate::build_tools::{py_err, safe_repr};
14+
use crate::serializers::filter::SchemaFilter;
1415
use crate::url::{PyMultiHostUrl, PyUrl};
1516

16-
use super::config::utf8_py_error;
1717
use super::errors::{py_err_se_err, PydanticSerializationError};
1818
use super::extra::{Extra, SerMode};
1919
use super::filter::AnyFilter;
@@ -97,21 +97,23 @@ pub(crate) fn infer_to_python_known(
9797
// have to do this to make sure subclasses of for example str are upcast to `str`
9898
ObType::IntSubclass => value.extract::<i64>()?.into_py(py),
9999
ObType::FloatSubclass => value.extract::<f64>()?.into_py(py),
100+
ObType::Decimal => value.to_string().into_py(py),
100101
ObType::StrSubclass => value.extract::<&str>()?.into_py(py),
101102
ObType::Bytes => extra
102103
.config
103104
.bytes_mode
104-
.bytes_to_string(value.downcast()?)
105+
.bytes_to_string(py, value.downcast::<PyBytes>()?.as_bytes())
105106
.map(|s| s.into_py(py))?,
106107
ObType::Bytearray => {
107108
let py_byte_array: &PyByteArray = value.downcast()?;
108109
// see https://docs.rs/pyo3/latest/pyo3/types/struct.PyByteArray.html#method.as_bytes
109110
// for why this is marked unsafe
110111
let bytes = unsafe { py_byte_array.as_bytes() };
111-
match from_utf8(bytes) {
112-
Ok(s) => s.into_py(py),
113-
Err(err) => return Err(utf8_py_error(py, err, bytes)),
114-
}
112+
extra
113+
.config
114+
.bytes_mode
115+
.bytes_to_string(py, bytes)
116+
.map(|s| s.into_py(py))?
115117
}
116118
ObType::Tuple => {
117119
let elements = serialize_seq_filter!(PyTuple);
@@ -163,6 +165,20 @@ pub(crate) fn infer_to_python_known(
163165
let v = value.getattr(intern!(py, "value"))?;
164166
infer_to_python(v, include, exclude, extra)?.into_py(py)
165167
}
168+
ObType::Generator => {
169+
let py_seq: &PyIterator = value.downcast()?;
170+
let mut items = Vec::new();
171+
let filter = AnyFilter::new();
172+
173+
for (index, r) in py_seq.iter()?.enumerate() {
174+
let element = r?;
175+
let op_next = filter.value_filter(index, include, exclude)?;
176+
if let Some((next_include, next_exclude)) = op_next {
177+
items.push(infer_to_python(element, next_include, next_exclude, extra)?);
178+
}
179+
}
180+
PyList::new(py, items).into_py(py)
181+
}
166182
ObType::Unknown => return Err(unknown_type_error(value)),
167183
},
168184
_ => match ob_type {
@@ -199,6 +215,17 @@ pub(crate) fn infer_to_python_known(
199215
}
200216
ObType::Dataclass => serialize_dict(object_to_dict(value, false, extra)?)?,
201217
ObType::PydanticModel => serialize_dict(object_to_dict(value, true, extra)?)?,
218+
ObType::Generator => {
219+
let iter = super::type_serializers::generator::SerializationIterator::new(
220+
value.downcast()?,
221+
super::type_serializers::any::AnySerializer::default().into(),
222+
SchemaFilter::default(),
223+
include,
224+
exclude,
225+
extra,
226+
);
227+
iter.into_py(py)
228+
}
202229
_ => value.into_py(py),
203230
},
204231
};
@@ -321,21 +348,19 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
321348
ObType::Int | ObType::IntSubclass => serialize!(i64),
322349
ObType::Bool => serialize!(bool),
323350
ObType::Float | ObType::FloatSubclass => serialize!(f64),
351+
ObType::Decimal => value.to_string().serialize(serializer),
324352
ObType::Str | ObType::StrSubclass => {
325353
let py_str: &PyString = value.downcast().map_err(py_err_se_err)?;
326354
super::type_serializers::string::serialize_py_str(py_str, serializer)
327355
}
328356
ObType::Bytes => {
329357
let py_bytes: &PyBytes = value.downcast().map_err(py_err_se_err)?;
330-
extra.config.bytes_mode.serialize_bytes(py_bytes, serializer)
358+
extra.config.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer)
331359
}
332360
ObType::Bytearray => {
333361
let py_byte_array: &PyByteArray = value.downcast().map_err(py_err_se_err)?;
334362
let bytes = unsafe { py_byte_array.as_bytes() };
335-
match from_utf8(bytes) {
336-
Ok(s) => serializer.serialize_str(s),
337-
Err(e) => Err(py_err_se_err(e)),
338-
}
363+
extra.config.bytes_mode.serialize_bytes(bytes, serializer)
339364
}
340365
ObType::Dict => serialize_dict!(value.downcast::<PyDict>().map_err(py_err_se_err)?),
341366
ObType::List => serialize_seq_filter!(PyList),
@@ -378,6 +403,20 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
378403
let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
379404
infer_serialize(v, serializer, include, exclude, extra)
380405
}
406+
ObType::Generator => {
407+
let py_seq: &PyIterator = value.downcast().map_err(py_err_se_err)?;
408+
let mut seq = serializer.serialize_seq(None)?;
409+
let filter = AnyFilter::new();
410+
for (index, r) in py_seq.iter().map_err(py_err_se_err)?.enumerate() {
411+
let element = r.map_err(py_err_se_err)?;
412+
let op_next = filter.value_filter(index, include, exclude).map_err(py_err_se_err)?;
413+
if let Some((next_include, next_exclude)) = op_next {
414+
let item_serializer = SerializeInfer::new(element, next_include, next_exclude, extra);
415+
seq.serialize_element(&item_serializer)?
416+
}
417+
}
418+
seq.end()
419+
}
381420
ObType::Unknown => return Err(py_err_se_err(unknown_type_error(value))),
382421
};
383422
extra.rec_guard.pop(value_id);
@@ -399,19 +438,20 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: &ObType, key: &'py PyAny, extra
399438
ObType::Int | ObType::IntSubclass | ObType::Float | ObType::FloatSubclass => {
400439
super::type_serializers::simple::to_str_json_key(key)
401440
}
441+
ObType::Decimal => Ok(Cow::Owned(key.to_string())),
402442
ObType::Bool => super::type_serializers::simple::bool_json_key(key),
403443
ObType::Str | ObType::StrSubclass => {
404444
let py_str: &PyString = key.downcast()?;
405445
Ok(Cow::Borrowed(py_str.to_str()?))
406446
}
407-
ObType::Bytes => extra.config.bytes_mode.bytes_to_string(key.downcast()?),
447+
ObType::Bytes => extra
448+
.config
449+
.bytes_mode
450+
.bytes_to_string(key.py(), key.downcast::<PyBytes>()?.as_bytes()),
408451
ObType::Bytearray => {
409452
let py_byte_array: &PyByteArray = key.downcast()?;
410453
let bytes = unsafe { py_byte_array.as_bytes() };
411-
match from_utf8(bytes) {
412-
Ok(s) => Ok(Cow::Borrowed(s)),
413-
Err(err) => Err(utf8_py_error(key.py(), err, bytes)),
414-
}
454+
extra.config.bytes_mode.bytes_to_string(key.py(), bytes)
415455
}
416456
ObType::Datetime => {
417457
let py_dt: &PyDateTime = key.downcast()?;
@@ -447,7 +487,7 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: &ObType, key: &'py PyAny, extra
447487
}
448488
Ok(Cow::Owned(key_build.finish()))
449489
}
450-
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict => {
490+
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
451491
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
452492
}
453493
ObType::Dataclass | ObType::PydanticModel => {

src/serializers/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,39 @@ impl SchemaSerializer {
135135
)
136136
}
137137
}
138+
139+
#[allow(clippy::too_many_arguments)]
140+
#[pyfunction]
141+
pub fn to_json(
142+
py: Python,
143+
value: &PyAny,
144+
indent: Option<usize>,
145+
include: Option<&PyAny>,
146+
exclude: Option<&PyAny>,
147+
exclude_none: Option<bool>,
148+
round_trip: Option<bool>,
149+
timedelta_mode: Option<&str>,
150+
bytes_mode: Option<&str>,
151+
) -> PyResult<PyObject> {
152+
let warnings = CollectWarnings::new(None);
153+
let rec_guard = SerRecursionGuard::default();
154+
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?;
155+
let extra = Extra::new(
156+
py,
157+
&SerMode::Json,
158+
&[],
159+
None,
160+
&warnings,
161+
None,
162+
None,
163+
exclude_none,
164+
round_trip,
165+
&config,
166+
&rec_guard,
167+
);
168+
let serializer = type_serializers::any::AnySerializer::default().into();
169+
let bytes = to_json_bytes(value, &serializer, include, exclude, &extra, indent, 1024)?;
170+
warnings.final_check(py)?;
171+
let py_bytes = PyBytes::new(py, &bytes);
172+
Ok(py_bytes.into())
173+
}

0 commit comments

Comments
 (0)