diff --git a/Cargo.lock b/Cargo.lock
index a00f0da12..56ee426f4 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -328,6 +328,7 @@ dependencies = [
"async-trait",
"hashbrown",
"serde",
+ "tokio",
"unsigned-varint",
]
diff --git a/crates/derive/Cargo.toml b/crates/derive/Cargo.toml
index 9f26eaf77..38e640b31 100644
--- a/crates/derive/Cargo.toml
+++ b/crates/derive/Cargo.toml
@@ -23,5 +23,8 @@ unsigned-varint = "0.8.0"
# Optional
serde = { version = "1.0.197", default-features = false, features = ["derive"], optional = true }
+[dev-dependencies]
+tokio = { version = "1.36", features = ["full"] }
+
[features]
serde = ["dep:serde", "alloy-primitives/serde"]
diff --git a/crates/derive/src/stages/channel_bank.rs b/crates/derive/src/stages/channel_bank.rs
index 6293a6d14..c872c9ce1 100644
--- a/crates/derive/src/stages/channel_bank.rs
+++ b/crates/derive/src/stages/channel_bank.rs
@@ -60,12 +60,15 @@ where
self.prev.origin()
}
+ /// Returns the size of the channel bank by accumulating over all channels.
+ pub fn size(&self) -> usize {
+ self.channels.iter().fold(0, |acc, (_, c)| acc + c.size())
+ }
+
/// Prunes the Channel bank, until it is below [MAX_CHANNEL_BANK_SIZE].
+ /// Prunes from the high-priority channel since it failed to be read.
pub fn prune(&mut self) -> StageResult<()> {
- // Check total size
- let mut total_size = self.channels.iter().fold(0, |acc, (_, c)| acc + c.size());
- // Prune until it is reasonable again. The high-priority channel failed to be read,
- // so we prune from there.
+ let mut total_size = self.size();
while total_size > MAX_CHANNEL_BANK_SIZE {
let id = self
.channel_queue
@@ -122,16 +125,17 @@ where
.ok_or(anyhow!("Channel not found"))?;
let origin = self.origin().ok_or(anyhow!("No origin present"))?;
+ // Remove all timed out channels from the front of the `channel_queue`.
if channel.open_block_number() + self.cfg.channel_timeout < origin.number {
self.channels.remove(&first);
self.channel_queue.pop_front();
return Ok(None);
}
- // At the point we have removed all timed out channels from the front of the `channel_queue`.
+ // At this point we have removed all timed out channels from the front of the `channel_queue`.
// Pre-Canyon we simply check the first index.
- // Post-Canyon we read the entire channelQueue for the first ready channel. If no channel is
- // available, we return `nil, io.EOF`.
+ // Post-Canyon we read the entire channelQueue for the first ready channel.
+ // If no channel is available, we return StageError::Eof.
// Canyon is activated when the first L1 block whose time >= CanyonTime, not on the L2 timestamp.
if !self.cfg.is_canyon_active(origin.timestamp) {
return self.try_read_channel_at_index(0).map(Some);
@@ -201,3 +205,68 @@ where
Err(StageError::Eof)
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::stages::frame_queue::tests::new_test_frames;
+ use crate::stages::l1_retrieval::L1Retrieval;
+ use crate::stages::l1_traversal::tests::new_test_traversal;
+ use crate::traits::test_utils::TestDAP;
+ use alloc::vec;
+
+ #[test]
+ fn test_ingest_empty_origin() {
+ let mut traversal = new_test_traversal(false, false);
+ traversal.block = None;
+ let dap = TestDAP::default();
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let frame_queue = FrameQueue::new(retrieval);
+ let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue);
+ let frame = Frame::default();
+ let err = channel_bank.ingest_frame(frame).unwrap_err();
+ assert_eq!(err, StageError::Custom(anyhow!("No origin")));
+ }
+
+ #[test]
+ fn test_ingest_and_prune_channel_bank() {
+ let traversal = new_test_traversal(true, true);
+ let results = vec![Ok(Bytes::from(vec![0x00]))];
+ let dap = TestDAP { results };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let frame_queue = FrameQueue::new(retrieval);
+ let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue);
+ let mut frames = new_test_frames(100000);
+ // Ingest frames until the channel bank is full and it stops increasing in size
+ let mut current_size = 0;
+ let next_frame = frames.pop().unwrap();
+ channel_bank.ingest_frame(next_frame).unwrap();
+ while channel_bank.size() > current_size {
+ current_size = channel_bank.size();
+ let next_frame = frames.pop().unwrap();
+ channel_bank.ingest_frame(next_frame).unwrap();
+ assert!(channel_bank.size() <= MAX_CHANNEL_BANK_SIZE);
+ }
+ // There should be a bunch of frames leftover
+ assert!(!frames.is_empty());
+ // If we ingest one more frame, the channel bank should prune
+ // and the size should be the same
+ let next_frame = frames.pop().unwrap();
+ channel_bank.ingest_frame(next_frame).unwrap();
+ assert_eq!(channel_bank.size(), current_size);
+ }
+
+ #[tokio::test]
+ async fn test_read_empty_channel_bank() {
+ let traversal = new_test_traversal(true, true);
+ let results = vec![Ok(Bytes::from(vec![0x00]))];
+ let dap = TestDAP { results };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let frame_queue = FrameQueue::new(retrieval);
+ let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue);
+ let err = channel_bank.read().unwrap_err();
+ assert_eq!(err, StageError::Eof);
+ let err = channel_bank.next_data().await.unwrap_err();
+ assert_eq!(err, StageError::Custom(anyhow!("Not Enough Data")));
+ }
+}
diff --git a/crates/derive/src/stages/frame_queue.rs b/crates/derive/src/stages/frame_queue.rs
index f0900a91b..51221eb63 100644
--- a/crates/derive/src/stages/frame_queue.rs
+++ b/crates/derive/src/stages/frame_queue.rs
@@ -47,6 +47,7 @@ where
if self.queue.is_empty() {
match self.prev.next_data().await {
Ok(data) => {
+ // TODO: what do we do with frame parsing errors?
if let Ok(frames) = Frame::parse_frames(data.as_ref()) {
self.queue.extend(frames);
}
@@ -78,3 +79,112 @@ where
Err(StageError::Eof)
}
}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use super::*;
+ use crate::stages::l1_traversal::tests::new_test_traversal;
+ use crate::traits::test_utils::TestDAP;
+ use crate::DERIVATION_VERSION_0;
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use alloy_primitives::Bytes;
+
+ pub(crate) fn new_test_frames(count: usize) -> Vec {
+ (0..count)
+ .map(|i| Frame {
+ id: [0xFF; 16],
+ number: i as u16,
+ data: vec![0xDD; 50],
+ is_last: i == count - 1,
+ })
+ .collect()
+ }
+
+ pub(crate) fn new_encoded_test_frames(count: usize) -> Bytes {
+ let frames = new_test_frames(count);
+ let mut bytes = Vec::new();
+ bytes.extend_from_slice(&[DERIVATION_VERSION_0]);
+ for frame in frames.iter() {
+ bytes.extend_from_slice(&frame.encode());
+ }
+ Bytes::from(bytes)
+ }
+
+ #[tokio::test]
+ async fn test_frame_queue_empty_bytes() {
+ let traversal = new_test_traversal(true, true);
+ let results = vec![Ok(Bytes::from(vec![0x00]))];
+ let dap = TestDAP { results };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let mut frame_queue = FrameQueue::new(retrieval);
+ let err = frame_queue.next_frame().await.unwrap_err();
+ assert_eq!(err, anyhow!("Not enough data").into());
+ }
+
+ #[tokio::test]
+ async fn test_frame_queue_no_frames_decoded() {
+ let traversal = new_test_traversal(true, true);
+ let results = vec![Err(StageError::Eof), Ok(Bytes::default())];
+ let dap = TestDAP { results };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let mut frame_queue = FrameQueue::new(retrieval);
+ let err = frame_queue.next_frame().await.unwrap_err();
+ assert_eq!(err, anyhow!("Not enough data").into());
+ }
+
+ #[tokio::test]
+ async fn test_frame_queue_wrong_derivation_version() {
+ let traversal = new_test_traversal(true, true);
+ let results = vec![Ok(Bytes::from(vec![0x01]))];
+ let dap = TestDAP { results };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let mut frame_queue = FrameQueue::new(retrieval);
+ let err = frame_queue.next_frame().await.unwrap_err();
+ assert_eq!(err, anyhow!("Unsupported derivation version").into());
+ }
+
+ #[tokio::test]
+ async fn test_frame_queue_frame_too_short() {
+ let traversal = new_test_traversal(true, true);
+ let results = vec![Ok(Bytes::from(vec![0x00, 0x01]))];
+ let dap = TestDAP { results };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let mut frame_queue = FrameQueue::new(retrieval);
+ let err = frame_queue.next_frame().await.unwrap_err();
+ assert_eq!(err, anyhow!("Frame too short to decode").into());
+ }
+
+ #[tokio::test]
+ async fn test_frame_queue_single_frame() {
+ let data = new_encoded_test_frames(1);
+ let traversal = new_test_traversal(true, true);
+ let dap = TestDAP {
+ results: vec![Ok(data)],
+ };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let mut frame_queue = FrameQueue::new(retrieval);
+ let frame_decoded = frame_queue.next_frame().await.unwrap();
+ let frame = new_test_frames(1);
+ assert_eq!(frame[0], frame_decoded);
+ let err = frame_queue.next_frame().await.unwrap_err();
+ assert_eq!(err, anyhow!("Not enough data").into());
+ }
+
+ #[tokio::test]
+ async fn test_frame_queue_multiple_frames() {
+ let data = new_encoded_test_frames(3);
+ let traversal = new_test_traversal(true, true);
+ let dap = TestDAP {
+ results: vec![Ok(data)],
+ };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let mut frame_queue = FrameQueue::new(retrieval);
+ for i in 0..3 {
+ let frame_decoded = frame_queue.next_frame().await.unwrap();
+ assert_eq!(frame_decoded.number, i);
+ }
+ let err = frame_queue.next_frame().await.unwrap_err();
+ assert_eq!(err, anyhow!("Not enough data").into());
+ }
+}
diff --git a/crates/derive/src/stages/l1_retrieval.rs b/crates/derive/src/stages/l1_retrieval.rs
index ca4ecc5b9..c69dea9c3 100644
--- a/crates/derive/src/stages/l1_retrieval.rs
+++ b/crates/derive/src/stages/l1_retrieval.rs
@@ -22,7 +22,7 @@ where
/// The data availability provider to use for the L1 retrieval stage.
pub provider: DAP,
/// The current data iterator.
- data: Option>,
+ pub(crate) data: Option,
}
impl L1Retrieval
@@ -83,3 +83,82 @@ where
Ok(())
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::stages::l1_traversal::tests::new_test_traversal;
+ use crate::traits::test_utils::{TestDAP, TestIter};
+ use alloc::vec;
+ use alloy_primitives::Address;
+
+ #[tokio::test]
+ async fn test_l1_retrieval_origin() {
+ let traversal = new_test_traversal(true, true);
+ let dap = TestDAP { results: vec![] };
+ let retrieval = L1Retrieval::new(traversal, dap);
+ let expected = BlockInfo::default();
+ assert_eq!(retrieval.origin(), Some(&expected));
+ }
+
+ #[tokio::test]
+ async fn test_l1_retrieval_next_data() {
+ let traversal = new_test_traversal(true, true);
+ let results = vec![Err(StageError::Eof), Ok(Bytes::default())];
+ let dap = TestDAP { results };
+ let mut retrieval = L1Retrieval::new(traversal, dap);
+ assert_eq!(retrieval.data, None);
+ let data = retrieval.next_data().await.unwrap();
+ assert_eq!(data, Bytes::default());
+ assert!(retrieval.data.is_some());
+ let retrieval_data = retrieval.data.as_ref().unwrap();
+ assert_eq!(retrieval_data.open_data_calls.len(), 1);
+ assert_eq!(retrieval_data.open_data_calls[0].0, BlockInfo::default());
+ assert_eq!(retrieval_data.open_data_calls[0].1, Address::default());
+ // Data should be reset to none and the error should be bubbled up.
+ let data = retrieval.next_data().await.unwrap_err();
+ assert_eq!(data, StageError::Eof);
+ assert!(retrieval.data.is_none());
+ }
+
+ #[tokio::test]
+ async fn test_l1_retrieval_existing_data_is_respected() {
+ let data = TestIter {
+ open_data_calls: vec![(BlockInfo::default(), Address::default())],
+ results: vec![Ok(Bytes::default())],
+ };
+ // Create a new traversal with no blocks or receipts.
+ // This would bubble up an error if the prev stage
+ // (traversal) is called in the retrieval stage.
+ let traversal = new_test_traversal(false, false);
+ let dap = TestDAP { results: vec![] };
+ let mut retrieval = L1Retrieval {
+ prev: traversal,
+ provider: dap,
+ data: Some(data),
+ };
+ let data = retrieval.next_data().await.unwrap();
+ assert_eq!(data, Bytes::default());
+ assert!(retrieval.data.is_some());
+ let retrieval_data = retrieval.data.as_ref().unwrap();
+ assert_eq!(retrieval_data.open_data_calls.len(), 1);
+ }
+
+ #[tokio::test]
+ async fn test_l1_retrieval_existing_data_errors() {
+ let data = TestIter {
+ open_data_calls: vec![(BlockInfo::default(), Address::default())],
+ results: vec![Err(StageError::Eof)],
+ };
+ let traversal = new_test_traversal(true, true);
+ let dap = TestDAP { results: vec![] };
+ let mut retrieval = L1Retrieval {
+ prev: traversal,
+ provider: dap,
+ data: Some(data),
+ };
+ let data = retrieval.next_data().await.unwrap_err();
+ assert_eq!(data, StageError::Eof);
+ assert!(retrieval.data.is_none());
+ }
+}
diff --git a/crates/derive/src/stages/l1_traversal.rs b/crates/derive/src/stages/l1_traversal.rs
index 253611050..37656ca49 100644
--- a/crates/derive/src/stages/l1_traversal.rs
+++ b/crates/derive/src/stages/l1_traversal.rs
@@ -12,7 +12,7 @@ use async_trait::async_trait;
#[derive(Debug, Clone, Copy)]
pub struct L1Traversal {
/// The current block in the traversal stage.
- block: Option,
+ pub(crate) block: Option,
/// The data source for the traversal stage.
data_source: Provider,
/// Signals whether or not the traversal stage has been completed.
@@ -27,7 +27,7 @@ impl L1Traversal {
/// Creates a new [L1Traversal] instance.
pub fn new(data_source: F, cfg: RollupConfig) -> Self {
Self {
- block: None,
+ block: Some(BlockInfo::default()),
data_source,
done: false,
system_config: SystemConfig::default(),
@@ -35,8 +35,14 @@ impl L1Traversal {
}
}
- /// Returns the next L1 block in the traversal stage, if the stage has not been completed. This function can only
- /// be called once, and will return `None` on subsequent calls unless the stage is reset.
+ /// Retrieves a reference to the inner data source of the [L1Traversal] stage.
+ pub fn data_source(&self) -> &F {
+ &self.data_source
+ }
+
+ /// Returns the next L1 block in the traversal stage, if the stage has not been completed.
+ /// This function can only be called once, and will return `None` on subsequent calls
+ /// unless the stage is reset.
pub fn next_l1_block(&mut self) -> StageResult