From df5729def9dbab605eb6d691b19e90241bbd47b3 Mon Sep 17 00:00:00 2001 From: Lukasz Anforowicz Date: Mon, 13 Jan 2025 14:44:51 -0800 Subject: [PATCH] Change hardcoded table sizes into `const`-generic params. (#49) --- src/decompress.rs | 96 ++++++++++++++++++++++++++++++----------------- src/huffman.rs | 4 +- src/tables.rs | 19 ++++++---- 3 files changed, 75 insertions(+), 44 deletions(-) diff --git a/src/decompress.rs b/src/decompress.rs index 41ef77b..0b1bba9 100644 --- a/src/decompress.rs +++ b/src/decompress.rs @@ -62,13 +62,17 @@ pub const LITERAL_ENTRY: u32 = 0x8000; pub const EXCEPTIONAL_ENTRY: u32 = 0x4000; pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000; +// See https://github.com/image-rs/fdeflate/issues/45 for discussion of the table sizes. +const DEFAULT_LITLEN_TABLE_SIZE: usize = 4096; +const DEFAULT_DIST_TABLE_SIZE: usize = 512; + /// The Decompressor state for a compressed block. #[derive(Eq, PartialEq, Debug)] -struct CompressedBlock { - litlen_table: Box<[u32; 4096]>, +struct CompressedBlock { + litlen_table: Box<[u32; LITLEN_TABLE_SIZE]>, secondary_table: Vec, - dist_table: Box<[u32; 512]>, + dist_table: Box<[u32; DIST_TABLE_SIZE]>, dist_secondary_table: Vec, eof_code: u16, @@ -91,7 +95,7 @@ enum State { /// Decompressor for arbitrary zlib streams. pub struct Decompressor { /// State for decoding a compressed block. - compression: CompressedBlock, + compression: CompressedBlock, // State for decoding a block header. header: BlockHeader, // Number of bytes left for uncompressed block. @@ -120,8 +124,8 @@ impl Decompressor { Self { bits: BitBuffer::new(), compression: CompressedBlock { - litlen_table: Box::new([0; 4096]), - dist_table: Box::new([0; 512]), + litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]), + dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]), secondary_table: Vec::new(), dist_secondary_table: Vec::new(), eof_code: 0, @@ -224,11 +228,11 @@ impl Decompressor { } let input0 = self.bits.peek_bits(8); - let input1 = self.bits.peek_bits(16) >> 8 & 0xff; + let input1 = (self.bits.peek_bits(16) >> 8) & 0xff; if input0 & 0x0f != 0x08 || (input0 & 0xf0) > 0x70 || input1 & 0x20 != 0 - || (input0 << 8 | input1) % 31 != 0 + || ((input0 << 8) | input1) % 31 != 0 { return Err(DecompressionError::BadZlibHeader); } @@ -394,9 +398,11 @@ impl Decompressor { // Build decoding tables if the previous block wasn't also a fixed block. if !self.fixed_table { self.fixed_table = true; + assert!(self.compression.litlen_table.len() >= FIXED_LITLEN_TABLE.len()); for chunk in self.compression.litlen_table.chunks_exact_mut(512) { chunk.copy_from_slice(&FIXED_LITLEN_TABLE); } + assert!(self.compression.dist_table.len() >= FIXED_DIST_TABLE.len()); for chunk in self.compression.dist_table.chunks_exact_mut(32) { chunk.copy_from_slice(&FIXED_DIST_TABLE); } @@ -551,7 +557,9 @@ impl Decompressor { } } -impl CompressedBlock { +impl + CompressedBlock +{ fn build_tables(&mut self, hlit: usize, code_lengths: &[u8]) -> Result<(), DecompressionError> { // If there is no code assigned for the EOF symbol then the bitstream is invalid. if code_lengths[256] == 0 { @@ -610,6 +618,20 @@ impl CompressedBlock { mut output_index: usize, queued_output: &mut Option, ) -> Result<(CompressedBlockStatus, usize), DecompressionError> { + // `litlen_table_mask` (and `dist_table_mask`) calculation assumes that `LITLEN_TABLE_SIZE` + // (or `DIST_TABLE_SIZE`) is a power of two. + assert!(LITLEN_TABLE_SIZE.count_ones() == 1); + assert!(DIST_TABLE_SIZE.count_ones() == 1); + let litlen_table_mask = (LITLEN_TABLE_SIZE as u64) - 1; + let litlen_table_bits = LITLEN_TABLE_SIZE.trailing_zeros(); + let dist_table_mask = (DIST_TABLE_SIZE as u64) - 1; + let dist_table_bits = DIST_TABLE_SIZE.trailing_zeros(); + // Lower bound on table sizes, because 1a) RFC1951 uses at most 15 bits for codewords and + // 1b) we can fit at most 8 bits of `overflow_bits_mask` in the last byte of a primary + // table entry. + assert!(litlen_table_bits + 8 >= 15); + assert!(dist_table_bits + 8 >= 15); + // Fast decoding loop. // // This loop is optimized for speed and is the main decoding loop for the decompressor, @@ -623,23 +645,25 @@ impl CompressedBlock { // the bit buffer. This is because when the input is non-empty, the bit buffer actually // has 64-bits of valid data (even though nbits will be in 56..=63). bit_buffer.fill_buffer(remaining_input); - let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize]; + let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize]; while output_index + 8 <= output.len() && remaining_input.len() >= 8 { // First check whether the next symbol is a literal. This code does up to 2 additional // table lookups to decode more literals. let mut bits; let mut litlen_code_bits = litlen_entry as u8; if litlen_entry & LITERAL_ENTRY != 0 { - let litlen_entry2 = - self.litlen_table[(bit_buffer.buffer >> litlen_code_bits & 0xfff) as usize]; + let litlen_entry2 = self.litlen_table + [((bit_buffer.buffer >> litlen_code_bits) & litlen_table_mask) as usize]; let litlen_code_bits2 = litlen_entry2 as u8; - let litlen_entry3 = self.litlen_table[(bit_buffer.buffer - >> (litlen_code_bits + litlen_code_bits2) - & 0xfff) as usize]; + let litlen_entry3 = self.litlen_table[((bit_buffer.buffer + >> (litlen_code_bits + litlen_code_bits2)) + & litlen_table_mask) + as usize]; let litlen_code_bits3 = litlen_entry3 as u8; - let litlen_entry4 = self.litlen_table[(bit_buffer.buffer - >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3) - & 0xfff) as usize]; + let litlen_entry4 = self.litlen_table[((bit_buffer.buffer + >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)) + & litlen_table_mask) + as usize]; let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize; output[output_index] = (litlen_entry >> 16) as u8; @@ -692,8 +716,8 @@ impl CompressedBlock { litlen_code_bits, ) } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 { - let secondary_table_index = - (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff)); + let secondary_table_index = (litlen_entry >> 16) + + ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff)); let secondary_entry = self.secondary_table[secondary_table_index as usize]; let litlen_symbol = secondary_entry >> 4; let litlen_code_bits = (secondary_entry & 0xf) as u8; @@ -701,7 +725,8 @@ impl CompressedBlock { match litlen_symbol { 0..=255 => { bit_buffer.consume_bits(litlen_code_bits); - litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize]; + litlen_entry = + self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize]; bit_buffer.fill_buffer(remaining_input); output[output_index] = litlen_symbol as u8; output_index += 1; @@ -729,7 +754,7 @@ impl CompressedBlock { let length = length_base as usize + (bits & length_extra_mask) as usize; bits >>= length_extra_bits; - let dist_entry = self.dist_table[(bits & 0x1ff) as usize]; + let dist_entry = self.dist_table[(bits & dist_table_mask) as usize]; let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 { ( (dist_entry >> 16) as u16, @@ -740,7 +765,7 @@ impl CompressedBlock { return Err(DecompressionError::InvalidDistanceCode); } else { let secondary_table_index = - (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff)); + (dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff)); let secondary_entry = self.dist_secondary_table[secondary_table_index as usize]; let dist_symbol = (secondary_entry >> 4) as usize; if dist_symbol >= 30 { @@ -764,7 +789,7 @@ impl CompressedBlock { litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits, ); bit_buffer.fill_buffer(remaining_input); - litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize]; + litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize]; let copy_length = length.min(output.len() - output_index); if dist == 1 { @@ -817,12 +842,13 @@ impl CompressedBlock { } let mut bits = bit_buffer.buffer; - let litlen_entry = self.litlen_table[(bits & 0xfff) as usize]; + let litlen_entry = self.litlen_table[(bits & litlen_table_mask) as usize]; let litlen_code_bits = litlen_entry as u8; if litlen_entry & LITERAL_ENTRY != 0 { - // Fast path: the next symbol is <= 12 bits and a literal, the table specifies the - // output bytes and we can directly write them to the output buffer. + // Fast path: the next symbol is <= `litlen_table_bits` bits and a literal, the + // table specifies the output bytes and we can directly write them to the output + // buffer. let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize; if bit_buffer.nbits < litlen_code_bits { @@ -860,8 +886,8 @@ impl CompressedBlock { litlen_code_bits, ) } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 { - let secondary_table_index = - (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff)); + let secondary_table_index = (litlen_entry >> 16) + + ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff)); let secondary_entry = self.secondary_table[secondary_table_index as usize]; let litlen_symbol = secondary_entry >> 4; let litlen_code_bits = (secondary_entry & 0xf) as u8; @@ -898,20 +924,22 @@ impl CompressedBlock { let length = length_base as usize + (bits & length_extra_mask) as usize; bits >>= length_extra_bits; - let dist_entry = self.dist_table[(bits & 0x1ff) as usize]; + let dist_entry = self.dist_table[(bits & dist_table_mask) as usize]; let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 { ( (dist_entry >> 16) as u16, (dist_entry >> 8) as u8 & 0xf, dist_entry as u8, ) - } else if bit_buffer.nbits > litlen_code_bits + length_extra_bits + 9 { + } else if bit_buffer.nbits + > litlen_code_bits + length_extra_bits + dist_table_bits as u8 + { if dist_entry >> 8 == 0 { return Err(DecompressionError::InvalidDistanceCode); } let secondary_table_index = - (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff)); + (dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff)); let secondary_entry = self.dist_secondary_table[secondary_table_index as usize]; let dist_symbol = (secondary_entry >> 4) as usize; if dist_symbol >= 30 { @@ -1184,8 +1212,8 @@ mod tests { #[test] fn fixed_tables() { let mut compression = CompressedBlock { - litlen_table: Box::new([0; 4096]), - dist_table: Box::new([0; 512]), + litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]), + dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]), secondary_table: Vec::new(), dist_secondary_table: Vec::new(), eof_code: 0, diff --git a/src/huffman.rs b/src/huffman.rs index 10e7330..59af32d 100644 --- a/src/huffman.rs +++ b/src/huffman.rs @@ -118,8 +118,8 @@ pub fn build_table( let codeword1 = codes[sym1]; let codeword2 = codes[sym2]; let codeword = codeword1 | (codeword2 << len1); - let entry = (sym1 as u32) << 16 - | (sym2 as u32) << 24 + let entry = ((sym1 as u32) << 16) + | ((sym2 as u32) << 24) | LITERAL_ENTRY | (2 << 8); primary_table[codeword as usize] = entry | (length as u32); diff --git a/src/tables.rs b/src/tables.rs index 44dec1e..567565a 100644 --- a/src/tables.rs +++ b/src/tables.rs @@ -87,8 +87,9 @@ pub(crate) const DIST_SYM_TO_DIST_BASE: [u16; 30] = [ 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, ]; -/// The main litlen_table uses a 12-bit input to lookup the meaning of the symbol. The table is -/// split into 4 sections: +/// By default the main litlen_table uses a 12-bit input to lookup the meaning of the symbol +/// (`const`-generic parameters of `CompressedBlock` can ask for a non-default table size). +/// The table entries have 4 possible flavours: /// /// aaaaaaaa_bbbbbbbb_100000yy_0000xxxx x = input_advance_bits, y = output_advance_bytes (literal) /// 0000000z_zzzzzzzz_00000yyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base (length) @@ -103,7 +104,7 @@ pub(crate) const LITLEN_TABLE_ENTRIES: [u32; 288] = { // 00000000_iiiiiiii_10000001_0000???? (? = will be filled by huffman::build_table) // aaaaaaaa_bbbbbbbb_100000yy_0000xxxx // x = input_advance_bits, y = output_advance_bytes (literal) - entries[i] = (i as u32) << 16 | LITERAL_ENTRY | (1 << 8); + entries[i] = ((i as u32) << 16) | LITERAL_ENTRY | (1 << 8); i += 1; } @@ -113,14 +114,16 @@ pub(crate) const LITLEN_TABLE_ENTRIES: [u32; 288] = { // 0000000z_zzzzzzzz_00000yyy_0000???? (? = will be filled by huffman::build_table) // 0000000z_zzzzzzzz_00000yyy_0000xxxx // x = input_advance_bits, y = extra_bits, z = distance_base (length) - entries[i] = (LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16 - | (LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8; + entries[i] = ((LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16) + | ((LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8); i += 1; } entries }; -/// The distance table is a 512-entry table that maps 9 bits of distance symbols to their meaning. +/// The distance table is by default a 512-entry table that maps 9 bits of distance symbols to +/// their meaning. (`const`-generic parameters of `CompressedBlock` can ask for a non-default +/// table size). /// /// 00000000_00000000_00000000_00000000 symbol is more than 9 bits /// zzzzzzzz_zzzzzzzz_0000yyyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base @@ -128,8 +131,8 @@ pub(crate) const DISTANCE_TABLE_ENTRIES: [u32; 32] = { let mut entries = [0; 32]; let mut i = 0; while i < 30 { - entries[i] = (DIST_SYM_TO_DIST_BASE[i] as u32) << 16 - | (DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8 + entries[i] = ((DIST_SYM_TO_DIST_BASE[i] as u32) << 16) + | ((DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8) | LITERAL_ENTRY; i += 1; }