From 6e62db1d9e42dc0bb225f2c5ef864dbdfc00a439 Mon Sep 17 00:00:00 2001 From: Shreyan Gupta Date: Sat, 11 Jan 2025 10:37:10 +0530 Subject: [PATCH] [trie] Push tracking proof_size_limit to trie recorder (#12710) This PR is the first part of https://github.com/near/nearcore/issues/12701 The PR moves tracking of proof_size_limit from runtime to trie recorder. There should be no functional change. This is a more natural place to expose the check for proof_size_limit and sets the basis for future improvements like - Potentially moving compute_limit to recorder as well - Better tracking and checking of limits (required for resharding) - Potential setup to add better checks for limits (required for reading and managing buffered receipts) --- chain/chain/src/resharding/manager.rs | 2 +- chain/chain/src/runtime/mod.rs | 14 +++-- core/store/src/trie/mod.rs | 16 ++++- core/store/src/trie/ops/tests.rs | 2 +- core/store/src/trie/trie_recording.rs | 41 ++++++++----- core/store/src/trie/update.rs | 2 +- runtime/runtime/src/lib.rs | 85 +++++++++------------------ runtime/runtime/src/tests/apply.rs | 15 ++++- 8 files changed, 95 insertions(+), 82 deletions(-) diff --git a/chain/chain/src/resharding/manager.rs b/chain/chain/src/resharding/manager.rs index b5ee4a6652e..b99b1dfec52 100644 --- a/chain/chain/src/resharding/manager.rs +++ b/chain/chain/src/resharding/manager.rs @@ -216,7 +216,7 @@ impl ReshardingManager { "Creating child memtrie by retaining nodes in parent memtrie..." ); let mut mem_tries = mem_tries.write().unwrap(); - let mut trie_recorder = TrieRecorder::new(); + let mut trie_recorder = TrieRecorder::new(None); let mode = TrackingMode::RefcountsAndAccesses(&mut trie_recorder); let mem_trie_update = mem_tries.update(*parent_chunk_extra.state_root(), mode)?; diff --git a/chain/chain/src/runtime/mod.rs b/chain/chain/src/runtime/mod.rs index 83baa60461c..42b4b03c9cc 100644 --- a/chain/chain/src/runtime/mod.rs +++ b/chain/chain/src/runtime/mod.rs @@ -614,6 +614,7 @@ impl RuntimeAdapter for NightshadeRuntime { let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&prev_block.block_hash)?; let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id)?; + let runtime_config = self.runtime_config_store.get_config(protocol_version); let next_epoch_id = self.epoch_manager.get_next_epoch_id_from_prev_block(&(&prev_block.block_hash))?; @@ -650,7 +651,9 @@ impl RuntimeAdapter for NightshadeRuntime { if ProtocolFeature::StatelessValidation.enabled(next_protocol_version) || cfg!(feature = "shadow_chunk_validation") { - trie = trie.recording_reads_new_recorder(); + let proof_size_limit = + runtime_config.witness_config.new_transactions_validation_state_size_soft_limit; + trie = trie.recording_reads_with_proof_size_limit(proof_size_limit); } let mut state_update = TrieUpdate::new(trie); @@ -658,8 +661,6 @@ impl RuntimeAdapter for NightshadeRuntime { let mut total_gas_burnt = 0; let mut total_size = 0u64; - let runtime_config = self.runtime_config_store.get_config(protocol_version); - let transactions_gas_limit = chunk_tx_gas_limit(protocol_version, runtime_config, &prev_block, shard_id, gas_limit); @@ -882,7 +883,12 @@ impl RuntimeAdapter for NightshadeRuntime { if ProtocolFeature::StatelessValidation.enabled(next_protocol_version) || cfg!(feature = "shadow_chunk_validation") { - trie = trie.recording_reads_new_recorder(); + let epoch_id = + self.epoch_manager.get_epoch_id_from_prev_block(&block.prev_block_hash)?; + let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id)?; + let config = self.runtime_config_store.get_config(protocol_version); + let proof_limit = config.witness_config.main_storage_proof_size_soft_limit; + trie = trie.recording_reads_with_proof_size_limit(proof_limit); } match self.process_state_update( diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index 67c6ce9f29b..a9471fa0c0b 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -723,11 +723,18 @@ impl Trie { /// Makes a new trie that has everything the same except that access /// through that trie accumulates a state proof for all nodes accessed. pub fn recording_reads_new_recorder(&self) -> Self { - self.recording_reads_with_recorder(RefCell::new(TrieRecorder::new())) + let recorder = RefCell::new(TrieRecorder::new(None)); + self.recording_reads_with_recorder(recorder) } /// Makes a new trie that has everything the same except that access /// through that trie accumulates a state proof for all nodes accessed. + /// We also supply a proof size limit to prevent the proof from growing too large. + pub fn recording_reads_with_proof_size_limit(&self, proof_size_limit: usize) -> Self { + let recorder = RefCell::new(TrieRecorder::new(Some(proof_size_limit))); + self.recording_reads_with_recorder(recorder) + } + pub fn recording_reads_with_recorder(&self, recorder: RefCell) -> Self { let mut trie = Self::new_with_memtries( self.storage.clone(), @@ -766,6 +773,13 @@ impl Trie { .unwrap_or_default() } + pub fn check_proof_size_limit_exceed(&self) -> bool { + self.recorder + .as_ref() + .map(|recorder| recorder.borrow().check_proof_size_limit_exceed()) + .unwrap_or_default() + } + /// Constructs a Trie from the partial storage (i.e. state proof) that /// was returned from recorded_storage(). If used to access the same trie /// nodes as when the partial storage was generated, this trie will behave diff --git a/core/store/src/trie/ops/tests.rs b/core/store/src/trie/ops/tests.rs index ca8eb8dab88..b2afaeeb2aa 100644 --- a/core/store/src/trie/ops/tests.rs +++ b/core/store/src/trie/ops/tests.rs @@ -104,7 +104,7 @@ fn run(initial_entries: Vec<(Vec, Vec)>, retain_multi_ranges: Vec>, size: usize, + /// Size of the recorded state proof plus some additional size added to cover removals and contract code. + /// An upper-bound estimation of the true recorded size after finalization. + /// See https://github.com/near/nearcore/issues/10890 and https://github.com/near/nearcore/pull/11000 for details. + upper_bound_size: usize, + /// Soft limit on the maximum size of the state proof that can be recorded. + proof_size_limit: Option, /// Counts removals performed while recording. /// recorded_storage_size_upper_bound takes it into account when calculating the total size. removal_counter: usize, @@ -45,10 +51,12 @@ pub struct SubtreeSize { } impl TrieRecorder { - pub fn new() -> Self { + pub fn new(proof_size_limit: Option) -> Self { Self { recorded: HashMap::new(), size: 0, + upper_bound_size: 0, + proof_size_limit, removal_counter: 0, code_len_counter: 0, codes_to_record: Default::default(), @@ -66,16 +74,27 @@ impl TrieRecorder { pub fn record(&mut self, hash: &CryptoHash, node: Arc<[u8]>) { let size = node.len(); if self.recorded.insert(*hash, node).is_none() { - self.size += size; + self.size = self.size.checked_add(size).unwrap(); + self.upper_bound_size = self.upper_bound_size.checked_add(size).unwrap(); } } - pub fn record_removal(&mut self) { - self.removal_counter = self.removal_counter.saturating_add(1) + pub fn record_key_removal(&mut self) { + // Charge 2000 bytes for every removal + self.removal_counter = self.removal_counter.checked_add(1).unwrap(); + self.upper_bound_size = self.upper_bound_size.checked_add(2000).unwrap(); } pub fn record_code_len(&mut self, code_len: usize) { - self.code_len_counter = self.code_len_counter.saturating_add(code_len) + self.code_len_counter = self.code_len_counter.checked_add(code_len).unwrap(); + self.upper_bound_size = self.upper_bound_size.checked_add(code_len).unwrap(); + } + + pub fn check_proof_size_limit_exceed(&self) -> bool { + if let Some(proof_size_limit) = self.proof_size_limit { + return self.upper_bound_size > proof_size_limit; + } + false } pub fn recorded_storage(&mut self) -> PartialStorage { @@ -88,19 +107,11 @@ impl TrieRecorder { self.size } - /// Size of the recorded state proof plus some additional size added to cover removals - /// and contract codes. - /// An upper-bound estimation of the true recorded size after finalization. - /// See https://github.com/near/nearcore/issues/10890 and https://github.com/near/nearcore/pull/11000 for details. pub fn recorded_storage_size_upper_bound(&self) -> usize { - // Charge 2000 bytes for every removal - let removals_size = self.removal_counter.saturating_mul(2000); - self.recorded_storage_size() - .saturating_add(removals_size) - .saturating_add(self.code_len_counter) + self.upper_bound_size } - /// Get statisitics about the recorded trie. Useful for observability and debugging. + /// Get statistics about the recorded trie. Useful for observability and debugging. /// This scans all of the recorded data, so could potentially be expensive to run. pub fn get_stats(&self, trie_root: &CryptoHash) -> TrieRecorderStats { let mut trie_column_sizes = Vec::new(); diff --git a/core/store/src/trie/update.rs b/core/store/src/trie/update.rs index e2de65a3991..dd57a75cdaf 100644 --- a/core/store/src/trie/update.rs +++ b/core/store/src/trie/update.rs @@ -148,7 +148,7 @@ impl TrieUpdate { // by the runtime are assumed to be non-malicious and we don't charge extra for them. if let Some(recorder) = &self.trie.recorder { if matches!(trie_key, TrieKey::ContractData { .. }) { - recorder.borrow_mut().record_removal(); + recorder.borrow_mut().record_key_removal(); } } diff --git a/runtime/runtime/src/lib.rs b/runtime/runtime/src/lib.rs index 9321108039c..4073730d9af 100644 --- a/runtime/runtime/src/lib.rs +++ b/runtime/runtime/src/lib.rs @@ -1656,11 +1656,14 @@ impl Runtime { compute_usage = tracing::field::Empty, ) .entered(); + let state_update = &mut processing_state.state_update; - let node_counter_before = state_update.trie().get_trie_nodes_count(); - let recorded_storage_size_before = state_update.trie().recorded_storage_size(); - let storage_proof_size_upper_bound_before = - state_update.trie().recorded_storage_size_upper_bound(); + let trie = state_update.trie(); + let node_counter_before = trie.get_trie_nodes_count(); + let recorded_storage_size_before = trie.recorded_storage_size(); + let storage_proof_size_upper_bound_before = trie.recorded_storage_size_upper_bound(); + + // Main logic let result = self.process_receipt( processing_state, receipt, @@ -1668,42 +1671,38 @@ impl Runtime { &mut validator_proposals, ); - let total = &mut processing_state.total; - let state_update = &mut processing_state.state_update; - let node_counter_after = state_update.trie().get_trie_nodes_count(); - tracing::trace!(target: "runtime", ?node_counter_before, ?node_counter_after); - let recorded_storage_diff = state_update - .trie() - .recorded_storage_size() - .saturating_sub(recorded_storage_size_before) - as f64; - let recorded_storage_upper_bound_diff = state_update - .trie() - .recorded_storage_size_upper_bound() - .saturating_sub(storage_proof_size_upper_bound_before) - as f64; let shard_id_str = processing_state.apply_state.shard_id.to_string(); + let trie = processing_state.state_update.trie(); + + let node_counter_after = trie.get_trie_nodes_count(); + tracing::trace!(target: "runtime", ?node_counter_before, ?node_counter_after); + + let recorded_storage_diff = trie.recorded_storage_size() - recorded_storage_size_before; + let recorded_storage_upper_bound_diff = + trie.recorded_storage_size_upper_bound() - storage_proof_size_upper_bound_before; metrics::RECEIPT_RECORDED_SIZE .with_label_values(&[shard_id_str.as_str()]) - .observe(recorded_storage_diff); + .observe(recorded_storage_diff as f64); metrics::RECEIPT_RECORDED_SIZE_UPPER_BOUND .with_label_values(&[shard_id_str.as_str()]) - .observe(recorded_storage_upper_bound_diff); + .observe(recorded_storage_upper_bound_diff as f64); let recorded_storage_proof_ratio = - recorded_storage_upper_bound_diff / f64::max(1.0, recorded_storage_diff); + recorded_storage_upper_bound_diff as f64 / f64::max(1.0, recorded_storage_diff as f64); // Record the ratio only for large receipts, small receipts can have a very high ratio, // but the ratio is not that important for them. - if recorded_storage_upper_bound_diff > 100_000. { + if recorded_storage_upper_bound_diff > 100_000 { metrics::RECEIPT_RECORDED_SIZE_UPPER_BOUND_RATIO .with_label_values(&[shard_id_str.as_str()]) .observe(recorded_storage_proof_ratio); } + if let Some(outcome_with_id) = result? { let gas_burnt = outcome_with_id.outcome.gas_burnt; let compute_usage = outcome_with_id .outcome .compute_usage .expect("`process_receipt` must populate compute usage"); + let total = &mut processing_state.total; total.add(gas_burnt, compute_usage)?; span.record("gas_burnt", gas_burnt); span.record("compute_usage", compute_usage); @@ -1726,7 +1725,6 @@ impl Runtime { mut processing_state: &mut ApplyProcessingReceiptState<'a>, receipt_sink: &mut ReceiptSink, compute_limit: u64, - proof_size_limit: Option, validator_proposals: &mut Vec, ) -> Result<(), RuntimeError> { let local_processing_start = std::time::Instant::now(); @@ -1750,9 +1748,7 @@ impl Runtime { for receipt in local_receipts.iter() { if processing_state.total.compute >= compute_limit - || proof_size_limit.is_some_and(|limit| { - processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit - }) + || processing_state.state_update.trie.check_proof_size_limit_exceed() { processing_state.delayed_receipts.push( &mut processing_state.state_update, @@ -1808,7 +1804,6 @@ impl Runtime { mut processing_state: &mut ApplyProcessingReceiptState<'a>, receipt_sink: &mut ReceiptSink, compute_limit: u64, - proof_size_limit: Option, validator_proposals: &mut Vec, ) -> Result, RuntimeError> { let delayed_processing_start = std::time::Instant::now(); @@ -1828,9 +1823,7 @@ impl Runtime { loop { if processing_state.total.compute >= compute_limit - || proof_size_limit.is_some_and(|limit| { - processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit - }) + || processing_state.state_update.trie.check_proof_size_limit_exceed() { break; } @@ -1910,7 +1903,6 @@ impl Runtime { mut processing_state: &mut ApplyProcessingReceiptState<'a>, receipt_sink: &mut ReceiptSink, compute_limit: u64, - proof_size_limit: Option, validator_proposals: &mut Vec, ) -> Result<(), RuntimeError> { let incoming_processing_start = std::time::Instant::now(); @@ -1940,9 +1932,7 @@ impl Runtime { ) .map_err(RuntimeError::ReceiptValidationError)?; if processing_state.total.compute >= compute_limit - || proof_size_limit.is_some_and(|limit| { - processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit - }) + || processing_state.state_update.trie.check_proof_size_limit_exceed() { processing_state.delayed_receipts.push( &mut processing_state.state_update, @@ -1992,24 +1982,17 @@ impl Runtime { receipt_sink: &mut ReceiptSink, ) -> Result { let mut validator_proposals = vec![]; - let protocol_version = processing_state.protocol_version; let apply_state = &processing_state.apply_state; // TODO(#8859): Introduce a dedicated `compute_limit` for the chunk. // For now compute limit always matches the gas limit. let compute_limit = apply_state.gas_limit.unwrap_or(Gas::max_value()); - let proof_size_limit = if ProtocolFeature::StatelessValidation.enabled(protocol_version) { - Some(apply_state.config.witness_config.main_storage_proof_size_soft_limit) - } else { - None - }; // We first process local receipts. They contain staking, local contract calls, etc. self.process_local_receipts( processing_state, receipt_sink, compute_limit, - proof_size_limit, &mut validator_proposals, )?; @@ -2018,7 +2001,6 @@ impl Runtime { processing_state, receipt_sink, compute_limit, - proof_size_limit, &mut validator_proposals, )?; @@ -2027,26 +2009,19 @@ impl Runtime { processing_state, receipt_sink, compute_limit, - proof_size_limit, &mut validator_proposals, )?; // Resolve timed-out PromiseYield receipts - let promise_yield_result = resolve_promise_yield_timeouts( - processing_state, - receipt_sink, - compute_limit, - proof_size_limit, - )?; + let promise_yield_result = + resolve_promise_yield_timeouts(processing_state, receipt_sink, compute_limit)?; let shard_id_str = processing_state.apply_state.shard_id.to_string(); if processing_state.total.compute >= compute_limit { metrics::CHUNK_RECEIPTS_LIMITED_BY .with_label_values(&[shard_id_str.as_str(), "compute_limit"]) .inc(); - } else if proof_size_limit.is_some_and(|limit| { - processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit - }) { + } else if processing_state.state_update.trie.check_proof_size_limit_exceed() { metrics::CHUNK_RECEIPTS_LIMITED_BY .with_label_values(&[shard_id_str.as_str(), "storage_proof_size_limit"]) .inc(); @@ -2351,7 +2326,6 @@ fn resolve_promise_yield_timeouts( processing_state: &mut ApplyProcessingReceiptState, receipt_sink: &mut ReceiptSink, compute_limit: u64, - proof_size_limit: Option, ) -> Result { let mut state_update = &mut processing_state.state_update; let total = &mut processing_state.total; @@ -2366,10 +2340,7 @@ fn resolve_promise_yield_timeouts( let mut timeout_receipts = vec![]; let yield_processing_start = std::time::Instant::now(); while promise_yield_indices.first_index < promise_yield_indices.next_available_index { - if total.compute >= compute_limit - || proof_size_limit - .is_some_and(|limit| state_update.trie.recorded_storage_size_upper_bound() > limit) - { + if total.compute >= compute_limit || state_update.trie.check_proof_size_limit_exceed() { break; } diff --git a/runtime/runtime/src/tests/apply.rs b/runtime/runtime/src/tests/apply.rs index 80cc41d1984..7a0783e2231 100644 --- a/runtime/runtime/src/tests/apply.rs +++ b/runtime/runtime/src/tests/apply.rs @@ -1151,9 +1151,14 @@ fn test_main_storage_proof_size_soft_limit() { ) }; + let trie = tries + .get_trie_for_shard(ShardUId::single_shard(), root) + .recording_reads_with_proof_size_limit( + apply_state.config.witness_config.main_storage_proof_size_soft_limit, + ); let apply_result = runtime .apply( - tries.get_trie_for_shard(ShardUId::single_shard(), root).recording_reads_new_recorder(), + trie, &None, &apply_state, &[ @@ -1192,10 +1197,16 @@ fn test_main_storage_proof_size_soft_limit() { ) }; + let trie = tries + .get_trie_for_shard(ShardUId::single_shard(), root) + .recording_reads_with_proof_size_limit( + apply_state.config.witness_config.main_storage_proof_size_soft_limit, + ); + // The function call to bob_account should hit the main_storage_proof_size_soft_limit let apply_result = runtime .apply( - tries.get_trie_for_shard(ShardUId::single_shard(), root).recording_reads_new_recorder(), + trie, &None, &apply_state, &[