diff --git a/Cargo.lock b/Cargo.lock index 22304b806..453632c5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,6 +362,7 @@ dependencies = [ "miniz_oxide", "proptest", "serde", + "spin", "tokio", "unsigned-varint", ] @@ -762,6 +763,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "spinning_top" version = "0.2.5" diff --git a/crates/derive/Cargo.toml b/crates/derive/Cargo.toml index 5adf8b0cd..a06cbade8 100644 --- a/crates/derive/Cargo.toml +++ b/crates/derive/Cargo.toml @@ -27,6 +27,7 @@ serde = { version = "1.0.197", default-features = false, features = ["derive"], [dev-dependencies] tokio = { version = "1.36", features = ["full"] } proptest = "1.4.0" +spin = { version = "0.9.8", features = ["mutex"] } # Spin is used for testing synchronization primitives [features] serde = ["dep:serde", "alloy-primitives/serde"] diff --git a/crates/derive/src/stages/batch_queue.rs b/crates/derive/src/stages/batch_queue.rs index 06a769d2f..a6f3172ee 100644 --- a/crates/derive/src/stages/batch_queue.rs +++ b/crates/derive/src/stages/batch_queue.rs @@ -359,7 +359,7 @@ where DAP: DataAvailabilityProvider + Send + Debug, CP: ChainProvider + Send + Debug, BF: SafeBlockFetcher + Send + Debug, - T: TelemetryProvider + Send + Debug, + T: TelemetryProvider + Send + Debug + Sync, { async fn reset(&mut self, base: BlockInfo, _: SystemConfig) -> StageResult<()> { // Copy over the Origin from the next stage. diff --git a/crates/derive/src/stages/channel_bank.rs b/crates/derive/src/stages/channel_bank.rs index fd9518328..75a5eec9e 100644 --- a/crates/derive/src/stages/channel_bank.rs +++ b/crates/derive/src/stages/channel_bank.rs @@ -36,7 +36,7 @@ where /// The rollup configuration. cfg: Arc, /// Telemetry - telemetry: T, + telemetry: Arc, /// Map of channels by ID. channels: HashMap, /// Channels in FIFO order. @@ -52,7 +52,7 @@ where T: TelemetryProvider + Debug, { /// Create a new [ChannelBank] stage. - pub fn new(cfg: Arc, prev: FrameQueue, telemetry: T) -> Self { + pub fn new(cfg: Arc, prev: FrameQueue, telemetry: Arc) -> Self { Self { cfg, telemetry, channels: HashMap::new(), channel_queue: VecDeque::new(), prev } } @@ -71,8 +71,8 @@ where pub fn prune(&mut self) -> StageResult<()> { let mut total_size = self.size(); while total_size > MAX_CHANNEL_BANK_SIZE { - let id = self.channel_queue.pop_front().ok_or(anyhow!("No channel to prune"))?; - let channel = self.channels.remove(&id).ok_or(anyhow!("Could not find channel"))?; + let id = self.channel_queue.pop_front().ok_or(StageError::NoChannelsAvailable)?; + let channel = self.channels.remove(&id).ok_or(StageError::ChannelNotFound)?; total_size -= channel.size(); } Ok(()) @@ -80,7 +80,7 @@ where /// Adds new L1 data to the channel bank. Should only be called after all data has been read. pub fn ingest_frame(&mut self, frame: Frame) -> StageResult<()> { - let origin = *self.origin().ok_or(anyhow!("No origin"))?; + let origin = *self.origin().ok_or(StageError::MissingOrigin)?; // Get the channel for the frame, or create a new one if it doesn't exist. let current_channel = self.channels.entry(frame.id).or_insert_with(|| { @@ -128,8 +128,8 @@ where // Return an `Ok(None)` if the first channel is timed out. There may be more timed // out channels at the head of the queue and we want to remove them all. let first = self.channel_queue[0]; - let channel = self.channels.get(&first).ok_or(anyhow!("Channel not found"))?; - let origin = self.origin().ok_or(anyhow!("No origin present"))?; + let channel = self.channels.get(&first).ok_or(StageError::ChannelNotFound)?; + let origin = self.origin().ok_or(StageError::MissingOrigin)?; // Remove all timed out channels from the front of the `channel_queue`. if channel.open_block_number() + self.cfg.channel_timeout < origin.number { @@ -186,8 +186,8 @@ where /// If the channel read was successful, it will remove the channel from the channel queue. fn try_read_channel_at_index(&mut self, index: usize) -> StageResult { let channel_id = self.channel_queue[index]; - let channel = self.channels.get(&channel_id).ok_or(anyhow!("Channel not found"))?; - let origin = self.origin().ok_or(anyhow!("No origin present"))?; + let channel = self.channels.get(&channel_id).ok_or(StageError::ChannelNotFound)?; + let origin = self.origin().ok_or(StageError::MissingOrigin)?; let timed_out = channel.open_block_number() + self.cfg.channel_timeout < origin.number; if timed_out || !channel.is_ready() { @@ -207,7 +207,7 @@ impl ResettableStage for ChannelBank where DAP: DataAvailabilityProvider + Send + Debug, CP: ChainProvider + Send + Debug, - T: TelemetryProvider + Send + Debug, + T: TelemetryProvider + Send + Sync + Debug, { async fn reset(&mut self, _: BlockInfo, _: SystemConfig) -> StageResult<()> { self.channels.clear(); @@ -234,11 +234,38 @@ mod tests { let dap = TestDAP::default(); let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); - let mut channel_bank = - ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); + let telemetry = Arc::new(TestTelemetry::new()); + let mut channel_bank = ChannelBank::new( + Arc::new(RollupConfig::default()), + frame_queue, + Arc::clone(&telemetry), + ); let frame = Frame::default(); let err = channel_bank.ingest_frame(frame).unwrap_err(); - assert_eq!(err, StageError::Custom(anyhow!("No origin"))); + assert_eq!(err, StageError::MissingOrigin); + } + + #[test] + fn test_ingest_invalid_frame() { + let traversal = new_test_traversal(vec![], vec![]); + let dap = TestDAP::default(); + let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); + let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); + let telem = Arc::new(TestTelemetry::new()); + let mut channel_bank = + ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, Arc::clone(&telem)); + let frame = Frame { id: [0xFF; 16], ..Default::default() }; + assert_eq!(channel_bank.size(), 0); + assert!(channel_bank.channels.is_empty()); + assert_eq!(telem.count_calls(LogLevel::Warning), 0); + assert_eq!(channel_bank.ingest_frame(frame.clone()), Ok(())); + assert_eq!(channel_bank.size(), crate::params::FRAME_OVERHEAD); + assert_eq!(channel_bank.channels.len(), 1); + // This should fail since the frame is already ingested. + assert_eq!(channel_bank.ingest_frame(frame), Ok(())); + assert_eq!(channel_bank.size(), crate::params::FRAME_OVERHEAD); + assert_eq!(channel_bank.channels.len(), 1); + assert_eq!(telem.count_calls(LogLevel::Warning), 1); } #[test] @@ -248,8 +275,12 @@ mod tests { let dap = TestDAP { results }; let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); - let mut channel_bank = - ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); + let telemetry = Arc::new(TestTelemetry::new()); + let mut channel_bank = ChannelBank::new( + Arc::new(RollupConfig::default()), + frame_queue, + Arc::clone(&telemetry), + ); let mut frames = new_test_frames(100000); // Ingest frames until the channel bank is full and it stops increasing in size let mut current_size = 0; @@ -277,8 +308,12 @@ mod tests { let dap = TestDAP { results }; let retrieval = L1Retrieval::new(traversal, dap, TestTelemetry::new()); let frame_queue = FrameQueue::new(retrieval, TestTelemetry::new()); - let mut channel_bank = - ChannelBank::new(Arc::new(RollupConfig::default()), frame_queue, TestTelemetry::new()); + let telemetry = Arc::new(TestTelemetry::new()); + let mut channel_bank = ChannelBank::new( + Arc::new(RollupConfig::default()), + frame_queue, + Arc::clone(&telemetry), + ); let err = channel_bank.read().unwrap_err(); assert_eq!(err, StageError::Eof); let err = channel_bank.next_data().await.unwrap_err(); diff --git a/crates/derive/src/traits/test_utils/telemetry.rs b/crates/derive/src/traits/test_utils/telemetry.rs index bb1448cfd..a7d32300a 100644 --- a/crates/derive/src/traits/test_utils/telemetry.rs +++ b/crates/derive/src/traits/test_utils/telemetry.rs @@ -1,15 +1,15 @@ //! Test Utilities for Telemetry use crate::traits::{LogLevel, TelemetryProvider}; -use alloc::{rc::Rc, vec::Vec}; +use alloc::{sync::Arc, vec::Vec}; use alloy_primitives::Bytes; -use core::cell::RefCell; +use spin::mutex::Mutex; /// Mock telemetry provider #[derive(Debug, Default)] pub struct TestTelemetry { /// Holds telemetry data with log levels for assertions. - pub(crate) telemetry_calls: Rc>>, + pub(crate) telemetry_calls: Arc>>, } impl TestTelemetry { @@ -17,14 +17,25 @@ impl TestTelemetry { pub fn new() -> Self { Self::default() } + + /// Checks the existance of a given ([Bytes], [LogLevel]) call. + pub fn exists(&self, data: Bytes, level: LogLevel) -> bool { + let guard = self.telemetry_calls.lock(); + guard.iter().filter(|(d, l)| *d == data && *l == level).count() > 0 + } + + /// Counts the number of telemetry calls with the given [LogLevel]. + pub fn count_calls(&self, level: LogLevel) -> usize { + let guard = self.telemetry_calls.lock(); + guard.iter().filter(|(_, l)| *l == level).count() + } } impl TelemetryProvider for TestTelemetry { fn write>(&self, data: I, level: LogLevel) { let data = (data.into(), level); - { - let mut calls = self.telemetry_calls.borrow_mut(); - (*calls).push(data); - } + let binding = self.telemetry_calls.clone(); + let mut guard = binding.lock(); + guard.push(data); } } diff --git a/crates/derive/src/types/errors.rs b/crates/derive/src/types/errors.rs index 2b88cf86a..f35f20ae4 100644 --- a/crates/derive/src/types/errors.rs +++ b/crates/derive/src/types/errors.rs @@ -15,6 +15,12 @@ pub enum StageError { /// There is not enough data progress, but if we wait, the stage will eventually return data /// or produce an EOF error. NotEnoughData, + /// No channels are available in the channel bank. + NoChannelsAvailable, + /// Failed to find channel. + ChannelNotFound, + /// Missing L1 origin. + MissingOrigin, /// Failed to build the [super::PayloadAttributes] for the next batch. AttributesBuild(anyhow::Error), /// Reset the pipeline. @@ -46,6 +52,9 @@ impl PartialEq for StageError { (self, other), (StageError::Eof, StageError::Eof) | (StageError::NotEnoughData, StageError::NotEnoughData) | + (StageError::NoChannelsAvailable, StageError::NoChannelsAvailable) | + (StageError::ChannelNotFound, StageError::ChannelNotFound) | + (StageError::MissingOrigin, StageError::MissingOrigin) | (StageError::AttributesBuild(_), StageError::AttributesBuild(_)) | (StageError::ReceiptFetch(_), StageError::ReceiptFetch(_)) | (StageError::BlockInfoFetch(_), StageError::BlockInfoFetch(_)) | @@ -66,6 +75,9 @@ impl Display for StageError { match self { StageError::Eof => write!(f, "End of file"), StageError::NotEnoughData => write!(f, "Not enough data"), + StageError::NoChannelsAvailable => write!(f, "No channels available"), + StageError::ChannelNotFound => write!(f, "Channel not found"), + StageError::MissingOrigin => write!(f, "Missing L1 origin"), StageError::AttributesBuild(e) => write!(f, "Attributes build error: {}", e), StageError::Reset(e) => write!(f, "Reset error: {}", e), StageError::ReceiptFetch(e) => write!(f, "Receipt fetch error: {}", e),