From a5445383ee772c85c2b3ff7b71b3ce6cc0432b29 Mon Sep 17 00:00:00 2001 From: cynecx Date: Fri, 22 Nov 2024 23:36:54 +0100 Subject: [PATCH] more conformance fixes to oneof handling and nan/inf serialization --- conformance/failing_tests.txt | 6 +++++ prost-derive/src/serde/de.rs | 15 +++++++++--- prost-types/src/serde.rs | 9 +++++++- prost/src/serde/de.rs | 2 +- prost/src/serde/de/option.rs | 43 +++++++++++++++++++++++++++++++++++ prost/src/serde/private.rs | 2 +- 6 files changed, 71 insertions(+), 6 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 14e277477..f0e344b50 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -14,3 +14,9 @@ Recommended.FieldMaskTooManyUnderscore.JsonOutput Recommended.Proto2.JsonInput.FieldNameExtension.Validator Recommended.Proto3.JsonInput.FieldMaskInvalidCharacter Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator + +# These just aren't possible right now, because we can't represent unknown +# enum values in the message struct (they are stored as i32). +Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInMapValue.ProtobufOutput +Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInOptionalField.ProtobufOutput +Recommended.Proto3.JsonInput.IgnoreUnknownEnumStringValueInRepeatedField.ProtobufOutput diff --git a/prost-derive/src/serde/de.rs b/prost-derive/src/serde/de.rs index d30e20cfa..32746b819 100644 --- a/prost-derive/src/serde/de.rs +++ b/prost-derive/src/serde/de.rs @@ -124,15 +124,24 @@ pub fn impl_for_message( field_matches.push(quote! { __Field::#field_variant_ident(key) => { if _private::Option::is_some(&#field_variant_ident) { - return _private::Err( - <__A::Error as _serde::de::Error>::duplicate_field(#field_ident_str) + let __val = _serde::de::MapAccess::next_value_seed( + &mut __map, + _private::DesIntoWithConfig::<_private::NullDeserializer, ()>::new( + __config + ), ); + match __val { + _private::Ok(()) => continue, + _private::Err(_) => return _private::Err( + <__A::Error as _serde::de::Error>::duplicate_field(#field_ident_str) + ), + } } let __val = _serde::de::MapAccess::next_value_seed( &mut __map, _private::OneOfDeserializer(key, __config), )?; - if __val.is_some() { + if _private::Option::is_some(&__val) { #field_variant_ident = _private::Some(__val); } } diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs index d996d4803..22b1cdea2 100644 --- a/prost-types/src/serde.rs +++ b/prost-types/src/serde.rs @@ -290,7 +290,14 @@ impl CustomSerialize for Value { { match self.kind.as_ref() { Some(value::Kind::NullValue(_)) | None => serializer.serialize_none(), - Some(value::Kind::NumberValue(val)) => serializer.serialize_f64(*val), + Some(value::Kind::NumberValue(val)) => { + if val.is_nan() || val.is_infinite() { + return Err(serde::ser::Error::custom(format!( + "serializing a value::Kind::NumberValue, which is {val}, is not possible" + ))); + } + serializer.serialize_f64(*val) + } Some(value::Kind::StringValue(val)) => serializer.serialize_str(val), Some(value::Kind::BoolValue(val)) => serializer.serialize_bool(*val), Some(value::Kind::StructValue(val)) => { diff --git a/prost/src/serde/de.rs b/prost/src/serde/de.rs index c9bf1d4f4..498cc9cd3 100644 --- a/prost/src/serde/de.rs +++ b/prost/src/serde/de.rs @@ -118,7 +118,7 @@ pub use forward::ForwardDeserializer; pub use map::MapDeserializer; pub use message::MessageDeserializer; pub use oneof::{DeserializeOneOf, OneOfDeserializer}; -pub use option::OptionDeserializer; +pub use option::{NullDeserializer, OptionDeserializer}; pub use r#enum::{DeserializeEnum, EnumDeserializer}; pub use scalar::{BoolDeserializer, FloatDeserializer, IntDeserializer}; pub use vec::VecDeserializer; diff --git a/prost/src/serde/de/option.rs b/prost/src/serde/de/option.rs index 429ab0595..407fa02c7 100644 --- a/prost/src/serde/de/option.rs +++ b/prost/src/serde/de/option.rs @@ -62,3 +62,46 @@ where true } } + +pub struct NullDeserializer; + +impl DeserializeInto<()> for NullDeserializer { + #[inline] + fn deserialize_into<'de, D: serde::Deserializer<'de>>( + deserializer: D, + _config: &DeserializerConfig, + ) -> Result<(), D::Error> { + struct Visitor; + + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a null value") + } + + #[inline] + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(()) + } + + #[inline] + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(()) + } + } + + deserializer.deserialize_option(Visitor) + } + + #[inline] + fn can_deserialize_null() -> bool { + true + } +} diff --git a/prost/src/serde/private.rs b/prost/src/serde/private.rs index 7e1287ead..146afd8ad 100644 --- a/prost/src/serde/private.rs +++ b/prost/src/serde/private.rs @@ -33,5 +33,5 @@ pub use super::de::{ BoolDeserializer, BytesDeserializer, CustomDeserialize, DefaultDeserializer, DesIntoWithConfig, DesWithConfig, DeserializeEnum, DeserializeInto, DeserializeOneOf, EnumDeserializer, FloatDeserializer, ForwardDeserializer, IntDeserializer, MapDeserializer, MessageDeserializer, - OneOfDeserializer, OptionDeserializer, VecDeserializer, + NullDeserializer, OneOfDeserializer, OptionDeserializer, VecDeserializer, };