Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bool writer like reader #118

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 63 additions & 37 deletions src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ use crate::metrics::{Metrics, ModelComponent};
use crate::structs::branch::Branch;
use crate::structs::simple_hash::SimpleHash;

// MAX_STREAM_BITS should be a multiple of 8 larger than 8,
// and (MAX_STREAM_BITS + 1 bit of carry + 1 bit of divider)
// should fit into 64 bits of `low_value`
const MAX_STREAM_BITS: i32 = 56; //48; //40;// 32;// 24;// 16;//

pub struct VPXBoolWriter<W> {
low_value: u64,
range: u32,
Expand All @@ -45,18 +40,18 @@ pub struct VPXBoolWriter<W> {

impl<W: Write> VPXBoolWriter<W> {
pub fn new(writer: W) -> Result<Self> {
let mut retval = VPXBoolWriter {
let retval = VPXBoolWriter {
low_value: 1 << 9, // this divider bit keeps track of stream bits number
range: 255,
range: 128, // this is value after putting initial false bit
buffer: Vec::new(),
writer: writer,
model_statistics: Metrics::default(),
hash: SimpleHash::new(),
};

let mut dummy_branch = Branch::new();
// initial false bit is put to not get carry out of stream bits
retval.put_bit(false, &mut dummy_branch, ModelComponent::Dummy)?;
// initial false bit with dummy branch is put into stream
// to not get carry out of stream bits,
// but it is just equivalent to change `range` from initial 255 to 128

Ok(retval)
}
Expand Down Expand Up @@ -95,7 +90,6 @@ impl<W: Write> VPXBoolWriter<W> {

let split = 1 + (((*tmp_range - 1) * probability) >> 8);

let mut shift;
branch.record_and_update_bit(value);

if value {
Expand All @@ -105,7 +99,7 @@ impl<W: Write> VPXBoolWriter<W> {
*tmp_range = split;
}

shift = (*tmp_range as u8).leading_zeros() as i32;
let shift = (*tmp_range as u8).leading_zeros() as i32;

#[cfg(feature = "compression_stats")]
{
Expand All @@ -114,28 +108,6 @@ impl<W: Write> VPXBoolWriter<W> {
}

*tmp_range <<= shift;

// check whether we have more than MAX_STREAM_BITS stream bits after shift
let stream_bits = 64 - (*tmp_value).leading_zeros() as i32 - 2;
let count = shift + stream_bits - MAX_STREAM_BITS;
if count >= 0 {
// check carry
*tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (*tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
self.carry();
}
// write all full bytes
let mut sh = MAX_STREAM_BITS - 8;
while sh > 0 {
self.buffer.push((*tmp_value >> sh) as u8);
sh -= 8;
}
*tmp_value &= (1 << 8) - 1; // exclude written bytes
*tmp_value |= 1 << 9; // restore divider bit

shift = count;
}

*tmp_value <<= shift;

Ok(())
Expand All @@ -159,6 +131,12 @@ impl<W: Write> VPXBoolWriter<W> {
self.buffer[x] += 1;
}

// each added bit can extend stream for up to 7 bits
#[inline(always)]
fn cannot_put_bits(tmp_value: u64, num_bits: u32) -> bool {
tmp_value & (u64::MAX << (64 - num_bits * 7)) != 0
}

#[inline(always)]
pub fn put_grid<const A: usize>(
&mut self,
Expand All @@ -173,6 +151,11 @@ impl<W: Write> VPXBoolWriter<W> {

let mut index = A.ilog2() - 1;
let mut serialized_so_far = 1;
// grid is 3 or 6 bits long, so single flash is enough
debug_assert!(A <= 64);
if Self::cannot_put_bits(tmp_value, A.ilog2()) {
tmp_value = self.flush_buffer(tmp_value);
}

loop {
let cur_bit = (v & (1 << index)) != 0;
Expand Down Expand Up @@ -213,6 +196,10 @@ impl<W: Write> VPXBoolWriter<W> {

let mut i: i32 = (num_bits - 1) as i32;
while i >= 0 {
if Self::cannot_put_bits(tmp_value, 1) {
tmp_value = self.flush_buffer(tmp_value);
}

self.put(
(bits & (1 << i)) != 0,
&mut branches[i as usize],
Expand Down Expand Up @@ -243,6 +230,11 @@ impl<W: Write> VPXBoolWriter<W> {

for i in 0..A {
let cur_bit = v != i;
debug_assert!(A <= 12);
// ensure we can put 6 bits into the stream
if (i == 0 || i == 6) && Self::cannot_put_bits(tmp_value, 6) {
tmp_value = self.flush_buffer(tmp_value);
}

self.put(
cur_bit,
Expand Down Expand Up @@ -271,6 +263,9 @@ impl<W: Write> VPXBoolWriter<W> {
) -> Result<()> {
let mut tmp_value = self.low_value;
let mut tmp_range = self.range;
if Self::cannot_put_bits(tmp_value, 1) {
tmp_value = self.flush_buffer(tmp_value);
}

self.put(value, branch, &mut tmp_value, &mut tmp_range, _cmp)?;

Expand All @@ -280,24 +275,55 @@ impl<W: Write> VPXBoolWriter<W> {
Ok(())
}

// After `flush_buffer` we have max 15 stream bits and can put there 6 bits,
// that is adding max 6*7 stream bits, as 15 + 42 < 62
fn flush_buffer(&mut self, mut tmp_value: u64) -> u64 {
let stream_bits = 64 - tmp_value.leading_zeros() as i32 - 2;
let low_value = tmp_value << 63 - stream_bits;
if low_value & (1 << 63) != 0 {
self.carry();
}

let mut sh = 55;
let mut stream_bytes = (stream_bits >> 3) - 1;
while stream_bytes > 0 {
self.buffer.push((low_value >> sh) as u8);
sh -= 8;
stream_bytes -= 1;
}

let remaining_bits = 8 + (stream_bits & 7);
tmp_value &= (1 << remaining_bits) - 1;
tmp_value |= 1 << (remaining_bits + 1);

tmp_value
}

// Here we write down only bytes of the stream necessary for decoding -
// opposite to initial Lepton implementation that writes down all the buffer.
pub fn finish(&mut self) -> Result<()> {
let mut tmp_value = self.low_value;
let stream_bits = 64 - tmp_value.leading_zeros() as i32 - 2;

tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
tmp_value <<= 63 - stream_bits;
if tmp_value & (1 << 63) != 0 {
self.carry();
}

let mut shift = MAX_STREAM_BITS - 8;
tmp_value <<= 1; // needed for 8 stream_bytes
let mut shift = 56;
let mut stream_bytes = (stream_bits + 7) >> 3;
while stream_bytes > 0 {
self.buffer.push((tmp_value >> shift) as u8);
shift -= 8;
stream_bytes -= 1;
}
// check that no stream bits remain in the buffer
debug_assert!(if shift == 56 {
tmp_value == 0
} else {
!(u64::MAX << (shift + 8)) & tmp_value == 0
});

self.writer.write_all(&self.buffer[..])?;
Ok(())
Expand Down
Loading