Skip to content

Commit

Permalink
Change hardcoded table sizes into const-generic params. (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
anforowicz authored Jan 13, 2025
1 parent 7ded2fa commit df5729d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 44 deletions.
96 changes: 62 additions & 34 deletions src/decompress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,17 @@ pub const LITERAL_ENTRY: u32 = 0x8000;
pub const EXCEPTIONAL_ENTRY: u32 = 0x4000;
pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000;

// See https://github.com/image-rs/fdeflate/issues/45 for discussion of the table sizes.
const DEFAULT_LITLEN_TABLE_SIZE: usize = 4096;
const DEFAULT_DIST_TABLE_SIZE: usize = 512;

/// The Decompressor state for a compressed block.
#[derive(Eq, PartialEq, Debug)]
struct CompressedBlock {
litlen_table: Box<[u32; 4096]>,
struct CompressedBlock<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize> {
litlen_table: Box<[u32; LITLEN_TABLE_SIZE]>,
secondary_table: Vec<u16>,

dist_table: Box<[u32; 512]>,
dist_table: Box<[u32; DIST_TABLE_SIZE]>,
dist_secondary_table: Vec<u16>,

eof_code: u16,
Expand All @@ -91,7 +95,7 @@ enum State {
/// Decompressor for arbitrary zlib streams.
pub struct Decompressor {
/// State for decoding a compressed block.
compression: CompressedBlock,
compression: CompressedBlock<DEFAULT_LITLEN_TABLE_SIZE, DEFAULT_DIST_TABLE_SIZE>,
// State for decoding a block header.
header: BlockHeader,
// Number of bytes left for uncompressed block.
Expand Down Expand Up @@ -120,8 +124,8 @@ impl Decompressor {
Self {
bits: BitBuffer::new(),
compression: CompressedBlock {
litlen_table: Box::new([0; 4096]),
dist_table: Box::new([0; 512]),
litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
secondary_table: Vec::new(),
dist_secondary_table: Vec::new(),
eof_code: 0,
Expand Down Expand Up @@ -224,11 +228,11 @@ impl Decompressor {
}

let input0 = self.bits.peek_bits(8);
let input1 = self.bits.peek_bits(16) >> 8 & 0xff;
let input1 = (self.bits.peek_bits(16) >> 8) & 0xff;
if input0 & 0x0f != 0x08
|| (input0 & 0xf0) > 0x70
|| input1 & 0x20 != 0
|| (input0 << 8 | input1) % 31 != 0
|| ((input0 << 8) | input1) % 31 != 0
{
return Err(DecompressionError::BadZlibHeader);
}
Expand Down Expand Up @@ -394,9 +398,11 @@ impl Decompressor {
// Build decoding tables if the previous block wasn't also a fixed block.
if !self.fixed_table {
self.fixed_table = true;
assert!(self.compression.litlen_table.len() >= FIXED_LITLEN_TABLE.len());
for chunk in self.compression.litlen_table.chunks_exact_mut(512) {
chunk.copy_from_slice(&FIXED_LITLEN_TABLE);
}
assert!(self.compression.dist_table.len() >= FIXED_DIST_TABLE.len());
for chunk in self.compression.dist_table.chunks_exact_mut(32) {
chunk.copy_from_slice(&FIXED_DIST_TABLE);
}
Expand Down Expand Up @@ -551,7 +557,9 @@ impl Decompressor {
}
}

impl CompressedBlock {
impl<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize>
CompressedBlock<LITLEN_TABLE_SIZE, DIST_TABLE_SIZE>
{
fn build_tables(&mut self, hlit: usize, code_lengths: &[u8]) -> Result<(), DecompressionError> {
// If there is no code assigned for the EOF symbol then the bitstream is invalid.
if code_lengths[256] == 0 {
Expand Down Expand Up @@ -610,6 +618,20 @@ impl CompressedBlock {
mut output_index: usize,
queued_output: &mut Option<QueuedOutput>,
) -> Result<(CompressedBlockStatus, usize), DecompressionError> {
// `litlen_table_mask` (and `dist_table_mask`) calculation assumes that `LITLEN_TABLE_SIZE`
// (or `DIST_TABLE_SIZE`) is a power of two.
assert!(LITLEN_TABLE_SIZE.count_ones() == 1);
assert!(DIST_TABLE_SIZE.count_ones() == 1);
let litlen_table_mask = (LITLEN_TABLE_SIZE as u64) - 1;
let litlen_table_bits = LITLEN_TABLE_SIZE.trailing_zeros();
let dist_table_mask = (DIST_TABLE_SIZE as u64) - 1;
let dist_table_bits = DIST_TABLE_SIZE.trailing_zeros();
// Lower bound on table sizes, because 1a) RFC1951 uses at most 15 bits for codewords and
// 1b) we can fit at most 8 bits of `overflow_bits_mask` in the last byte of a primary
// table entry.
assert!(litlen_table_bits + 8 >= 15);
assert!(dist_table_bits + 8 >= 15);

// Fast decoding loop.
//
// This loop is optimized for speed and is the main decoding loop for the decompressor,
Expand All @@ -623,23 +645,25 @@ impl CompressedBlock {
// the bit buffer. This is because when the input is non-empty, the bit buffer actually
// has 64-bits of valid data (even though nbits will be in 56..=63).
bit_buffer.fill_buffer(remaining_input);
let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize];
let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
while output_index + 8 <= output.len() && remaining_input.len() >= 8 {
// First check whether the next symbol is a literal. This code does up to 2 additional
// table lookups to decode more literals.
let mut bits;
let mut litlen_code_bits = litlen_entry as u8;
if litlen_entry & LITERAL_ENTRY != 0 {
let litlen_entry2 =
self.litlen_table[(bit_buffer.buffer >> litlen_code_bits & 0xfff) as usize];
let litlen_entry2 = self.litlen_table
[((bit_buffer.buffer >> litlen_code_bits) & litlen_table_mask) as usize];
let litlen_code_bits2 = litlen_entry2 as u8;
let litlen_entry3 = self.litlen_table[(bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2)
& 0xfff) as usize];
let litlen_entry3 = self.litlen_table[((bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2))
& litlen_table_mask)
as usize];
let litlen_code_bits3 = litlen_entry3 as u8;
let litlen_entry4 = self.litlen_table[(bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)
& 0xfff) as usize];
let litlen_entry4 = self.litlen_table[((bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3))
& litlen_table_mask)
as usize];

let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
output[output_index] = (litlen_entry >> 16) as u8;
Expand Down Expand Up @@ -692,16 +716,17 @@ impl CompressedBlock {
litlen_code_bits,
)
} else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
let secondary_table_index =
(litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
let secondary_table_index = (litlen_entry >> 16)
+ ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
let secondary_entry = self.secondary_table[secondary_table_index as usize];
let litlen_symbol = secondary_entry >> 4;
let litlen_code_bits = (secondary_entry & 0xf) as u8;

match litlen_symbol {
0..=255 => {
bit_buffer.consume_bits(litlen_code_bits);
litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize];
litlen_entry =
self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
bit_buffer.fill_buffer(remaining_input);
output[output_index] = litlen_symbol as u8;
output_index += 1;
Expand Down Expand Up @@ -729,7 +754,7 @@ impl CompressedBlock {
let length = length_base as usize + (bits & length_extra_mask) as usize;
bits >>= length_extra_bits;

let dist_entry = self.dist_table[(bits & 0x1ff) as usize];
let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
(
(dist_entry >> 16) as u16,
Expand All @@ -740,7 +765,7 @@ impl CompressedBlock {
return Err(DecompressionError::InvalidDistanceCode);
} else {
let secondary_table_index =
(dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
(dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
let dist_symbol = (secondary_entry >> 4) as usize;
if dist_symbol >= 30 {
Expand All @@ -764,7 +789,7 @@ impl CompressedBlock {
litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits,
);
bit_buffer.fill_buffer(remaining_input);
litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize];
litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];

let copy_length = length.min(output.len() - output_index);
if dist == 1 {
Expand Down Expand Up @@ -817,12 +842,13 @@ impl CompressedBlock {
}

let mut bits = bit_buffer.buffer;
let litlen_entry = self.litlen_table[(bits & 0xfff) as usize];
let litlen_entry = self.litlen_table[(bits & litlen_table_mask) as usize];
let litlen_code_bits = litlen_entry as u8;

if litlen_entry & LITERAL_ENTRY != 0 {
// Fast path: the next symbol is <= 12 bits and a literal, the table specifies the
// output bytes and we can directly write them to the output buffer.
// Fast path: the next symbol is <= `litlen_table_bits` bits and a literal, the
// table specifies the output bytes and we can directly write them to the output
// buffer.
let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;

if bit_buffer.nbits < litlen_code_bits {
Expand Down Expand Up @@ -860,8 +886,8 @@ impl CompressedBlock {
litlen_code_bits,
)
} else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
let secondary_table_index =
(litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
let secondary_table_index = (litlen_entry >> 16)
+ ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
let secondary_entry = self.secondary_table[secondary_table_index as usize];
let litlen_symbol = secondary_entry >> 4;
let litlen_code_bits = (secondary_entry & 0xf) as u8;
Expand Down Expand Up @@ -898,20 +924,22 @@ impl CompressedBlock {
let length = length_base as usize + (bits & length_extra_mask) as usize;
bits >>= length_extra_bits;

let dist_entry = self.dist_table[(bits & 0x1ff) as usize];
let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
(
(dist_entry >> 16) as u16,
(dist_entry >> 8) as u8 & 0xf,
dist_entry as u8,
)
} else if bit_buffer.nbits > litlen_code_bits + length_extra_bits + 9 {
} else if bit_buffer.nbits
> litlen_code_bits + length_extra_bits + dist_table_bits as u8
{
if dist_entry >> 8 == 0 {
return Err(DecompressionError::InvalidDistanceCode);
}

let secondary_table_index =
(dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
(dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
let dist_symbol = (secondary_entry >> 4) as usize;
if dist_symbol >= 30 {
Expand Down Expand Up @@ -1184,8 +1212,8 @@ mod tests {
#[test]
fn fixed_tables() {
let mut compression = CompressedBlock {
litlen_table: Box::new([0; 4096]),
dist_table: Box::new([0; 512]),
litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
secondary_table: Vec::new(),
dist_secondary_table: Vec::new(),
eof_code: 0,
Expand Down
4 changes: 2 additions & 2 deletions src/huffman.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ pub fn build_table(
let codeword1 = codes[sym1];
let codeword2 = codes[sym2];
let codeword = codeword1 | (codeword2 << len1);
let entry = (sym1 as u32) << 16
| (sym2 as u32) << 24
let entry = ((sym1 as u32) << 16)
| ((sym2 as u32) << 24)
| LITERAL_ENTRY
| (2 << 8);
primary_table[codeword as usize] = entry | (length as u32);
Expand Down
19 changes: 11 additions & 8 deletions src/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ pub(crate) const DIST_SYM_TO_DIST_BASE: [u16; 30] = [
2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577,
];

/// The main litlen_table uses a 12-bit input to lookup the meaning of the symbol. The table is
/// split into 4 sections:
/// By default the main litlen_table uses a 12-bit input to lookup the meaning of the symbol
/// (`const`-generic parameters of `CompressedBlock` can ask for a non-default table size).
/// The table entries have 4 possible flavours:
///
/// aaaaaaaa_bbbbbbbb_100000yy_0000xxxx x = input_advance_bits, y = output_advance_bytes (literal)
/// 0000000z_zzzzzzzz_00000yyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base (length)
Expand All @@ -103,7 +104,7 @@ pub(crate) const LITLEN_TABLE_ENTRIES: [u32; 288] = {
// 00000000_iiiiiiii_10000001_0000???? (? = will be filled by huffman::build_table)
// aaaaaaaa_bbbbbbbb_100000yy_0000xxxx
// x = input_advance_bits, y = output_advance_bytes (literal)
entries[i] = (i as u32) << 16 | LITERAL_ENTRY | (1 << 8);
entries[i] = ((i as u32) << 16) | LITERAL_ENTRY | (1 << 8);
i += 1;
}

Expand All @@ -113,23 +114,25 @@ pub(crate) const LITLEN_TABLE_ENTRIES: [u32; 288] = {
// 0000000z_zzzzzzzz_00000yyy_0000???? (? = will be filled by huffman::build_table)
// 0000000z_zzzzzzzz_00000yyy_0000xxxx
// x = input_advance_bits, y = extra_bits, z = distance_base (length)
entries[i] = (LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16
| (LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8;
entries[i] = ((LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16)
| ((LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8);
i += 1;
}
entries
};

/// The distance table is a 512-entry table that maps 9 bits of distance symbols to their meaning.
/// The distance table is by default a 512-entry table that maps 9 bits of distance symbols to
/// their meaning. (`const`-generic parameters of `CompressedBlock` can ask for a non-default
/// table size).
///
/// 00000000_00000000_00000000_00000000 symbol is more than 9 bits
/// zzzzzzzz_zzzzzzzz_0000yyyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base
pub(crate) const DISTANCE_TABLE_ENTRIES: [u32; 32] = {
let mut entries = [0; 32];
let mut i = 0;
while i < 30 {
entries[i] = (DIST_SYM_TO_DIST_BASE[i] as u32) << 16
| (DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8
entries[i] = ((DIST_SYM_TO_DIST_BASE[i] as u32) << 16)
| ((DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8)
| LITERAL_ENTRY;
i += 1;
}
Expand Down

0 comments on commit df5729d

Please sign in to comment.