Skip to content

Commit

Permalink
encode huffman tables according to spec
Browse files Browse the repository at this point in the history
  • Loading branch information
KillingSpark committed Oct 15, 2024
1 parent ea4e286 commit afa9b52
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 19 deletions.
7 changes: 7 additions & 0 deletions src/encoding/bit_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ impl BitWriter {
}
}

pub fn append_bytes(&mut self, data: &[u8]) {
if self.misaligned() != 0 {
panic!("Don't append bytes when writer is misaligned")
}
self.output.extend_from_slice(data);
}

pub fn write_bits(&mut self, bits: impl Into<u64>, num_bits: usize) {
self.write_bits_64(bits.into(), num_bits);
}
Expand Down
3 changes: 2 additions & 1 deletion src/fse/fse_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ impl<'t> FSEDecoder<'t> {
if self.table.accuracy_log == 0 {
return Err(FSEDecoderError::TableIsUninitialized);
}
self.state = self.table.decode[bits.get_bits(self.table.accuracy_log) as usize];
let new_state = bits.get_bits(self.table.accuracy_log);
self.state = self.table.decode[new_state as usize];

Ok(())
}
Expand Down
75 changes: 72 additions & 3 deletions src/fse/fse_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,66 @@ impl FSEEncoder {
writer.dump()
}

pub fn encode_interleaved(&mut self, data: &[u8]) -> Vec<u8> {
self.write_table();

let mut state_1 = &self.table.states[data[data.len() - 1] as usize].states[0];
let mut state_2 = &self.table.states[data[data.len() - 2] as usize].states[0];

let mut idx = data.len() - 4;
loop {
{
let state = state_1;
let x = data[idx + 1];
let next = self.table.next_state(x, state.index);
let diff = state.index - next.baseline;
self.writer.write_bits(diff as u64, next.num_bits as usize);
state_1 = next;
}
{
let state = state_2;
let x = data[idx];
let next = self.table.next_state(x, state.index);
let diff = state.index - next.baseline;
self.writer.write_bits(diff as u64, next.num_bits as usize);
state_2 = next;
}

if idx < 2 {
break;
}
idx -= 2;
}
if idx == 1 {
let state = state_1;
let x = data[0];
let next = self.table.next_state(x, state.index);
let diff = state.index - next.baseline;
self.writer.write_bits(diff as u64, next.num_bits as usize);
state_1 = next;

self.writer
.write_bits(state_2.index as u64, self.acc_log() as usize);
self.writer
.write_bits(state_1.index as u64, self.acc_log() as usize);
} else {
self.writer
.write_bits(state_1.index as u64, self.acc_log() as usize);
self.writer
.write_bits(state_2.index as u64, self.acc_log() as usize);
}

let mut writer = BitWriter::new();
core::mem::swap(&mut self.writer, &mut writer);
let bits_to_fill = writer.misaligned();
if bits_to_fill == 0 {
writer.write_bits(1u32, 8);
} else {
writer.write_bits(1u32, bits_to_fill);
}
writer.dump()
}

fn write_table(&mut self) {
self.writer.write_bits(self.acc_log() - 5, 4);
let mut probability_counter = 0usize;
Expand Down Expand Up @@ -133,15 +193,15 @@ impl State {
}
}

