Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change hardcoded table sizes into const-generic params. #49

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions src/decompress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,17 @@ pub const LITERAL_ENTRY: u32 = 0x8000;
pub const EXCEPTIONAL_ENTRY: u32 = 0x4000;
pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000;

// See https://github.com/image-rs/fdeflate/issues/45 for discussion of the table sizes.
const DEFAULT_LITLEN_TABLE_SIZE: usize = 4096;
const DEFAULT_DIST_TABLE_SIZE: usize = 512;

/// The Decompressor state for a compressed block.
#[derive(Eq, PartialEq, Debug)]
struct CompressedBlock {
litlen_table: Box<[u32; 4096]>,
struct CompressedBlock<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize> {
litlen_table: Box<[u32; LITLEN_TABLE_SIZE]>,
secondary_table: Vec<u16>,

dist_table: Box<[u32; 512]>,
dist_table: Box<[u32; DIST_TABLE_SIZE]>,
dist_secondary_table: Vec<u16>,

eof_code: u16,
Expand All @@ -91,7 +95,7 @@ enum State {
/// Decompressor for arbitrary zlib streams.
pub struct Decompressor {
/// State for decoding a compressed block.
compression: CompressedBlock,
compression: CompressedBlock<DEFAULT_LITLEN_TABLE_SIZE, DEFAULT_DIST_TABLE_SIZE>,
// State for decoding a block header.
header: BlockHeader,
// Number of bytes left for uncompressed block.
Expand Down Expand Up @@ -120,8 +124,8 @@ impl Decompressor {
Self {
bits: BitBuffer::new(),
compression: CompressedBlock {
litlen_table: Box::new([0; 4096]),
dist_table: Box::new([0; 512]),
litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
secondary_table: Vec::new(),
dist_secondary_table: Vec::new(),
eof_code: 0,
Expand Down Expand Up @@ -224,11 +228,11 @@ impl Decompressor {
}

let input0 = self.bits.peek_bits(8);
let input1 = self.bits.peek_bits(16) >> 8 & 0xff;
let input1 = (self.bits.peek_bits(16) >> 8) & 0xff;
if input0 & 0x0f != 0x08
|| (input0 & 0xf0) > 0x70
|| input1 & 0x20 != 0
|| (input0 << 8 | input1) % 31 != 0
|| ((input0 << 8) | input1) % 31 != 0
{
return Err(DecompressionError::BadZlibHeader);
}
Expand Down Expand Up @@ -394,9 +398,11 @@ impl Decompressor {
// Build decoding tables if the previous block wasn't also a fixed block.
if !self.fixed_table {
self.fixed_table = true;
assert!(self.compression.litlen_table.len() >= FIXED_LITLEN_TABLE.len());
for chunk in self.compression.litlen_table.chunks_exact_mut(512) {
chunk.copy_from_slice(&FIXED_LITLEN_TABLE);
}
assert!(self.compression.dist_table.len() >= FIXED_DIST_TABLE.len());
for chunk in self.compression.dist_table.chunks_exact_mut(32) {
chunk.copy_from_slice(&FIXED_DIST_TABLE);
}
Expand Down Expand Up @@ -551,7 +557,9 @@ impl Decompressor {
}
}

impl CompressedBlock {
impl<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize>
CompressedBlock<LITLEN_TABLE_SIZE, DIST_TABLE_SIZE>
{
fn build_tables(&mut self, hlit: usize, code_lengths: &[u8]) -> Result<(), DecompressionError> {
// If there is no code assigned for the EOF symbol then the bitstream is invalid.
if code_lengths[256] == 0 {
Expand Down Expand Up @@ -610,6 +618,20 @@ impl CompressedBlock {
mut output_index: usize,
queued_output: &mut Option<QueuedOutput>,
) -> Result<(CompressedBlockStatus, usize), DecompressionError> {
// `litlen_table_mask` (and `dist_table_mask`) calculation assumes that `LITLEN_TABLE_SIZE`
// (or `DIST_TABLE_SIZE`) is a power of two.
assert!(LITLEN_TABLE_SIZE.count_ones() == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't make a huge difference, but there's actually a is_power_of_two method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this could be in const {} to check at compile time

assert!(DIST_TABLE_SIZE.count_ones() == 1);
let litlen_table_mask = (LITLEN_TABLE_SIZE as u64) - 1;
let litlen_table_bits = LITLEN_TABLE_SIZE.trailing_zeros();
let dist_table_mask = (DIST_TABLE_SIZE as u64) - 1;
let dist_table_bits = DIST_TABLE_SIZE.trailing_zeros();
// Lower bound on table sizes, because 1a) RFC1951 uses at most 15 bits for codewords and
// 1b) we can fit at most 8 bits of `overflow_bits_mask` in the last byte of a primary
// table entry.
assert!(litlen_table_bits + 8 >= 15);
assert!(dist_table_bits + 8 >= 15);

// Fast decoding loop.
//
// This loop is optimized for speed and is the main decoding loop for the decompressor,
Expand All @@ -623,23 +645,25 @@ impl CompressedBlock {
// the bit buffer. This is because when the input is non-empty, the bit buffer actually
// has 64-bits of valid data (even though nbits will be in 56..=63).
bit_buffer.fill_buffer(remaining_input);
let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize];
let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
while output_index + 8 <= output.len() && remaining_input.len() >= 8 {
// First check whether the next symbol is a literal. This code does up to 2 additional
// table lookups to decode more literals.
let mut bits;
let mut litlen_code_bits = litlen_entry as u8;
if litlen_entry & LITERAL_ENTRY != 0 {
let litlen_entry2 =
self.litlen_table[(bit_buffer.buffer >> litlen_code_bits & 0xfff) as usize];
let litlen_entry2 = self.litlen_table
[((bit_buffer.buffer >> litlen_code_bits) & litlen_table_mask) as usize];
let litlen_code_bits2 = litlen_entry2 as u8;
let litlen_entry3 = self.litlen_table[(bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2)
& 0xfff) as usize];
let litlen_entry3 = self.litlen_table[((bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2))
& litlen_table_mask)
as usize];
let litlen_code_bits3 = litlen_entry3 as u8;
let litlen_entry4 = self.litlen_table[(bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)
& 0xfff) as usize];
let litlen_entry4 = self.litlen_table[((bit_buffer.buffer
>> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3))
& litlen_table_mask)
as usize];

let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
output[output_index] = (litlen_entry >> 16) as u8;
Expand Down Expand Up @@ -692,16 +716,17 @@ impl CompressedBlock {
litlen_code_bits,
)
} else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
let secondary_table_index =
(litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
let secondary_table_index = (litlen_entry >> 16)
+ ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
let secondary_entry = self.secondary_table[secondary_table_index as usize];
let litlen_symbol = secondary_entry >> 4;
let litlen_code_bits = (secondary_entry & 0xf) as u8;

match litlen_symbol {
0..=255 => {
bit_buffer.consume_bits(litlen_code_bits);
litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize];
litlen_entry =
self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
bit_buffer.fill_buffer(remaining_input);
output[output_index] = litlen_symbol as u8;
output_index += 1;
Expand Down Expand Up @@ -729,7 +754,7 @@ impl CompressedBlock {
let length = length_base as usize + (bits & length_extra_mask) as usize;
bits >>= length_extra_bits;

let dist_entry = self.dist_table[(bits & 0x1ff) as usize];
let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
(
(dist_entry >> 16) as u16,
Expand All @@ -740,7 +765,7 @@ impl CompressedBlock {
return Err(DecompressionError::InvalidDistanceCode);
} else {
let secondary_table_index =
(dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
(dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
let dist_symbol = (secondary_entry >> 4) as usize;
if dist_symbol >= 30 {
Expand All @@ -764,7 +789,7 @@ impl CompressedBlock {
litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits,
);
bit_buffer.fill_buffer(remaining_input);
litlen_entry = self.litlen_table[(bit_buffer.buffer & 0xfff) as usize];
litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];

let copy_length = length.min(output.len() - output_index);
if dist == 1 {
Expand Down Expand Up @@ -817,12 +842,13 @@ impl CompressedBlock {
}

let mut bits = bit_buffer.buffer;
let litlen_entry = self.litlen_table[(bits & 0xfff) as usize];
let litlen_entry = self.litlen_table[(bits & litlen_table_mask) as usize];
let litlen_code_bits = litlen_entry as u8;

if litlen_entry & LITERAL_ENTRY != 0 {
// Fast path: the next symbol is <= 12 bits and a literal, the table specifies the
// output bytes and we can directly write them to the output buffer.
// Fast path: the next symbol is <= `litlen_table_bits` bits and a literal, the
// table specifies the output bytes and we can directly write them to the output
// buffer.
let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;

if bit_buffer.nbits < litlen_code_bits {
Expand Down Expand Up @@ -860,8 +886,8 @@ impl CompressedBlock {
litlen_code_bits,
)
} else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
let secondary_table_index =
(litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
let secondary_table_index = (litlen_entry >> 16)
+ ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
let secondary_entry = self.secondary_table[secondary_table_index as usize];
let litlen_symbol = secondary_entry >> 4;
let litlen_code_bits = (secondary_entry & 0xf) as u8;
Expand Down Expand Up @@ -898,20 +924,22 @@ impl CompressedBlock {
let length = length_base as usize + (bits & length_extra_mask) as usize;
bits >>= length_extra_bits;

let dist_entry = self.dist_table[(bits & 0x1ff) as usize];
let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
(
(dist_entry >> 16) as u16,
(dist_entry >> 8) as u8 & 0xf,
dist_entry as u8,
)
} else if bit_buffer.nbits > litlen_code_bits + length_extra_bits + 9 {
} else if bit_buffer.nbits
> litlen_code_bits + length_extra_bits + dist_table_bits as u8
{
if dist_entry >> 8 == 0 {
return Err(DecompressionError::InvalidDistanceCode);
}

let secondary_table_index =
(dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
(dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
let dist_symbol = (secondary_entry >> 4) as usize;
if dist_symbol >= 30 {
Expand Down Expand Up @@ -1184,8 +1212,8 @@ mod tests {
#[test]
fn fixed_tables() {
let mut compression = CompressedBlock {
litlen_table: Box::new([0; 4096]),
dist_table: Box::new([0; 512]),
litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
secondary_table: Vec::new(),
dist_secondary_table: Vec::new(),
eof_code: 0,
Expand Down
4 changes: 2 additions & 2 deletions src/huffman.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ pub fn build_table(
let codeword1 = codes[sym1];
let codeword2 = codes[sym2];
let codeword = codeword1 | (codeword2 << len1);
let entry = (sym1 as u32) << 16
| (sym2 as u32) << 24
let entry = ((sym1 as u32) << 16)
| ((sym2 as u32) << 24)
| LITERAL_ENTRY
| (2 << 8);
primary_table[codeword as usize] = entry | (length as u32);
Expand Down
19 changes: 11 additions & 8 deletions src/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ pub(crate) const DIST_SYM_TO_DIST_BASE: [u16; 30] = [
2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577,
];

/// The main litlen_table uses a 12-bit input to lookup the meaning of the symbol. The table is
/// split into 4 sections:
/// By default the main litlen_table uses a 12-bit input to lookup the meaning of the symbol
/// (`const`-generic parameters of `CompressedBlock` can ask for a non-default table size).
/// The table entries have 4 possible flavours:
///
/// aaaaaaaa_bbbbbbbb_100000yy_0000xxxx x = input_advance_bits, y = output_advance_bytes (literal)
/// 0000000z_zzzzzzzz_00000yyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base (length)
Expand All @@ -103,7 +104,7 @@ pub(crate) const LITLEN_TABLE_ENTRIES: [u32; 288] = {
// 00000000_iiiiiiii_10000001_0000???? (? = will be filled by huffman::build_table)
// aaaaaaaa_bbbbbbbb_100000yy_0000xxxx
// x = input_advance_bits, y = output_advance_bytes (literal)
entries[i] = (i as u32) << 16 | LITERAL_ENTRY | (1 << 8);
entries[i] = ((i as u32) << 16) | LITERAL_ENTRY | (1 << 8);
i += 1;
}

Expand All @@ -113,23 +114,25 @@ pub(crate) const LITLEN_TABLE_ENTRIES: [u32; 288] = {
// 0000000z_zzzzzzzz_00000yyy_0000???? (? = will be filled by huffman::build_table)
// 0000000z_zzzzzzzz_00000yyy_0000xxxx
// x = input_advance_bits, y = extra_bits, z = distance_base (length)
entries[i] = (LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16
| (LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8;
entries[i] = ((LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16)
| ((LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8);
i += 1;
}
entries
};

/// The distance table is a 512-entry table that maps 9 bits of distance symbols to their meaning.
/// The distance table is by default a 512-entry table that maps 9 bits of distance symbols to
/// their meaning. (`const`-generic parameters of `CompressedBlock` can ask for a non-default
/// table size).
///
/// 00000000_00000000_00000000_00000000 symbol is more than 9 bits
/// zzzzzzzz_zzzzzzzz_0000yyyy_0000xxxx x = input_advance_bits, y = extra_bits, z = distance_base
pub(crate) const DISTANCE_TABLE_ENTRIES: [u32; 32] = {
let mut entries = [0; 32];
let mut i = 0;
while i < 30 {
entries[i] = (DIST_SYM_TO_DIST_BASE[i] as u32) << 16
| (DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8
entries[i] = ((DIST_SYM_TO_DIST_BASE[i] as u32) << 16)
| ((DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8)
| LITERAL_ENTRY;
i += 1;
}
Expand Down
Loading