From 0c9ecebafd8c38643635fd6540af5ef086b02209 Mon Sep 17 00:00:00 2001 From: refcell Date: Wed, 3 Apr 2024 19:53:45 -0400 Subject: [PATCH 1/3] fix(derive): clean up the channel bank and add tests --- crates/derive/src/stages/channel_bank.rs | 35 ++++++++++++++++++------ crates/derive/src/types/errors.rs | 12 ++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/crates/derive/src/stages/channel_bank.rs b/crates/derive/src/stages/channel_bank.rs index 0ad391f03..6916885f0 100644 --- a/crates/derive/src/stages/channel_bank.rs +++ b/crates/derive/src/stages/channel_bank.rs @@ -65,8 +65,8 @@ where pub fn prune(&mut self) -> StageResult<()> { let mut total_size = self.size(); while total_size > MAX_CHANNEL_BANK_SIZE { - let id = self.channel_queue.pop_front().ok_or(anyhow!("No channel to prune"))?; - let channel = self.channels.remove(&id).ok_or(anyhow!("Could not find channel"))?; + let id = self.channel_queue.pop_front().ok_or(StageError::NoChannelsAvailable)?; + let channel = self.channels.remove(&id).ok_or(StageError::ChannelNotFound)?; total_size -= channel.size(); } Ok(()) @@ -74,7 +74,7 @@ where /// Adds new L1 data to the channel bank. Should only be called after all data has been read. pub fn ingest_frame(&mut self, frame: Frame) -> StageResult<()> { - let origin = *self.origin().ok_or(anyhow!("No origin"))?; + let origin = *self.origin().ok_or(StageError::MissingOrigin)?; // Get the channel for the frame, or create a new one if it doesn't exist. let current_channel = self.channels.entry(frame.id).or_insert_with(|| { @@ -108,8 +108,8 @@ where // Return an `Ok(None)` if the first channel is timed out. There may be more timed // out channels at the head of the queue and we want to remove them all. let first = self.channel_queue[0]; - let channel = self.channels.get(&first).ok_or(anyhow!("Channel not found"))?; - let origin = self.origin().ok_or(anyhow!("No origin present"))?; + let channel = self.channels.get(&first).ok_or(StageError::ChannelNotFound)?; + let origin = self.origin().ok_or(StageError::MissingOrigin)?; // Remove all timed out channels from the front of the `channel_queue`. if channel.open_block_number() + self.cfg.channel_timeout < origin.number { @@ -162,8 +162,8 @@ where /// If the channel read was successful, it will remove the channel from the channel queue. fn try_read_channel_at_index(&mut self, index: usize) -> StageResult { let channel_id = self.channel_queue[index]; - let channel = self.channels.get(&channel_id).ok_or(anyhow!("Channel not found"))?; - let origin = self.origin().ok_or(anyhow!("No origin present"))?; + let channel = self.channels.get(&channel_id).ok_or(StageError::ChannelNotFound)?; + let origin = self.origin().ok_or(StageError::MissingOrigin)?; let timed_out = channel.open_block_number() + self.cfg.channel_timeout < origin.number; if timed_out || !channel.is_ready() { @@ -213,7 +213,26 @@ mod tests { let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue); let frame = Frame::default(); let err = channel_bank.ingest_frame(frame).unwrap_err(); - assert_eq!(err, StageError::Custom(anyhow!("No origin"))); + assert_eq!(err, StageError::MissingOrigin); + } + + #[test] + fn test_ingest_invalid_frame() { + let traversal = new_test_traversal(false, false); + let dap = TestDAP::default(); + let retrieval = L1Retrieval::new(traversal, dap); + let frame_queue = FrameQueue::new(retrieval); + let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue); + let frame = Frame { id: [0xFF; 16], ..Default::default() }; + assert_eq!(channel_bank.size(), 0); + assert!(channel_bank.channels.is_empty()); + assert_eq!(channel_bank.ingest_frame(frame.clone()), Ok(())); + assert_eq!(channel_bank.size(), crate::params::FRAME_OVERHEAD); + assert_eq!(channel_bank.channels.len(), 1); + // This should fail since the frame is already ingested. + assert_eq!(channel_bank.ingest_frame(frame), Ok(())); + assert_eq!(channel_bank.size(), crate::params::FRAME_OVERHEAD); + assert_eq!(channel_bank.channels.len(), 1); } #[test] diff --git a/crates/derive/src/types/errors.rs b/crates/derive/src/types/errors.rs index 1b97a6249..4bb74e63f 100644 --- a/crates/derive/src/types/errors.rs +++ b/crates/derive/src/types/errors.rs @@ -12,6 +12,12 @@ pub enum StageError { /// There is not enough data progress, but if we wait, the stage will eventually return data /// or produce an EOF error. NotEnoughData, + /// No channels are available in the channel bank. + NoChannelsAvailable, + /// Failed to find channel. + ChannelNotFound, + /// Missing L1 origin. + MissingOrigin, /// Other wildcard error. Custom(anyhow::Error), } @@ -22,6 +28,9 @@ impl PartialEq for StageError { (self, other), (StageError::Eof, StageError::Eof) | (StageError::NotEnoughData, StageError::NotEnoughData) | + (StageError::NoChannelsAvailable, StageError::NoChannelsAvailable) | + (StageError::ChannelNotFound, StageError::ChannelNotFound) | + (StageError::MissingOrigin, StageError::MissingOrigin) | (StageError::Custom(_), StageError::Custom(_)) ) } @@ -41,6 +50,9 @@ impl Display for StageError { match self { StageError::Eof => write!(f, "End of file"), StageError::NotEnoughData => write!(f, "Not enough data"), + StageError::NoChannelsAvailable => write!(f, "No channels available"), + StageError::ChannelNotFound => write!(f, "Channel not found"), + StageError::MissingOrigin => write!(f, "Missing L1 origin"), StageError::Custom(e) => write!(f, "Custom error: {}", e), } } From df76da5ab6143da1fb502ec381bae0f52c552dac Mon Sep 17 00:00:00 2001 From: refcell Date: Wed, 3 Apr 2024 22:24:01 -0400 Subject: [PATCH 2/3] fix(derive): channel bank tests --- crates/derive/src/stages/channel_bank.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/derive/src/stages/channel_bank.rs b/crates/derive/src/stages/channel_bank.rs index 29000ff5f..bd3ce708c 100644 --- a/crates/derive/src/stages/channel_bank.rs +++ b/crates/derive/src/stages/channel_bank.rs @@ -243,11 +243,12 @@ mod tests { #[test] fn test_ingest_invalid_frame() { - let traversal = new_test_traversal(false, false); + let traversal = new_test_traversal(vec![], vec![]); let dap = TestDAP::default(); - let retrieval = L1Retrieval::new(traversal, dap); - let frame_queue = FrameQueue::new(retrieval); - let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue); + let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); + let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); + let mut channel_bank = + ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); let frame = Frame { id: [0xFF; 16], ..Default::default() }; assert_eq!(channel_bank.size(), 0); assert!(channel_bank.channels.is_empty()); From 382d8a545acf8a04fd69034b412c5a59e1f0a626 Mon Sep 17 00:00:00 2001 From: refcell Date: Thu, 4 Apr 2024 08:09:23 -0400 Subject: [PATCH 3/3] fix(derive): channel bank testing with spinlocked primitives --- Cargo.lock | 10 ++++++ crates/derive/Cargo.toml | 1 + crates/derive/src/stages/batch_queue.rs | 2 +- crates/derive/src/stages/channel_bank.rs | 35 +++++++++++++------ .../derive/src/traits/test_utils/telemetry.rs | 25 +++++++++---- 5 files changed, 55 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 22304b806..453632c5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,6 +362,7 @@ dependencies = [ "miniz_oxide", "proptest", "serde", + "spin", "tokio", "unsigned-varint", ] @@ -762,6 +763,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "spinning_top" version = "0.2.5" diff --git a/crates/derive/Cargo.toml b/crates/derive/Cargo.toml index 5adf8b0cd..a06cbade8 100644 --- a/crates/derive/Cargo.toml +++ b/crates/derive/Cargo.toml @@ -27,6 +27,7 @@ serde = { version = "1.0.197", default-features = false, features = ["derive"], [dev-dependencies] tokio = { version = "1.36", features = ["full"] } proptest = "1.4.0" +spin = { version = "0.9.8", features = ["mutex"] } # Spin is used for testing synchronization primitives [features] serde = ["dep:serde", "alloy-primitives/serde"] diff --git a/crates/derive/src/stages/batch_queue.rs b/crates/derive/src/stages/batch_queue.rs index d7fc4498b..cfc980edd 100644 --- a/crates/derive/src/stages/batch_queue.rs +++ b/crates/derive/src/stages/batch_queue.rs @@ -354,7 +354,7 @@ where DAP: DataAvailabilityProvider + Send + Debug, CP: ChainProvider + Send + Debug, BF: SafeBlockFetcher + Send + Debug, - T: TelemetryProvider + Send + Debug, + T: TelemetryProvider + Send + Debug + Sync, { async fn reset(&mut self, base: BlockInfo, _: SystemConfig) -> StageResult<()> { // Copy over the Origin from the next stage. diff --git a/crates/derive/src/stages/channel_bank.rs b/crates/derive/src/stages/channel_bank.rs index bd3ce708c..75a5eec9e 100644 --- a/crates/derive/src/stages/channel_bank.rs +++ b/crates/derive/src/stages/channel_bank.rs @@ -36,7 +36,7 @@ where /// The rollup configuration. cfg: Arc, /// Telemetry - telemetry: T, + telemetry: Arc, /// Map of channels by ID. channels: HashMap, /// Channels in FIFO order. @@ -52,7 +52,7 @@ where T: TelemetryProvider + Debug, { /// Create a new [ChannelBank] stage. - pub fn new(cfg: Arc, prev: FrameQueue, telemetry: T) -> Self { + pub fn new(cfg: Arc, prev: FrameQueue, telemetry: Arc) -> Self { Self { cfg, telemetry, channels: HashMap::new(), channel_queue: VecDeque::new(), prev } } @@ -207,7 +207,7 @@ impl ResettableStage for ChannelBank where DAP: DataAvailabilityProvider + Send + Debug, CP: ChainProvider + Send + Debug, - T: TelemetryProvider + Send + Debug, + T: TelemetryProvider + Send + Sync + Debug, { async fn reset(&mut self, _: BlockInfo, _: SystemConfig) -> StageResult<()> { self.channels.clear(); @@ -234,8 +234,12 @@ mod tests { let dap = TestDAP::default(); let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); - let mut channel_bank = - ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); + let telemetry = Arc::new(TestTelemetry::new()); + let mut channel_bank = ChannelBank::new( + Arc::new(RollupConfig::default()), + frame_queue, + Arc::clone(&telemetry), + ); let frame = Frame::default(); let err = channel_bank.ingest_frame(frame).unwrap_err(); assert_eq!(err, StageError::MissingOrigin); @@ -247,11 +251,13 @@ mod tests { let dap = TestDAP::default(); let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); + let telem = Arc::new(TestTelemetry::new()); let mut channel_bank = - ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); + ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, Arc::clone(&telem)); let frame = Frame { id: [0xFF; 16], ..Default::default() }; assert_eq!(channel_bank.size(), 0); assert!(channel_bank.channels.is_empty()); + assert_eq!(telem.count_calls(LogLevel::Warning), 0); assert_eq!(channel_bank.ingest_frame(frame.clone()), Ok(())); assert_eq!(channel_bank.size(), crate::params::FRAME_OVERHEAD); assert_eq!(channel_bank.channels.len(), 1); @@ -259,6 +265,7 @@ mod tests { assert_eq!(channel_bank.ingest_frame(frame), Ok(())); assert_eq!(channel_bank.size(), crate::params::FRAME_OVERHEAD); assert_eq!(channel_bank.channels.len(), 1); + assert_eq!(telem.count_calls(LogLevel::Warning), 1); } #[test] @@ -268,8 +275,12 @@ mod tests { let dap = TestDAP { results }; let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); - let mut channel_bank = - ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); + let telemetry = Arc::new(TestTelemetry::new()); + let mut channel_bank = ChannelBank::new( + Arc::new(RollupConfig::default()), + frame_queue, + Arc::clone(&telemetry), + ); let mut frames = new_test_frames(100000); // Ingest frames until the channel bank is full and it stops increasing in size let mut current_size = 0; @@ -297,8 +308,12 @@ mod tests { let dap = TestDAP { results }; let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); - let mut channel_bank = - ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); + let telemetry = Arc::new(TestTelemetry::new()); + let mut channel_bank = ChannelBank::new( + Arc::new(RollupConfig::default()), + frame_queue, + Arc::clone(&telemetry), + ); let err = channel_bank.read().unwrap_err(); assert_eq!(err, StageError::Eof); let err = channel_bank.next_data().await.unwrap_err(); diff --git a/crates/derive/src/traits/test_utils/telemetry.rs b/crates/derive/src/traits/test_utils/telemetry.rs index bb1448cfd..a7d32300a 100644 --- a/crates/derive/src/traits/test_utils/telemetry.rs +++ b/crates/derive/src/traits/test_utils/telemetry.rs @@ -1,15 +1,15 @@ //! Test Utilities for Telemetry use crate::traits::{LogLevel, TelemetryProvider}; -use alloc::{rc::Rc, vec::Vec}; +use alloc::{sync::Arc, vec::Vec}; use alloy_primitives::Bytes; -use core::cell::RefCell; +use spin::mutex::Mutex; /// Mock telemetry provider #[derive(Debug, Default)] pub struct TestTelemetry { /// Holds telemetry data with log levels for assertions. - pub(crate) telemetry_calls: Rc>>, + pub(crate) telemetry_calls: Arc>>, } impl TestTelemetry { @@ -17,14 +17,25 @@ impl TestTelemetry { pub fn new() -> Self { Self::default() } + + /// Checks the existance of a given ([Bytes], [LogLevel]) call. + pub fn exists(&self, data: Bytes, level: LogLevel) -> bool { + let guard = self.telemetry_calls.lock(); + guard.iter().filter(|(d, l)| *d == data && *l == level).count() > 0 + } + + /// Counts the number of telemetry calls with the given [LogLevel]. + pub fn count_calls(&self, level: LogLevel) -> usize { + let guard = self.telemetry_calls.lock(); + guard.iter().filter(|(_, l)| *l == level).count() + } } impl TelemetryProvider for TestTelemetry { fn write>(&self, data: I, level: LogLevel) { let data = (data.into(), level); - { - let mut calls = self.telemetry_calls.borrow_mut(); - (*calls).push(data); - } + let binding = self.telemetry_calls.clone(); + let mut guard = binding.lock(); + guard.push(data); } }