pub fn build_table_from_data(data: &[u8]) -> FSETable {
pub fn build_table_from_data(data: &[u8], avoid_0_numbit: bool) -> FSETable {
let mut counts = [0; 256];
for x in data {
counts[*x as usize] += 1;
}
build_table_from_counts(&counts)
build_table_from_counts(&counts, avoid_0_numbit)
}

fn build_table_from_counts(counts: &[usize]) -> FSETable {
fn build_table_from_counts(counts: &[usize], avoid_0_numbit: bool) -> FSETable {
let mut probs = [0; 256];
let mut min_count = 0;
for (idx, count) in counts.iter().copied().enumerate() {
Expand Down Expand Up @@ -172,6 +232,15 @@ fn build_table_from_counts(counts: &[usize]) -> FSETable {
let max = probs.iter_mut().max().unwrap();
*max += diff as i32;

if avoid_0_numbit && *max > 1 << (acc_log - 1) {
let redistribute = *max - (1 << (acc_log - 1));
*max -= redistribute;
let max = *max;
let second_max = probs.iter_mut().filter(|x| **x != max).max().unwrap();
*second_max += redistribute;
assert!(*second_max <= max);
}

build_table_from_probabilities(&probs, acc_log)
}

Expand Down
2 changes: 1 addition & 1 deletion src/fse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub fn round_trip(data: &[u8]) {
return;
}

let mut encoder: FSEEncoder = FSEEncoder::new(fse_encoder::build_table_from_data(data));
let mut encoder: FSEEncoder = FSEEncoder::new(fse_encoder::build_table_from_data(data, false));
let mut dec_table = FSETable::new(255);

let encoded = encoder.encode(data);
Expand Down
12 changes: 1 addition & 11 deletions src/huff0/huff0_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ impl HuffmanTable {
bits: Vec::with_capacity(256),
bit_ranks: Vec::with_capacity(11),
rank_indexes: Vec::with_capacity(11),
fse_table: FSETable::new(100),
fse_table: FSETable::new(255),
}
}

Expand Down Expand Up @@ -586,14 +586,4 @@ impl HuffmanTable {

Ok(())
}

/// For internal tests construct directly from weights
pub(super) fn from_weights(mut weights: Vec<u8>) -> Self {
// Last weight is inferred by build_table_from_weights
weights.pop();
let mut new = Self::new();
new.weights = weights;
new.build_table_from_weights().unwrap();
new
}
}
39 changes: 38 additions & 1 deletion src/huff0/huff0_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use alloc::vec::Vec;
use core::cmp::Ordering;

use crate::encoding::bit_writer::BitWriter;
use crate::{
encoding::bit_writer::BitWriter,
fse::fse_encoder::{self, FSEEncoder},
};

pub struct HuffmanEncoder {
table: HuffmanTable,
Expand All @@ -16,6 +19,7 @@ impl HuffmanEncoder {
}
}
pub fn encode(&mut self, data: &[u8]) {
self.write_table();
for symbol in data.iter().rev() {
let (code, num_bits) = self.table.codes[*symbol as usize];
self.writer.write_bits(code, num_bits as usize);
Expand Down Expand Up @@ -44,6 +48,39 @@ impl HuffmanEncoder {

weights
}

fn write_table(&mut self) {
// TODO strategy for determining this?
let weights = self.weights();
let weights = &weights[..weights.len() - 1]; // dont encode last weight

if weights.len() > 16 {
// TODO share output vec between encoders
// TODO assert that no 0 num_bit states are generated here
let mut encoder = FSEEncoder::new(fse_encoder::build_table_from_data(&weights, true));
let encoded = encoder.encode_interleaved(&weights);
assert!(encoded.len() < 128);
self.writer.write_bits(encoded.len() as u8, 8);
self.writer.append_bytes(&encoded);
} else {
self.writer.write_bits(weights.len() as u8 + 127, 8);
let pairs = weights.chunks_exact(2);
let remainder = pairs.remainder();
for pair in pairs.into_iter() {
let weight1 = pair[0];
let weight2 = pair[1];
assert!(weight1 < 16);
assert!(weight2 < 16);
self.writer.write_bits(weight2, 4);
self.writer.write_bits(weight1, 4);
}
if !remainder.is_empty() {
let weight = remainder[0];
assert!(weight < 16);
self.writer.write_bits(weight << 4, 8);
}
}
}
}

pub struct HuffmanTable {
Expand Down
5 changes: 3 additions & 2 deletions src/huff0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ pub fn round_trip(data: &[u8]) {

encoder.encode(data);
let encoded = encoder.dump();
let decoder_table = HuffmanTable::from_weights(encoder.weights());
let mut decoder_table = HuffmanTable::new();
let table_bytes = decoder_table.build_decoder(&encoded).unwrap();
let mut decoder = HuffmanDecoder::new(&decoder_table);

let mut br = BitReaderReversed::new(&encoded);
let mut br = BitReaderReversed::new(&encoded[table_bytes as usize..]);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1);
Expand Down

0 comments on commit afa9b52

Please sign in to comment.