From 83976304d09e4eb73e53e564ccb8c62d7a145ee0 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Mon, 11 Sep 2023 15:58:30 +0100 Subject: [PATCH] Support alternative HashMap hasher --- prost-derive/src/field/map.rs | 61 ++-- src/encoding.rs | 580 ++++++++++++++++++++++------------ 2 files changed, 420 insertions(+), 221 deletions(-) diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index aabceb1f1..44018a7cd 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -1,8 +1,7 @@ use anyhow::{bail, Error}; use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::punctuated::Punctuated; -use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Token}; +use syn::{punctuated::Punctuated, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Token}; use crate::field::{scalar, set_option, tag_attr}; @@ -298,11 +297,6 @@ impl Field { /// The Debug tries to convert any enumerations met into the variants if possible, instead of /// outputting the raw numbers. pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { - let type_name = match self.map_ty { - MapTy::HashMap => Ident::new("HashMap", Span::call_site()), - MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()), - }; - // A fake field for generating the debug wrapper let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper)); let key = self.key_ty.rust_type(); @@ -333,20 +327,51 @@ impl Field { } let value = ty.rust_type(); - quote! { - struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>); - impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { - #fmt + match self.map_ty { + MapTy::HashMap => { + let map = Ident::new("HashMap", Span::call_site()); + quote! { + struct #wrapper_name<'a, S>(&'a ::#libname::collections::#map<#key, #value, S>); + impl<'a, S> ::core::fmt::Debug for #wrapper_name<'a, S> { + #fmt + } + } + } + MapTy::BTreeMap => { + let map = Ident::new("BTreeMap", Span::call_site()); + quote! { + struct #wrapper_name<'a>(&'a ::#libname::collections::#map<#key, #value>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + #fmt + } + } } } } - ValueTy::Message => quote! { - struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>); - impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V> - where - V: ::core::fmt::Debug + 'a, - { - #fmt + ValueTy::Message => match self.map_ty { + MapTy::HashMap => { + let name = Ident::new("HashMap", Span::call_site()); + quote! { + struct #wrapper_name<'a, V: 'a, S>(&'a ::#libname::collections::#name<#key, V, S>); + impl<'a, V, S> ::core::fmt::Debug for #wrapper_name<'a, V, S> + where + V: ::core::fmt::Debug + 'a, + { + #fmt + } + } + } + MapTy::BTreeMap => { + let map = Ident::new("BTreeMap", Span::call_site()); + quote! { + struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#map<#key, V>); + impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V> + where + V: ::core::fmt::Debug + 'a, + { + #fmt + } + } } }, } diff --git a/src/encoding.rs b/src/encoding.rs index e4d2aa274..e8db1c635 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -4,21 +4,12 @@ #![allow(clippy::implicit_hasher, clippy::ptr_arg)] -use alloc::collections::BTreeMap; -use alloc::format; -use alloc::string::String; -use alloc::vec::Vec; -use core::cmp::min; -use core::convert::TryFrom; -use core::mem; -use core::str; -use core::u32; -use core::usize; +use alloc::{collections::BTreeMap, format, string::String, vec::Vec}; +use core::{cmp::min, convert::TryFrom, mem, str, u32, usize}; use ::bytes::{Buf, BufMut, Bytes}; -use crate::DecodeError; -use crate::Message; +use crate::{DecodeError, Message}; /// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. /// The buffer must have enough remaining space (maximum 10 bytes). @@ -679,8 +670,10 @@ macro_rules! fixed_width { mod test { use proptest::prelude::*; - use super::super::test::{check_collection_type, check_type}; - use super::*; + use super::{ + super::test::{check_collection_type, check_type}, + *, + }; proptest! { #[test] @@ -854,8 +847,10 @@ pub mod string { mod test { use proptest::prelude::*; - use super::super::test::{check_collection_type, check_type}; - use super::*; + use super::{ + super::test::{check_collection_type, check_type}, + *, + }; proptest! { #[test] @@ -1017,8 +1012,10 @@ pub mod bytes { mod test { use proptest::prelude::*; - use super::super::test::{check_collection_type, check_type}; - use super::*; + use super::{ + super::test::{check_collection_type, check_type}, + *, + }; proptest! { #[test] @@ -1223,209 +1220,386 @@ pub mod group { } } -/// Rust doesn't have a `Map` trait, so macros are currently the best way to be -/// generic over `HashMap` and `BTreeMap`. -macro_rules! map { - ($map_ty:ident) => { - use crate::encoding::*; - use core::hash::Hash; - - /// Generic protobuf map encode function. - pub fn encode( - key_encode: KE, - key_encoded_len: KL, - val_encode: VE, - val_encoded_len: VL, - tag: u32, - values: &$map_ty, - buf: &mut B, - ) where - K: Default + Eq + Hash + Ord, - V: Default + PartialEq, - B: BufMut, - KE: Fn(u32, &K, &mut B), - KL: Fn(u32, &K) -> usize, - VE: Fn(u32, &V, &mut B), - VL: Fn(u32, &V) -> usize, - { - encode_with_default( - key_encode, - key_encoded_len, - val_encode, - val_encoded_len, - &V::default(), - tag, - values, - buf, - ) - } +#[cfg(feature = "std")] +pub mod hash_map { + use crate::encoding::*; + use core::hash::{BuildHasher, Hash}; + use std::collections::HashMap; - /// Generic protobuf map merge function. - pub fn merge( - key_merge: KM, - val_merge: VM, - values: &mut $map_ty, - buf: &mut B, - ctx: DecodeContext, - ) -> Result<(), DecodeError> - where - K: Default + Eq + Hash + Ord, - V: Default, - B: Buf, - KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, - VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, - { - merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx) - } + /// Generic protobuf map encode function. + #[inline] + pub fn encode( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + tag: u32, + values: &HashMap, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + encode_with_default( + key_encode, + key_encoded_len, + val_encode, + val_encoded_len, + &V::default(), + tag, + values, + buf, + ) + } - /// Generic protobuf map encode function. - pub fn encoded_len( - key_encoded_len: KL, - val_encoded_len: VL, - tag: u32, - values: &$map_ty, - ) -> usize - where - K: Default + Eq + Hash + Ord, - V: Default + PartialEq, - KL: Fn(u32, &K) -> usize, - VL: Fn(u32, &V) -> usize, - { - encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values) - } + /// Generic protobuf map merge function. + #[inline] + pub fn merge( + key_merge: KM, + val_merge: VM, + values: &mut HashMap, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + V: Default, + S: BuildHasher, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx) + } - /// Generic protobuf map encode function with an overridden value default. - /// - /// This is necessary because enumeration values can have a default value other - /// than 0 in proto2. - pub fn encode_with_default( - key_encode: KE, - key_encoded_len: KL, - val_encode: VE, - val_encoded_len: VL, - val_default: &V, - tag: u32, - values: &$map_ty, - buf: &mut B, - ) where - K: Default + Eq + Hash + Ord, - V: PartialEq, - B: BufMut, - KE: Fn(u32, &K, &mut B), - KL: Fn(u32, &K) -> usize, - VE: Fn(u32, &V, &mut B), - VL: Fn(u32, &V) -> usize, - { - for (key, val) in values.iter() { - let skip_key = key == &K::default(); - let skip_val = val == val_default; + /// Generic protobuf map encode function. + #[inline] + pub fn encoded_len( + key_encoded_len: KL, + val_encoded_len: VL, + tag: u32, + values: &HashMap, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values) + } - let len = (if skip_key { 0 } else { key_encoded_len(1, key) }) - + (if skip_val { 0 } else { val_encoded_len(2, val) }); + /// Generic protobuf map encode function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + #[inline] + pub fn encode_with_default( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &HashMap, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + for (key, val) in values.iter() { + let skip_key = key == &K::default(); + let skip_val = val == val_default; - encode_key(tag, WireType::LengthDelimited, buf); - encode_varint(len as u64, buf); - if !skip_key { - key_encode(1, key, buf); - } - if !skip_val { - val_encode(2, val, buf); - } + let len = (if skip_key { 0 } else { key_encoded_len(1, key) }) + + (if skip_val { 0 } else { val_encoded_len(2, val) }); + + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(len as u64, buf); + if !skip_key { + key_encode(1, key, buf); + } + if !skip_val { + val_encode(2, val, buf); } } + } - /// Generic protobuf map merge function with an overridden value default. - /// - /// This is necessary because enumeration values can have a default value other - /// than 0 in proto2. - pub fn merge_with_default( - key_merge: KM, - val_merge: VM, - val_default: V, - values: &mut $map_ty, - buf: &mut B, - ctx: DecodeContext, - ) -> Result<(), DecodeError> - where - K: Default + Eq + Hash + Ord, - B: Buf, - KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, - VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, - { - let mut key = Default::default(); - let mut val = val_default; - ctx.limit_reached()?; - merge_loop( - &mut (&mut key, &mut val), - buf, - ctx.enter_recursion(), - |&mut (ref mut key, ref mut val), buf, ctx| { - let (tag, wire_type) = decode_key(buf)?; - match tag { - 1 => key_merge(wire_type, key, buf, ctx), - 2 => val_merge(wire_type, val, buf, ctx), - _ => skip_field(wire_type, tag, buf, ctx), - } - }, - )?; - values.insert(key, val); - - Ok(()) - } + /// Generic protobuf map merge function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + #[inline] + pub fn merge_with_default( + key_merge: KM, + val_merge: VM, + val_default: V, + values: &mut HashMap, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + S: BuildHasher, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + let mut key = Default::default(); + let mut val = val_default; + ctx.limit_reached()?; + merge_loop( + &mut (&mut key, &mut val), + buf, + ctx.enter_recursion(), + |&mut (ref mut key, ref mut val), buf, ctx| { + let (tag, wire_type) = decode_key(buf)?; + match tag { + 1 => key_merge(wire_type, key, buf, ctx), + 2 => val_merge(wire_type, val, buf, ctx), + _ => skip_field(wire_type, tag, buf, ctx), + } + }, + )?; + values.insert(key, val); - /// Generic protobuf map encode function with an overridden value default. - /// - /// This is necessary because enumeration values can have a default value other - /// than 0 in proto2. - pub fn encoded_len_with_default( - key_encoded_len: KL, - val_encoded_len: VL, - val_default: &V, - tag: u32, - values: &$map_ty, - ) -> usize - where - K: Default + Eq + Hash + Ord, - V: PartialEq, - KL: Fn(u32, &K) -> usize, - VL: Fn(u32, &V) -> usize, - { - key_len(tag) * values.len() - + values - .iter() - .map(|(key, val)| { - let len = (if key == &K::default() { - 0 - } else { - key_encoded_len(1, key) - }) + (if val == val_default { - 0 - } else { - val_encoded_len(2, val) - }); - encoded_len_varint(len as u64) + len - }) - .sum::() - } - }; -} + Ok(()) + } -#[cfg(feature = "std")] -pub mod hash_map { - use std::collections::HashMap; - map!(HashMap); + /// Generic protobuf map encode function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + #[inline] + pub fn encoded_len_with_default( + key_encoded_len: KL, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &HashMap, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + key_len(tag) * values.len() + + values + .iter() + .map(|(key, val)| { + let len = (if key == &K::default() { + 0 + } else { + key_encoded_len(1, key) + }) + (if val == val_default { + 0 + } else { + val_encoded_len(2, val) + }); + encoded_len_varint(len as u64) + len + }) + .sum::() + } } pub mod btree_map { - map!(BTreeMap); + use crate::encoding::*; + use core::hash::Hash; + + /// Generic protobuf map encode function. + pub fn encode( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + tag: u32, + values: &BTreeMap, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + encode_with_default( + key_encode, + key_encoded_len, + val_encode, + val_encoded_len, + &V::default(), + tag, + values, + buf, + ) + } + + /// Generic protobuf map merge function. + pub fn merge( + key_merge: KM, + val_merge: VM, + values: &mut BTreeMap, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + V: Default, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx) + } + + /// Generic protobuf map encode function. + pub fn encoded_len( + key_encoded_len: KL, + val_encoded_len: VL, + tag: u32, + values: &BTreeMap, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values) + } + + /// Generic protobuf map encode function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encode_with_default( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &BTreeMap, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + for (key, val) in values.iter() { + let skip_key = key == &K::default(); + let skip_val = val == val_default; + + let len = (if skip_key { 0 } else { key_encoded_len(1, key) }) + + (if skip_val { 0 } else { val_encoded_len(2, val) }); + + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(len as u64, buf); + if !skip_key { + key_encode(1, key, buf); + } + if !skip_val { + val_encode(2, val, buf); + } + } + } + + /// Generic protobuf map merge function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn merge_with_default( + key_merge: KM, + val_merge: VM, + val_default: V, + values: &mut BTreeMap, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + let mut key = Default::default(); + let mut val = val_default; + ctx.limit_reached()?; + merge_loop( + &mut (&mut key, &mut val), + buf, + ctx.enter_recursion(), + |&mut (ref mut key, ref mut val), buf, ctx| { + let (tag, wire_type) = decode_key(buf)?; + match tag { + 1 => key_merge(wire_type, key, buf, ctx), + 2 => val_merge(wire_type, val, buf, ctx), + _ => skip_field(wire_type, tag, buf, ctx), + } + }, + )?; + values.insert(key, val); + + Ok(()) + } + + /// Generic protobuf map encode function with an overridden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encoded_len_with_default( + key_encoded_len: KL, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &BTreeMap, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + key_len(tag) * values.len() + + values + .iter() + .map(|(key, val)| { + let len = (if key == &K::default() { + 0 + } else { + key_encoded_len(1, key) + }) + (if val == val_default { + 0 + } else { + val_encoded_len(2, val) + }); + encoded_len_varint(len as u64) + len + }) + .sum::() + } } #[cfg(test)] mod test { use alloc::string::ToString; - use core::borrow::Borrow; - use core::fmt::Debug; - use core::u64; + use core::{borrow::Borrow, fmt::Debug, u64}; use ::bytes::{Bytes, BytesMut}; use proptest::{prelude::*, test_runner::TestCaseResult};