Skip to content

Commit

Permalink
refactor: various cleanups and fixes in chain.rs (#10327)
Browse files Browse the repository at this point in the history
get_state_response_part():
- remove the contiguous shard_ids assumption. Instead of checking the
shard id against the number of chunks, look up the list of shard ids
instead.
- since the epoch id is remains the same for the current and the prev
block, the set of shard ids should also remain the same so we only need
to do the check once.
- remove some unnecessary clones
- finally rename some variables to make them shorter. Mainly remove the
`sync_` prefix. I am not sure how much value this prefix is adding and
the shorter variables are still conveying the same amount of
information.


compute_state_response_header():
- Similar to above, remove the contiguous shard_ids assumption.  
- Remove some unnecessary clones and dereferences


save_receipt_id_to_shard_id_for_block():
- there is a potential bug in the function where if encounter an error
in middle of iteration, we do not roll back the changes. Hence, pre
compute the list of hashmaps before saving receipts.
- Take a list of shard ids instead of num_shards to remove the
contiguous shard ids assumption
  • Loading branch information
akhi3030 authored Dec 13, 2023
1 parent a31cf71 commit 12ebc03
Showing 1 changed file with 33 additions and 45 deletions.
78 changes: 33 additions & 45 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2918,24 +2918,22 @@ impl Chain {
let sync_block = self
.get_block(&sync_hash)
.log_storage_error("block has already been checked for existence")?;
let sync_block_header = sync_block.header().clone();
let sync_block_epoch_id = sync_block.header().epoch_id().clone();
if shard_id as usize >= sync_block.chunks().len() {
let sync_block_header = sync_block.header();
let sync_block_epoch_id = sync_block_header.epoch_id();
let shard_ids = self.epoch_manager.shard_ids(sync_block_epoch_id)?;
if !shard_ids.contains(&shard_id) {
return Err(shard_id_out_of_bounds(shard_id));
}

// The chunk was applied at height `chunk_header.height_included`.
// Getting the `current` state.
let sync_prev_block = self.get_block(sync_block_header.prev_hash())?;
if &sync_block_epoch_id == sync_prev_block.header().epoch_id() {
if sync_block_epoch_id == sync_prev_block.header().epoch_id() {
return Err(sync_hash_not_first_hash(sync_hash));
}
if shard_id as usize >= sync_prev_block.chunks().len() {
return Err(shard_id_out_of_bounds(shard_id));
}
// Chunk header here is the same chunk header as at the `current` height.
let sync_prev_hash = *sync_prev_block.hash();
let chunk_header = sync_prev_block.chunks()[shard_id as usize].clone();
let sync_prev_hash = sync_prev_block.hash();
let chunk_header = &sync_prev_block.chunks()[shard_id as usize];
let (chunk_headers_root, chunk_proofs) = merklize(
&sync_prev_block
.chunks()
Expand All @@ -2947,7 +2945,7 @@ impl Chain {
);
assert_eq!(&chunk_headers_root, sync_prev_block.header().chunk_headers_root());

let chunk = self.get_chunk_clone_from_header(&chunk_header)?;
let chunk = self.get_chunk_clone_from_header(chunk_header)?;
let chunk_proof = chunk_proofs[shard_id as usize].clone();
let block_header =
self.get_block_header_on_chain_by_height(&sync_hash, chunk_header.height_included())?;
Expand All @@ -2957,9 +2955,6 @@ impl Chain {
.get_block(block_header.prev_hash())
{
Ok(prev_block) => {
if shard_id as usize >= prev_block.chunks().len() {
return Err(shard_id_out_of_bounds(shard_id));
}
let prev_chunk_header = prev_block.chunks()[shard_id as usize].clone();
let (prev_chunk_headers_root, prev_chunk_proofs) = merklize(
&prev_block
Expand Down Expand Up @@ -3042,7 +3037,7 @@ impl Chain {

let state_root_node = self.runtime_adapter.get_state_root_node(
shard_id,
&sync_prev_hash,
sync_prev_hash,
&chunk_header.prev_state_root(),
)?;

Expand Down Expand Up @@ -3116,30 +3111,27 @@ impl Chain {
return Ok(state_part.into());
}

let sync_block = self
let block = self
.get_block(&sync_hash)
.log_storage_error("block has already been checked for existence")?;
let sync_block_header = sync_block.header().clone();
let sync_block_epoch_id = sync_block.header().epoch_id().clone();
if shard_id as usize >= sync_block.chunks().len() {
let header = block.header();
let epoch_id = block.header().epoch_id();
let shard_ids = self.epoch_manager.shard_ids(epoch_id)?;
if !shard_ids.contains(&shard_id) {
return Err(shard_id_out_of_bounds(shard_id));
}
let sync_prev_block = self.get_block(sync_block_header.prev_hash())?;
if &sync_block_epoch_id == sync_prev_block.header().epoch_id() {
let prev_block = self.get_block(header.prev_hash())?;
if epoch_id == prev_block.header().epoch_id() {
return Err(sync_hash_not_first_hash(sync_hash));
}
if shard_id as usize >= sync_prev_block.chunks().len() {
return Err(shard_id_out_of_bounds(shard_id));
}
let state_root = sync_prev_block.chunks()[shard_id as usize].prev_state_root();
let sync_prev_hash = *sync_prev_block.hash();
let sync_prev_prev_hash = *sync_prev_block.header().prev_hash();
let state_root = prev_block.chunks()[shard_id as usize].prev_state_root();
let prev_hash = *prev_block.hash();
let prev_prev_hash = *prev_block.header().prev_hash();
let state_root_node = self
.runtime_adapter
.get_state_root_node(shard_id, &sync_prev_hash, &state_root)
.get_state_root_node(shard_id, &prev_hash, &state_root)
.log_storage_error("get_state_root_node fail")?;
let num_parts = get_num_state_parts(state_root_node.memory_usage);

if part_id >= num_parts {
return Err(shard_id_out_of_bounds(shard_id));
}
Expand All @@ -3148,7 +3140,7 @@ impl Chain {
.runtime_adapter
.obtain_state_part(
shard_id,
&sync_prev_prev_hash,
&prev_prev_hash,
&state_root,
PartId::new(part_id, num_parts),
)
Expand Down Expand Up @@ -5290,27 +5282,22 @@ impl<'a> ChainUpdate<'a> {
/// EpochManager, otherwise it will return an error.
fn save_receipt_id_to_shard_id_for_block(
&mut self,
me: &Option<AccountId>,
account_id: Option<&AccountId>,
hash: &CryptoHash,
prev_hash: &CryptoHash,
num_shards: NumShards,
shard_ids: &[ShardId],
) -> Result<(), Error> {
for shard_id in 0..num_shards {
let care_about_shard = self.shard_tracker.care_about_shard(
me.as_ref(),
&prev_hash,
shard_id as ShardId,
true,
);
if !care_about_shard {
continue;
let mut list = vec![];
for &shard_id in shard_ids {
if self.shard_tracker.care_about_shard(account_id, prev_hash, shard_id, true) {
list.push(self.get_receipt_id_to_shard_id(hash, shard_id)?);
}
let receipt_id_to_shard_id = self.get_receipt_id_to_shard_id(hash, shard_id)?;
for (receipt_id, shard_id) in receipt_id_to_shard_id {
}
for map in list {
for (receipt_id, shard_id) in map {
self.chain_store_update.save_receipt_id_to_shard_id(receipt_id, shard_id);
}
}

Ok(())
}

Expand Down Expand Up @@ -5615,6 +5602,7 @@ impl<'a> ChainUpdate<'a> {
block_preprocess_info: BlockPreprocessInfo,
apply_chunks_results: Vec<(ShardId, Result<ShardUpdateResult, Error>)>,
) -> Result<Option<Tip>, Error> {
let shard_ids = self.epoch_manager.shard_ids(block.header().epoch_id())?;
let prev_hash = block.header().prev_hash();
let results = apply_chunks_results.into_iter().map(|(shard_id, x)| {
if let Err(err) = &x {
Expand Down Expand Up @@ -5682,10 +5670,10 @@ impl<'a> ChainUpdate<'a> {

// Save receipt_id_to_shard_id for all outgoing receipts generated in this block
self.save_receipt_id_to_shard_id_for_block(
me,
me.as_ref(),
block.hash(),
prev_hash,
block.chunks().len() as NumShards,
&shard_ids,
)?;

// Update the chain head if it's the new tip
Expand Down

0 comments on commit 12ebc03

Please sign in to comment.