Skip to content

Commit

Permalink
Shorter blocks and consistent file (#112)
Browse files Browse the repository at this point in the history
* Shorter blocks

* Carousseling of different thread's blocks into Lepton file
Makes Lepton file consistent between runs and helps for decoding of files transferred over slow channels

* Carouseling from 0

* Optimization of writer
Carouseling effectively mixes blocks, so we can exclude old mixing mechanism that checks length after each stream bit pushing - ~3 % faster

* Assert and some comments

* Faster writing during decoding - flush_non_final_data each row

* Apply review round 1

* Revew round 1 - revert bool writer changes

* Removed flush_non_final_data chack after each bit

* Fix for 0-sized buffer
  • Loading branch information
Melirius authored Nov 15, 2024
1 parent 7fbb320 commit 82d6547
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 30 deletions.
2 changes: 2 additions & 0 deletions src/structs/lepton_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ pub fn lepton_encode_row_range<W: Write>(
features,
)
.context()?;

bool_writer.flush_non_final_data().context()?;
}

if is_last_thread && full_file_compression {
Expand Down
104 changes: 86 additions & 18 deletions src/structs/multiplexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ use crate::{helpers::*, lepton_error::err_exit_code, structs::partial_buffer::Pa

/// The message that is sent between the threads
enum Message {
Eof,
WriteBlock(u8, Vec<u8>),
Eof(usize),
WriteBlock(usize, Vec<u8>),
}

pub struct MultiplexWriter {
thread_id: u8,
thread_id: usize,
sender: Sender<Message>,
buffer: Vec<u8>,
}
Expand Down Expand Up @@ -143,7 +143,7 @@ where
let cloned_sender = tx.clone();

let mut thread_writer = MultiplexWriter {
thread_id: thread_id as u8,
thread_id: thread_id,
sender: cloned_sender,
buffer: Vec::with_capacity(WRITE_BUFFER_SIZE),
};
Expand All @@ -155,7 +155,10 @@ where

thread_writer.flush().context()?;

thread_writer.sender.send(Message::Eof).context()?;
thread_writer
.sender
.send(Message::Eof(thread_id))
.context()?;
Ok(r)
};

Expand All @@ -169,20 +172,85 @@ where

// wait to collect work and done messages from all the threads
let mut threads_left = num_threads;
// carouseling to write data packets from all threads
let mut packets = vec![];
packets.resize(num_threads, VecDeque::<Vec<u8>>::new());
let mut eot = vec![false; num_threads]; // end of threads's packets
let mut curr_write_thread: usize = 0; // invariant is `packets[curr_write_thread].len() == 0`

let mut write_block = |thread_id: usize, a: Vec<u8>| -> Result<()> {
// block length and thread header
let tid = thread_id as u8;
let l = a.len() - 1;
if l == 4095 || l == 16383 || l == 65535 {
// length is a special power of 2 - standard block length is 2^16
writer.write_u8(tid | ((l.ilog2() as u8 >> 1) - 4) << 4)?;
} else {
writer.write_u8(tid)?;
writer.write_u8((l & 0xff) as u8)?;
writer.write_u8(((l >> 8) & 0xff) as u8)?;
}
// block itself
writer.write_all(&a[..])?;

Ok(())
};

while threads_left > 0 {
let value = rx.recv().context();
match value {
Ok(Message::Eof) => {
Ok(Message::Eof(thread_id)) => {
threads_left -= 1;
eot[thread_id] = true;

if threads_left == 0 {
// last phase - write down all remaining packets
let mut packets_left = 0;
for a in &packets {
packets_left += a.len();
}

while packets_left > 0 {
curr_write_thread = (curr_write_thread + 1) % num_threads;

if let Some(packet) = packets[curr_write_thread].pop_front() {
write_block(curr_write_thread, packet).context()?;

packets_left -= 1;
}
}
} else if thread_id == curr_write_thread {
// no more this thread's packets - continue to other threads
debug_assert_eq!(packets[curr_write_thread].len(), 0);
loop {
curr_write_thread = (curr_write_thread + 1) % num_threads;

if let Some(packet) = packets[curr_write_thread].pop_front() {
write_block(curr_write_thread, packet).context()?;
} else if !eot[curr_write_thread] {
break;
}
}
}
}
Ok(Message::WriteBlock(thread_id, b)) => {
let l = b.len() - 1;

writer.write_u8(thread_id).context()?;
writer.write_u8((l & 0xff) as u8).context()?;
writer.write_u8(((l >> 8) & 0xff) as u8).context()?;
writer.write_all(&b[..]).context()?;
debug_assert!(b.len() <= WRITE_BUFFER_SIZE);
if thread_id == curr_write_thread {
debug_assert_eq!(packets[curr_write_thread].len(), 0);
write_block(thread_id, b).context()?;
// this thread's packet written - continue to other threads
loop {
curr_write_thread = (curr_write_thread + 1) % num_threads;

if let Some(packet) = packets[curr_write_thread].pop_front() {
write_block(curr_write_thread, packet).context()?;
} else if !eot[curr_write_thread] {
break;
}
}
} else {
packets[thread_id].push_back(b);
}
}
Err(_) => {
// if we get a receiving error here, this means that one of the threads broke
Expand Down Expand Up @@ -221,7 +289,7 @@ where
/// getting the data that we are expecting.
pub struct MultiplexReader {
/// the multiplexed thread stream we are processing
thread_id: u8,
thread_id: usize,

/// the receiver part of the channel to get more buffers
receiver: Receiver<Message>,
Expand Down Expand Up @@ -262,7 +330,7 @@ impl MultiplexReader {

match self.receiver.recv() {
Ok(r) => match r {
Message::Eof => {
Message::Eof(_tid) => {
self.end_of_file = true;
}
Message::WriteBlock(tid, block) => {
Expand Down Expand Up @@ -354,7 +422,7 @@ impl<RESULT> MultiplexReaderState<RESULT> {
if let Some((thread_id, rx, cloned_processor, cloned_result_sender)) = w {
// get the appropriate receiver so we can read out data from it
let mut proc_reader = MultiplexReader {
thread_id: thread_id as u8,
thread_id: thread_id,
current_buffer: Cursor::new(Vec::new()),
receiver: rx,
end_of_file: false,
Expand Down Expand Up @@ -424,8 +492,8 @@ impl<RESULT> MultiplexReaderState<RESULT> {
// ignore if we get error sending because channel died since we will collect
// the error later. We don't want to interrupt the other threads that are processing
// so we only get the error from the thread that actually errored out.
let _ = self.sender_channels[usize::from(thread_id)]
.send(Message::WriteBlock(thread_id, a));
let tid = usize::from(thread_id);
let _ = self.sender_channels[tid].send(Message::WriteBlock(tid, a));
self.current_state = State::StartBlock;
} else {
break;
Expand All @@ -443,7 +511,7 @@ impl<RESULT> MultiplexReaderState<RESULT> {
let mut results = Vec::new();
for thread_id in 0..self.sender_channels.len() {
// send eof to all threads (ignore results since they might be dead already)
_ = self.sender_channels[thread_id].send(Message::Eof);
_ = self.sender_channels[thread_id].send(Message::Eof(thread_id));
results.push(None);
}

Expand Down
22 changes: 10 additions & 12 deletions src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ impl<W: Write> VPXBoolWriter<W> {

*tmp_value <<= shift;

// check if we're out of buffer space, if yes - send the buffer to output
if self.buffer.len() > 65536 - 128 {
self.flush_non_final_data()?;
}

Ok(())
}

Expand Down Expand Up @@ -302,16 +297,19 @@ impl<W: Write> VPXBoolWriter<W> {

/// When buffer is full and is going to be sent to output, preserve buffer data that
/// is not final and should carried over to the next buffer.
fn flush_non_final_data(&mut self) -> Result<()> {
pub fn flush_non_final_data(&mut self) -> Result<()> {
// carry over buffer data that might be not final
let mut i = self.buffer.len() - 1;
while self.buffer[i] == 0xFF {
assert!(i > 0);
let mut i = self.buffer.len();
if i > 0 {
i -= 1;
}
while self.buffer[i] == 0xFF {
assert!(i > 0);
i -= 1;
}

self.writer.write_all(&self.buffer[..i])?;
self.buffer.drain(..i);
self.writer.write_all(&self.buffer[..i])?;
self.buffer.drain(..i);
}

Ok(())
}
Expand Down

0 comments on commit 82d6547

Please sign in to comment.