diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index 14a69fb..85207fb 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -1,6 +1,6 @@ use crate::attribute::BitcodeAttrs; use crate::private; -use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; +use crate::shared::{remove_lifetimes, replace_lifetimes, VariantIndex}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{ @@ -111,6 +111,7 @@ impl crate::shared::Item for Item { self, crate_name: &Path, variant_count: usize, + variant_index: VariantIndex, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { @@ -126,7 +127,12 @@ impl crate::shared::Item for Item { .then(|| { let private = private(crate_name); let c_style = inners.is_empty(); - quote! { variants: #private::VariantDecoder<#de, #variant_count, #c_style>, } + let histogram = if c_style { + 0 + } else { + variant_count + }; + quote! { variants: #private::VariantDecoder<#de, #variant_index, #variant_count, #histogram>, } }) .unwrap_or_default(); quote! { @@ -165,7 +171,7 @@ impl crate::shared::Item for Item { if inner.is_empty() { quote! {} } else { - let i = variant_index(i); + let i = variant_index.instance_to_tokens(i); let length = decode_variants .then(|| { quote! { @@ -209,7 +215,7 @@ impl crate::shared::Item for Item { .map(|i| { let inner = inner(i); let pattern = pattern(i); - let i = variant_index(i); + let i = variant_index.instance_to_tokens(i); quote! { #i => { #inner @@ -221,7 +227,7 @@ impl crate::shared::Item for Item { quote! { match self.variants.decode() { #variants - // Safety: VariantDecoder::decode outputs numbers less than N. + // Safety: VariantDecoder<_, N, _>::decode outputs numbers less than N. _ => unsafe { ::core::hint::unreachable_unchecked() } } } diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index 9680229..b4532bb 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -1,6 +1,6 @@ use crate::attribute::BitcodeAttrs; use crate::private; -use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; +use crate::shared::{remove_lifetimes, replace_lifetimes, VariantIndex}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{parse_quote, Generics, Path, Type}; @@ -114,6 +114,7 @@ impl crate::shared::Item for Item { self, crate_name: &Path, variant_count: usize, + variant_index: VariantIndex, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { @@ -124,7 +125,7 @@ impl crate::shared::Item for Item { let variants = encode_variants .then(|| { let private = private(crate_name); - quote! { variants: #private::VariantEncoder<#variant_count>, } + quote! { variants: #private::VariantEncoder<#variant_index, #variant_count>, } }) .unwrap_or_default(); let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); @@ -149,7 +150,7 @@ impl crate::shared::Item for Item { let variants: TokenStream = (0..variant_count) .map(|i| { let pattern = pattern(i); - let i = variant_index(i); + let i = variant_index.instance_to_tokens(i); quote! { #pattern => #i, } diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index 3d4f67e..5c123ae 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -9,9 +9,75 @@ use syn::{ Result, Type, WherePredicate, }; -type VariantIndex = u8; -pub fn variant_index(i: usize) -> VariantIndex { - i.try_into().unwrap() +#[derive(Copy, Clone, Debug)] +pub enum VariantIndex { + U8, + U16, + U32, +} + +impl VariantIndex { + pub fn new(variant_count: usize, ident: &Ident) -> Result { + for candidate in [Self::U8, Self::U16, Self::U32] { + if variant_count <= candidate.max_variants() { + return Ok(candidate); + } + } + err( + &ident, + &format!( + "enums with more than {} variants are not supported", + Self::U32.max_variants() + ), + ) + } + + fn max_variants(self) -> usize { + (match self { + Self::U8 => u8::MAX as usize, + Self::U16 => u16::MAX as usize, + Self::U32 => u32::MAX as usize, + }) + 1 + } + + /// If returns `false`, only C-style enums are supported. + pub fn supports_fields(self) -> bool { + match self { + Self::U8 => true, + _ => false, + } + } + + pub fn instance_to_tokens(self, index: usize) -> TokenStream { + match self { + Self::U8 => { + let n: u8 = index.try_into().unwrap(); + quote! {#n} + } + Self::U16 => { + let n: u16 = index.try_into().unwrap(); + quote! {#n} + } + Self::U32 => { + let n: u32 = index.try_into().unwrap(); + quote! {#n} + } + } + } +} + +impl ToTokens for VariantIndex { + fn to_tokens(&self, tokens: &mut TokenStream) { + use quote::TokenStreamExt; + tokens.append(Ident::new( + match self { + Self::U8 => "u8", + Self::U16 => "u16", + Self::U32 => "u32", + }, + Span::call_site(), + )); + } } pub trait Item: Copy + Sized { @@ -36,6 +102,7 @@ pub trait Item: Copy + Sized { self, crate_name: &Path, variant_count: usize, + variant_index: VariantIndex, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream; @@ -132,12 +199,20 @@ pub trait Derive { }) } Data::Enum(data_enum) => { - let max_variants = VariantIndex::MAX as usize + 1; - if data_enum.variants.len() > max_variants { - return err( - &ident, - &format!("enums with more than {max_variants} variants are not supported"), - ); + let variant_index = VariantIndex::new(data_enum.variants.len(), &ident)?; + + if !variant_index.supports_fields() { + for variant in &data_enum.variants { + if !variant.fields.is_empty() { + return err( + &ident, + &format!( + "enums with more than {} variants must not have any variants with fields", + VariantIndex::U8.max_variants() + ), + ); + } + } } // Used for adding `bounds` and skipping fields. Would be used by `#[bitcode(with_serde)]`. @@ -154,6 +229,7 @@ pub trait Derive { item.enum_impl( &attrs.crate_name, data_enum.variants.len(), + variant_index, |i| { let variant = &data_enum.variants[i]; let variant_name = &variant.ident; diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index a9beda4..aed0aad 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -3,14 +3,14 @@ use libfuzzer_sys::fuzz_target; extern crate bitcode; use arrayvec::{ArrayString, ArrayVec}; use bitcode::{Decode, DecodeOwned, Encode}; +use rust_decimal::Decimal; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::num::NonZeroU32; use std::time::Duration; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use rust_decimal::Decimal; #[inline(never)] fn test_derive(data: &[u8]) { @@ -140,6 +140,39 @@ fuzz_target!(|data: &[u8]| { pub enum Enum16 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P } #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] pub enum Enum17 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum300 { + V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, + V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, + V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, + V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, + V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, + V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, + V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, + V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, + V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, + V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, + V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, + V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, + V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, + V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, + V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, + V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, + V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, + V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, + V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, + V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, + V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, + V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, + V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, + V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, + V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, + V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, + V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, + V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, + V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, + V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, + } } use enums::*; @@ -148,10 +181,20 @@ fuzz_target!(|data: &[u8]| { A, B, C(u16), - D { a: u8, b: u8, #[serde(skip)] #[bitcode(skip)] c: u8 }, + D { + a: u8, + b: u8, + #[serde(skip)] + #[bitcode(skip)] + c: u8, + }, E(String), F, - G(#[bitcode(skip)] #[serde(skip)] i16), + G( + #[bitcode(skip)] + #[serde(skip)] + i16, + ), P(BTreeMap), } @@ -219,6 +262,7 @@ fuzz_target!(|data: &[u8]| { Enum15, Enum16, Enum17, + Enum300, Enum, ArrayString<5>, ArrayString<70>, diff --git a/src/derive/option.rs b/src/derive/option.rs index b192bae..967aec6 100644 --- a/src/derive/option.rs +++ b/src/derive/option.rs @@ -7,7 +7,7 @@ use core::mem::MaybeUninit; use core::num::NonZeroUsize; pub struct OptionEncoder { - variants: VariantEncoder<2>, + variants: VariantEncoder, some: T::Encoder, } @@ -86,7 +86,7 @@ impl Buffer for OptionEncoder { } pub struct OptionDecoder<'a, T: Decode<'a>> { - variants: VariantDecoder<'a, 2, false>, + variants: VariantDecoder<'a, u8, 2, 2>, some: T::Decoder, } diff --git a/src/derive/result.rs b/src/derive/result.rs index 9ec6971..fb7dede 100644 --- a/src/derive/result.rs +++ b/src/derive/result.rs @@ -7,7 +7,7 @@ use core::mem::MaybeUninit; use core::num::NonZeroUsize; pub struct ResultEncoder { - variants: VariantEncoder<2>, + variants: VariantEncoder, ok: T::Encoder, err: E::Encoder, } @@ -55,7 +55,7 @@ impl Buffer for ResultEncoder { } pub struct ResultDecoder<'a, T: Decode<'a>, E: Decode<'a>> { - variants: VariantDecoder<'a, 2, false>, + variants: VariantDecoder<'a, u8, 2, 2>, ok: T::Decoder, err: E::Decoder, } diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 67463c3..c2c230d 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -1,23 +1,30 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::error::err; use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; use crate::pack::{pack_bytes_less_than, unpack_bytes_less_than}; +use crate::pack_ints::{pack_ints, unpack_ints, Int}; use alloc::vec::Vec; +use core::any::TypeId; use core::num::NonZeroUsize; #[derive(Default)] -pub struct VariantEncoder(VecImpl); +pub struct VariantEncoder(VecImpl); -impl Encoder for VariantEncoder { +impl Encoder for VariantEncoder { #[inline(always)] - fn encode(&mut self, v: &u8) { + fn encode(&mut self, v: &T) { unsafe { self.0.push_unchecked(*v) }; } } -impl Buffer for VariantEncoder { +impl Buffer for VariantEncoder { fn collect_into(&mut self, out: &mut Vec) { assert!(N >= 2); - pack_bytes_less_than::(self.0.as_slice(), out); + if TypeId::of::() != TypeId::of::() { + pack_ints(self.0.as_mut_slice(), out); + } else { + pack_bytes_less_than::(bytemuck::must_cast_slice::(self.0.as_slice()), out); + }; self.0.clear(); } @@ -26,13 +33,16 @@ impl Buffer for VariantEncoder { } } -pub struct VariantDecoder<'a, const N: usize, const C_STYLE: bool> { - variants: CowSlice<'a, u8>, - histogram: [usize; N], // Not required if C_STYLE. TODO don't reserve space for it. +pub struct VariantDecoder<'a, T: Int, const N: usize, const HISTOGRAM: usize> { + variants: CowSlice<'a, T::Une>, + // `HISTOGRAM` is 0 for C style (fieldless) enums. + histogram: [usize; HISTOGRAM], } // [(); N] doesn't implement Default. -impl Default for VariantDecoder<'_, N, C_STYLE> { +impl Default + for VariantDecoder<'_, T, N, HISTOGRAM> +{ fn default() -> Self { Self { variants: Default::default(), @@ -41,30 +51,58 @@ impl Default for VariantDecoder<'_, N, C_ST } } -// C style enums don't require length, so we can skip making a histogram for them. -impl<'a, const N: usize> VariantDecoder<'a, N, false> { +// C style enums (`HISTOGRAM` = 0) don't require length, so we +// can skip making a histogram for them. +impl<'a, T: Int, const N: usize> VariantDecoder<'a, T, N, N> { pub fn length(&self, variant_index: u8) -> usize { self.histogram[variant_index as usize] } } -impl<'a, const N: usize, const C_STYLE: bool> View<'a> for VariantDecoder<'a, N, C_STYLE> { +impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> View<'a> + for VariantDecoder<'a, T, N, HISTOGRAM> +{ fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { assert!(N >= 2); - if C_STYLE { - unpack_bytes_less_than::(input, length, &mut self.variants)?; + if TypeId::of::() != TypeId::of::() { + unpack_ints::(input, length, &mut self.variants)?; + + /// Checks that `unpacked` ints are less than `N`, hopefully + /// without a branch instruction for every int. + fn check_less_than, const N: usize>( + unpacked: &[T::Une], + ) -> Result<()> { + if 2u64.pow(core::mem::size_of::() as u32 * 8) > N as u64 + && unpacked + .iter() + .copied() + .map(T::from_unaligned) + .max() + .map(Into::into) + .unwrap_or(0) + >= N + { + return err("invalid enum variant index"); + } + Ok(()) + } + + check_less_than::(unsafe { self.variants.as_slice(length) })?; } else { - self.histogram = unpack_bytes_less_than::(input, length, &mut self.variants)?; + let out = self.variants.cast_mut::(); + self.histogram = unpack_bytes_less_than::(input, length, out)?; } Ok(()) } } -impl<'a, const N: usize, const C_STYLE: bool> Decoder<'a, u8> for VariantDecoder<'a, N, C_STYLE> { +impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> Decoder<'a, T> + for VariantDecoder<'a, T, N, HISTOGRAM> +{ // Guaranteed to output numbers less than N. #[inline(always)] - fn decode(&mut self) -> u8 { - unsafe { self.variants.mut_slice().next_unchecked() } + fn decode(&mut self) -> T { + T::from_unaligned(unsafe { self.variants.mut_slice().next_unchecked() }) } } @@ -138,3 +176,59 @@ mod tests { } crate::bench_encode_decode!(bool_enum_vec: Vec<_>); } + +#[cfg(test)] +mod test2 { + use crate::{decode, encode, Decode, Encode}; + use alloc::vec::Vec; + + #[cfg_attr(not(test), rustfmt::skip)] + #[derive(Encode, Decode, Debug, PartialEq)] + pub enum Enum300 { + V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, + V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, + V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, + V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, + V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, + V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, + V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, + V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, + V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, + V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, + V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, + V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, + V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, + V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, + V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, + V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, + V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, + V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, + V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, + V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, + V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, + V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, + V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, + V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, + V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, + V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, + V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, + V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, + V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, + V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, + } + + #[allow(unused)] + #[test] + fn test_large_c_style_enum() { + assert!(matches!(decode(&encode(&Enum300::V42)), Ok(Enum300::V42))); + assert!(matches!(decode(&encode(&Enum300::V300)), Ok(Enum300::V300))); + } + + fn bench_data() -> Vec { + crate::random_data(1000) + .into_iter() + .map(|v: u16| unsafe { core::mem::transmute_copy::<_, Enum300>(&(v % 300)) }) + .collect() + } + crate::bench_encode_decode!(enum_300_variants_vec: Vec<_>); +} diff --git a/src/fast.rs b/src/fast.rs index 13adb2e..ebbe23e 100644 --- a/src/fast.rs +++ b/src/fast.rs @@ -332,11 +332,18 @@ impl<'a, T: Copy> NextUnchecked<'a, T> for &'a [T] { } /// Maybe owned [`FastSlice`]. Saves its allocation even if borrowing something. -#[derive(Default)] pub struct CowSlice<'borrowed, T> { slice: SliceImpl<'borrowed, T>, // Lifetime is min of 'borrowed and &'me self. vec: Vec, } +impl<'borrowed, T> Default for CowSlice<'borrowed, T> { + fn default() -> Self { + Self { + slice: Default::default(), + vec: Default::default(), + } + } +} impl<'borrowed, T> CowSlice<'borrowed, T> { /// Creates a [`CowSlice`] with an allocation of `vec`. None of `vec`'s elements are kept. pub fn with_allocation(mut vec: Vec) -> Self { diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 63057a6..331c467 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -240,6 +240,11 @@ impl<'a> EncoderWrapper<'a> { #[inline(always)] fn variant_index_u8(variant_index: u32) -> Result { if variant_index > u8::MAX as u32 { + // Properly optimizing the size of large enums would + // require `serde` to specify the variant count. + // + // Good news: the `derive` version of `bitcode` supports + // arbitrary-sized fieldless enums! err("enums with more than 256 variants are unsupported") } else { Ok(variant_index as u8)