Skip to content

Commit

Permalink
feat: parallel partial witness handling in the partial witness actor (#…
Browse files Browse the repository at this point in the history
…12656)

The PR unblocks the main thread of the `PartialWitnessActor` by
detaching the handling of the partial witnesses to separate threads.

This results in a considerable reduction in the distribution latency of
the state witness:

![image](https://github.com/user-attachments/assets/53723050-a341-43d8-b15c-5dea7a61f2b9)
The image originates from a forknet experiment with 50 nodes, with each
state witness artificially padded to reach a size of 30 MB.
  • Loading branch information
stedfn authored Jan 13, 2025
1 parent 69ba684 commit 692fa5c
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub struct PartialWitnessActor {
epoch_manager: Arc<dyn EpochManagerAdapter>,
runtime: Arc<dyn RuntimeAdapter>,
/// Tracks the parts of the state witness sent from chunk producers to chunk validators.
partial_witness_tracker: PartialEncodedStateWitnessTracker,
partial_witness_tracker: Arc<PartialEncodedStateWitnessTracker>,
partial_deploys_tracker: PartialEncodedContractDeploysTracker,
/// Tracks a collection of state witnesses sent from chunk producers to chunk validators.
state_witness_tracker: ChunkStateWitnessTracker,
Expand All @@ -75,6 +75,7 @@ pub struct PartialWitnessActor {
/// Same as above for contract deploys.
contract_deploys_encoders: ReedSolomonEncoderCache,
compile_contracts_spawner: Arc<dyn AsyncComputationSpawner>,
partial_witness_spawner: Arc<dyn AsyncComputationSpawner>,
/// AccountId in the key corresponds to the requester (chunk validator).
processed_contract_code_requests: LruCache<(ChunkProductionKey, AccountId), ()>,
}
Expand Down Expand Up @@ -166,9 +167,10 @@ impl PartialWitnessActor {
epoch_manager: Arc<dyn EpochManagerAdapter>,
runtime: Arc<dyn RuntimeAdapter>,
compile_contracts_spawner: Arc<dyn AsyncComputationSpawner>,
partial_witness_spawner: Arc<dyn AsyncComputationSpawner>,
) -> Self {
let partial_witness_tracker =
PartialEncodedStateWitnessTracker::new(client_sender, epoch_manager.clone());
Arc::new(PartialEncodedStateWitnessTracker::new(client_sender, epoch_manager.clone()));
Self {
network_adapter,
my_signer,
Expand All @@ -182,6 +184,7 @@ impl PartialWitnessActor {
CONTRACT_DEPLOYS_RATIO_DATA_PARTS,
),
compile_contracts_spawner,
partial_witness_spawner,
processed_contract_code_requests: LruCache::new(
NonZeroUsize::new(PROCESSED_CONTRACT_CODE_REQUESTS_CACHE_SIZE).unwrap(),
),
Expand Down Expand Up @@ -365,13 +368,20 @@ impl PartialWitnessActor {
));
}

/// Sends the witness part to the chunk validators, except the chunk producer that generated the witness part.
fn forward_state_witness_part(
&self,
/// Function to handle receiving partial_encoded_state_witness message from chunk producer.
fn handle_partial_encoded_state_witness(
&mut self,
partial_witness: PartialEncodedStateWitness,
) -> Result<(), Error> {
tracing::debug!(target: "client", ?partial_witness, "Receive PartialEncodedStateWitnessMessage");
let signer = self.my_validator_signer()?;
let validator_account_id = signer.validator_id().clone();
let epoch_manager = self.epoch_manager.clone();
let runtime_adapter = self.runtime.clone();

let ChunkProductionKey { shard_id, epoch_id, height_created } =
partial_witness.chunk_production_key();

let chunk_producer = self
.epoch_manager
.get_chunk_producer_info(&ChunkProductionKey { epoch_id, height_created, shard_id })?
Expand All @@ -386,32 +396,40 @@ impl PartialWitnessActor {
.filter(|validator| validator != &chunk_producer)
.collect();

self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests(
NetworkRequests::PartialEncodedStateWitnessForward(
target_chunk_validators,
partial_witness,
),
));
Ok(())
}

/// Function to handle receiving partial_encoded_state_witness message from chunk producer.
fn handle_partial_encoded_state_witness(
&mut self,
partial_witness: PartialEncodedStateWitness,
) -> Result<(), Error> {
tracing::debug!(target: "client", ?partial_witness, "Receive PartialEncodedStateWitnessMessage");

let signer = self.my_validator_signer()?;
// Validate the partial encoded state witness and forward the part to all the chunk validators.
if validate_partial_encoded_state_witness(
self.epoch_manager.as_ref(),
&partial_witness,
&signer,
self.runtime.store(),
)? {
self.forward_state_witness_part(partial_witness)?;
}
let network_adapter = self.network_adapter.clone();

self.partial_witness_spawner.spawn("handle_partial_encoded_state_witness", move || {
// Validate the partial encoded state witness and forward the part to all the chunk validators.
match validate_partial_encoded_state_witness(
epoch_manager.as_ref(),
&partial_witness,
&validator_account_id,
runtime_adapter.store(),
) {
Ok(true) => {
network_adapter.send(PeerManagerMessageRequest::NetworkRequests(
NetworkRequests::PartialEncodedStateWitnessForward(
target_chunk_validators,
partial_witness,
),
));
}
Ok(false) => {
// TODO: ban sending peer
tracing::warn!(
target: "client",
"Received invalid partial encoded state witness"
);
}
Err(err) => {
tracing::warn!(
target: "client",
"Encountered error during validation: {}",
err
);
}
}
});

Ok(())
}
Expand All @@ -424,15 +442,42 @@ impl PartialWitnessActor {
tracing::debug!(target: "client", ?partial_witness, "Receive PartialEncodedStateWitnessForwardMessage");

let signer = self.my_validator_signer()?;
// Validate the partial encoded state witness and store the partial encoded state witness.
if validate_partial_encoded_state_witness(
self.epoch_manager.as_ref(),
&partial_witness,
&signer,
self.runtime.store(),
)? {
self.partial_witness_tracker.store_partial_encoded_state_witness(partial_witness)?;
}
let validator_account_id = signer.validator_id().clone();
let partial_witness_tracker = self.partial_witness_tracker.clone();
let epoch_manager = self.epoch_manager.clone();
let runtime_adapter = self.runtime.clone();
self.partial_witness_spawner.spawn(
"handle_partial_encoded_state_witness_forward",
move || {
// Validate the partial encoded state witness and store the partial encoded state witness.
match validate_partial_encoded_state_witness(
epoch_manager.as_ref(),
&partial_witness,
&validator_account_id,
runtime_adapter.store(),
) {
Ok(true) => {
if let Err(err) = partial_witness_tracker.store_partial_encoded_state_witness(partial_witness) {
tracing::error!(target: "client", "Failed to store partial encoded state witness: {}", err);
}
}
Ok(false) => {
// TODO: ban sending peer
tracing::warn!(
target: "client",
"Received invalid partial encoded state witness"
);
}
Err(err) => {
tracing::warn!(
target: "client",
"Encountered error during validation: {}",
err
);
}
}
},
);

Ok(())
}
Expand Down Expand Up @@ -596,7 +641,7 @@ impl PartialWitnessActor {

/// Sends the contract accesses to the same chunk validators
/// (except for the chunk producers that track the same shard),
/// which will receive the state witness for the new chunk.
/// which will receive the state witness for the new chunk.
fn send_contract_accesses_to_chunk_validators(
&self,
key: ChunkProductionKey,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use lru::LruCache;
use near_async::messaging::CanSend;
use near_async::time::Instant;
use near_cache::SyncLruCache;
use near_chain::chain::ChunkStateWitnessMessage;
use near_chain::Error;
use near_epoch_manager::EpochManagerAdapter;
Expand Down Expand Up @@ -308,13 +309,13 @@ pub struct PartialEncodedStateWitnessTracker {
/// Epoch manager to get the set of chunk validators
epoch_manager: Arc<dyn EpochManagerAdapter>,
/// Keeps track of state witness parts received from chunk producers.
parts_cache: LruCache<ChunkProductionKey, CacheEntry>,
parts_cache: Mutex<LruCache<ChunkProductionKey, CacheEntry>>,
/// Keeps track of the already decoded witnesses. This is needed
/// to protect chunk validator from processing the same witness multiple
/// times.
processed_witnesses: LruCache<ChunkProductionKey, ()>,
processed_witnesses: SyncLruCache<ChunkProductionKey, ()>,
/// Reed Solomon encoder for decoding state witness parts.
encoders: ReedSolomonEncoderCache,
encoders: Mutex<ReedSolomonEncoderCache>,
}

impl PartialEncodedStateWitnessTracker {
Expand All @@ -325,16 +326,16 @@ impl PartialEncodedStateWitnessTracker {
Self {
client_sender,
epoch_manager,
parts_cache: LruCache::new(NonZeroUsize::new(WITNESS_PARTS_CACHE_SIZE).unwrap()),
processed_witnesses: LruCache::new(
NonZeroUsize::new(PROCESSED_WITNESSES_CACHE_SIZE).unwrap(),
),
encoders: ReedSolomonEncoderCache::new(WITNESS_RATIO_DATA_PARTS),
parts_cache: Mutex::new(LruCache::new(
NonZeroUsize::new(WITNESS_PARTS_CACHE_SIZE).unwrap(),
)),
processed_witnesses: SyncLruCache::new(PROCESSED_WITNESSES_CACHE_SIZE),
encoders: Mutex::new(ReedSolomonEncoderCache::new(WITNESS_RATIO_DATA_PARTS)),
}
}

pub fn store_partial_encoded_state_witness(
&mut self,
&self,
partial_witness: PartialEncodedStateWitness,
) -> Result<(), Error> {
tracing::debug!(target: "client", ?partial_witness, "store_partial_encoded_state_witness");
Expand All @@ -345,7 +346,7 @@ impl PartialEncodedStateWitnessTracker {
}

pub fn store_accessed_contract_hashes(
&mut self,
&self,
key: ChunkProductionKey,
hashes: HashSet<CodeHash>,
) -> Result<(), Error> {
Expand All @@ -355,7 +356,7 @@ impl PartialEncodedStateWitnessTracker {
}

pub fn store_accessed_contract_codes(
&mut self,
&self,
key: ChunkProductionKey,
codes: Vec<CodeBytes>,
) -> Result<(), Error> {
Expand All @@ -365,7 +366,7 @@ impl PartialEncodedStateWitnessTracker {
}

fn process_update(
&mut self,
&self,
key: ChunkProductionKey,
create_if_not_exists: bool,
update: CacheUpdate,
Expand All @@ -382,17 +383,23 @@ impl PartialEncodedStateWitnessTracker {
if create_if_not_exists {
self.maybe_insert_new_entry_in_parts_cache(&key);
}
let Some(entry) = self.parts_cache.get_mut(&key) else {
let mut parts_cache = self.parts_cache.lock().unwrap();
let Some(entry) = parts_cache.get_mut(&key) else {
return Ok(());
};
if let Some((decode_result, accessed_contracts)) = entry.update(update) {
let total_size: usize = if let Some((decode_result, accessed_contracts)) =
entry.update(update)
{
// Record the time taken from receiving first part to decoding partial witness.
let time_to_last_part = Instant::now().signed_duration_since(entry.created_at);
metrics::PARTIAL_WITNESS_TIME_TO_LAST_PART
.with_label_values(&[key.shard_id.to_string().as_str()])
.observe(time_to_last_part.as_seconds_f64());

self.parts_cache.pop(&key);
parts_cache.pop(&key);
let total_size = parts_cache.iter().map(|(_, entry)| entry.total_size()).sum();
drop(parts_cache);

self.processed_witnesses.push(key.clone(), ());

let encoded_witness = match decode_result {
Expand Down Expand Up @@ -428,26 +435,33 @@ impl PartialEncodedStateWitnessTracker {

tracing::debug!(target: "client", ?key, "Sending encoded witness to client.");
self.client_sender.send(ChunkStateWitnessMessage { witness, raw_witness_size });
}
self.record_total_parts_cache_size_metric();

total_size
} else {
parts_cache.iter().map(|(_, entry)| entry.total_size()).sum()
};
metrics::PARTIAL_WITNESS_CACHE_SIZE.set(total_size as f64);

Ok(())
}

fn get_encoder(&mut self, key: &ChunkProductionKey) -> Result<Arc<ReedSolomonEncoder>, Error> {
fn get_encoder(&self, key: &ChunkProductionKey) -> Result<Arc<ReedSolomonEncoder>, Error> {
// The expected number of parts for the Reed Solomon encoding is the number of chunk validators.
let num_parts = self
.epoch_manager
.get_chunk_validator_assignments(&key.epoch_id, key.shard_id, key.height_created)?
.len();
Ok(self.encoders.entry(num_parts))
let mut encoders = self.encoders.lock().unwrap();
Ok(encoders.entry(num_parts))
}

// Function to insert a new entry into the cache for the chunk hash if it does not already exist
// We additionally check if an evicted entry has been fully decoded and processed.
fn maybe_insert_new_entry_in_parts_cache(&mut self, key: &ChunkProductionKey) {
if !self.parts_cache.contains(key) {
fn maybe_insert_new_entry_in_parts_cache(&self, key: &ChunkProductionKey) {
let mut parts_cache = self.parts_cache.lock().unwrap();
if !parts_cache.contains(key) {
if let Some((evicted_key, evicted_entry)) =
self.parts_cache.push(key.clone(), CacheEntry::new(key.shard_id))
parts_cache.push(key.clone(), CacheEntry::new(key.shard_id))
{
tracing::warn!(
target: "client",
Expand All @@ -460,11 +474,6 @@ impl PartialEncodedStateWitnessTracker {
}
}

fn record_total_parts_cache_size_metric(&self) {
let total_size: usize = self.parts_cache.iter().map(|(_, entry)| entry.total_size()).sum();
metrics::PARTIAL_WITNESS_CACHE_SIZE.set(total_size as f64);
}

fn decode_state_witness(
&self,
encoded_witness: &EncodedChunkStateWitness,
Expand Down
4 changes: 2 additions & 2 deletions chain/client/src/stateless_validation/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const MAX_HEIGHTS_AHEAD: BlockHeightDelta = 5;
pub fn validate_partial_encoded_state_witness(
epoch_manager: &dyn EpochManagerAdapter,
partial_witness: &PartialEncodedStateWitness,
signer: &ValidatorSigner,
validator_account_id: &AccountId,
store: &Store,
) -> Result<bool, Error> {
let ChunkProductionKey { shard_id, epoch_id, height_created } =
Expand Down Expand Up @@ -56,7 +56,7 @@ pub fn validate_partial_encoded_state_witness(
if !validate_chunk_relevant_as_validator(
epoch_manager,
&partial_witness.chunk_production_key(),
signer.validator_id(),
validator_account_id,
store,
)? {
return Ok(false);
Expand Down
1 change: 1 addition & 0 deletions chain/client/src/test_utils/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ pub fn setup(
epoch_manager.clone(),
runtime.clone(),
Arc::new(RayonAsyncComputationSpawner),
Arc::new(RayonAsyncComputationSpawner),
));
let partial_witness_adapter = partial_witness_addr.with_auto_span_context();

Expand Down
1 change: 1 addition & 0 deletions integration-tests/src/test_loop/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ impl TestLoopBuilder {
epoch_manager.clone(),
runtime_adapter.clone(),
Arc::new(self.test_loop.async_computation_spawner(|_| Duration::milliseconds(80))),
Arc::new(self.test_loop.async_computation_spawner(|_| Duration::milliseconds(80))),
);

let gc_actor = GCActor::new(
Expand Down
1 change: 1 addition & 0 deletions integration-tests/src/tests/network/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ fn setup_network_node(
epoch_manager,
runtime,
Arc::new(RayonAsyncComputationSpawner),
Arc::new(RayonAsyncComputationSpawner),
));
shards_manager_adapter.bind(shards_manager_actor.with_auto_span_context());
let peer_manager = PeerManagerActor::spawn(
Expand Down
1 change: 1 addition & 0 deletions nearcore/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ pub fn start_with_config_and_synchronization(
epoch_manager.clone(),
runtime.clone(),
Arc::new(RayonAsyncComputationSpawner),
Arc::new(RayonAsyncComputationSpawner),
));

let (_gc_actor, gc_arbiter) = spawn_actix_actor(GCActor::new(
Expand Down
Loading

0 comments on commit 692fa5c

Please sign in to comment.