Skip to content

Commit

Permalink
refactor: reduce places where we assume shard ids are contiguous (#10230
Browse files Browse the repository at this point in the history
)

There are many places in the repo where we assume that the valid shard
ids are in the range [0, num_shards). This PR is an attempt to improve
the current state of affairs.

- ShardLayout introduces a new method: `fn shard_ids()` which still
employs the above assumption.
- All instances above assumption are moved to calls to the above
function so that the assumption is centralised in a single place.

Future work:
- If we have a function that returns a list of shard ids, we do not need
`fn num_shards()`. It can be derived from the previous function. It
should be removed for consistency reasons.
  • Loading branch information
akhi3030 authored Nov 23, 2023
1 parent ab1a459 commit 431ae41
Show file tree
Hide file tree
Showing 26 changed files with 157 additions and 116 deletions.
11 changes: 6 additions & 5 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ impl Chain {
parent_hash: CryptoHash,
) -> Result<bool, Error> {
let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&parent_hash)?;
for shard_id in 0..self.epoch_manager.num_shards(&epoch_id)? {
for shard_id in self.epoch_manager.shard_ids(&epoch_id)? {
if self.shard_tracker.care_about_shard(me.as_ref(), &parent_hash, shard_id, true)
|| self.shard_tracker.will_care_about_shard(
me.as_ref(),
Expand Down Expand Up @@ -2323,7 +2323,7 @@ impl Chain {
// the last final block on chain, which is OK, because in the flat storage implementation
// we don't assume that.
let epoch_id = block.header().epoch_id();
for shard_id in 0..self.epoch_manager.num_shards(epoch_id)? {
for shard_id in self.epoch_manager.shard_ids(epoch_id)? {
let need_flat_storage_update = if is_caught_up {
// If we already caught up this epoch, then flat storage exists for both shards which we already track
// and shards which will be tracked in next epoch, so we can update them.
Expand Down Expand Up @@ -2709,7 +2709,8 @@ impl Chain {
parent_hash: &CryptoHash,
) -> Result<Vec<ShardId>, Error> {
let epoch_id = epoch_manager.get_epoch_id_from_prev_block(parent_hash)?;
Ok((0..epoch_manager.num_shards(&epoch_id)?)
Ok((epoch_manager.shard_ids(&epoch_id)?)
.into_iter()
.filter(|shard_id| {
Self::should_catch_up_shard(
epoch_manager,
Expand Down Expand Up @@ -3591,7 +3592,7 @@ impl Chain {
chain_update.commit()?;

let epoch_id = block.header().epoch_id();
for shard_id in 0..self.epoch_manager.num_shards(epoch_id)? {
for shard_id in self.epoch_manager.shard_ids(epoch_id)? {
// Update flat storage for each shard being caught up. We catch up a shard if it is tracked in the next
// epoch. If it is tracked in this epoch as well, it was updated during regular block processing.
if !self.shard_tracker.care_about_shard(
Expand Down Expand Up @@ -4751,7 +4752,7 @@ impl Chain {
}
let mut account_id_to_shard_id_map = HashMap::new();
let mut shard_receipts: Vec<_> =
(0..shard_layout.num_shards()).map(|i| (i, Vec::new())).collect();
shard_layout.shard_ids().into_iter().map(|shard_id| (shard_id, Vec::new())).collect();
for receipt in receipts.iter() {
let shard_id = match account_id_to_shard_id_map.get(&receipt.receiver_id) {
Some(id) => *id,
Expand Down
4 changes: 2 additions & 2 deletions chain/chain/src/flat_storage_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,13 @@ impl FlatStorageCreator {
num_threads: usize,
) -> Result<Option<Self>, Error> {
let chain_head = chain_store.head()?;
let num_shards = epoch_manager.num_shards(&chain_head.epoch_id)?;
let shard_ids = epoch_manager.shard_ids(&chain_head.epoch_id)?;
let mut shard_creators: HashMap<ShardUId, FlatStorageShardCreator> = HashMap::new();
let mut creation_needed = false;
let flat_storage_manager = runtime.get_flat_storage_manager();
// Create flat storage for all shards.
// TODO(nikurt): Choose which shards need to open the flat storage.
for shard_id in 0..num_shards {
for shard_id in shard_ids {
// The node applies transactions from the shards it cares about this and the next epoch.
let shard_uid = epoch_manager.shard_id_to_uid(shard_id, &chain_head.epoch_id)?;
let status = flat_storage_manager.get_flat_storage_status(shard_uid);
Expand Down
2 changes: 1 addition & 1 deletion chain/chain/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ mod test {
shard_layout: &ShardLayout,
) -> Vec<CryptoHash> {
let mut receipts_hashes = vec![];
for shard_id in 0..shard_layout.num_shards() {
for shard_id in shard_layout.shard_ids() {
let shard_receipts: Vec<Receipt> = receipts
.iter()
.filter(|&receipt| {
Expand Down
4 changes: 4 additions & 0 deletions chain/chain/src/test_utils/kv_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ impl EpochManagerAdapter for MockEpochManager {
Ok(self.num_shards)
}

fn shard_ids(&self, _epoch_id: &EpochId) -> Result<Vec<ShardId>, EpochError> {
Ok((0..self.num_shards).collect())
}

fn num_total_parts(&self) -> usize {
12 + (self.num_shards as usize + 1) % 50
}
Expand Down
20 changes: 9 additions & 11 deletions chain/chunks/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ impl ShardedTransactionPool {

#[cfg(test)]
mod tests {
use std::{collections::HashMap, str::FromStr};

use crate::client::ShardedTransactionPool;
use near_crypto::{InMemorySigner, KeyType};
use near_o11y::testonly::init_test_logger;
use near_pool::types::PoolIterator;
Expand All @@ -157,9 +156,8 @@ mod tests {
types::AccountId,
};
use near_store::ShardUId;
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};

use crate::client::ShardedTransactionPool;
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
use std::{collections::HashMap, str::FromStr};

const TEST_SEED: RngSeed = [3; 32];

Expand Down Expand Up @@ -206,9 +204,9 @@ mod tests {
let n = 100;
tracing::info!("inserting {n} transactions into the pool using the old shard layout");
for i in 0..n {
let num_shards = old_shard_layout.num_shards();
let signer_shard_id = rng.gen_range(0..num_shards);
let receiver_shard_id = rng.gen_range(0..num_shards);
let shard_ids: Vec<_> = old_shard_layout.shard_ids().collect();
let &signer_shard_id = shard_ids.choose(&mut rng).unwrap();
let &receiver_shard_id = shard_ids.choose(&mut rng).unwrap();
let nonce = i as u64;

let signer_id = *shard_id_to_accounts[&signer_shard_id].choose(&mut rng).unwrap();
Expand Down Expand Up @@ -242,8 +240,8 @@ mod tests {

tracing::info!("checking the pool after resharding");
{
let num_shards = new_shard_layout.num_shards();
for shard_id in 0..num_shards {
let shard_ids: Vec<_> = new_shard_layout.shard_ids().collect();
for &shard_id in shard_ids.iter() {
let shard_id = shard_id as u32;
let shard_uid = ShardUId { shard_id, version: new_shard_layout.version() };
let pool = pool.pool_for_shard(shard_uid);
Expand All @@ -253,7 +251,7 @@ mod tests {
}

let mut total = 0;
for shard_id in 0..num_shards {
for shard_id in shard_ids {
let shard_id = shard_id as u32;
let shard_uid = ShardUId { shard_id, version: new_shard_layout.version() };
let mut pool_iter = pool.get_pool_iterator(shard_uid).unwrap();
Expand Down
16 changes: 10 additions & 6 deletions chain/chunks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,10 @@ impl ShardsManager {

fn get_tracking_shards(&self, parent_hash: &CryptoHash) -> HashSet<ShardId> {
let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(parent_hash).unwrap();
(0..self.epoch_manager.num_shards(&epoch_id).unwrap())
self.epoch_manager
.shard_ids(&epoch_id)
.unwrap()
.into_iter()
.filter(|chunk_shard_id| {
cares_about_shard_this_or_next_epoch(
self.me.as_ref(),
Expand Down Expand Up @@ -1273,7 +1276,7 @@ impl ShardsManager {
};
}

if header.shard_id() >= self.epoch_manager.num_shards(&epoch_id)? {
if !self.epoch_manager.shard_ids(&epoch_id)?.contains(&header.shard_id()) {
return if epoch_id_confirmed {
byzantine_assert!(false);
Err(Error::InvalidChunkShardId)
Expand Down Expand Up @@ -1718,8 +1721,10 @@ impl ShardsManager {
let block_producers =
self.epoch_manager.get_epoch_block_producers_ordered(&epoch_id, lastest_block_hash)?;
let current_chunk_height = partial_encoded_chunk.header.height_created();
let num_shards = self.epoch_manager.num_shards(&epoch_id)?;
let mut next_chunk_producers = (0..num_shards)
let mut next_chunk_producers = self
.epoch_manager
.shard_ids(&epoch_id)?
.into_iter()
.map(|shard_id| {
self.epoch_manager.get_chunk_producer(&epoch_id, current_chunk_height + 1, shard_id)
})
Expand Down Expand Up @@ -1769,8 +1774,7 @@ impl ShardsManager {
chunk_entry: &EncodedChunksCacheEntry,
) -> Result<bool, Error> {
let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(prev_block_hash)?;
for shard_id in 0..self.epoch_manager.num_shards(&epoch_id)? {
let shard_id = shard_id as ShardId;
for shard_id in self.epoch_manager.shard_ids(&epoch_id)? {
if !chunk_entry.receipts.contains_key(&shard_id) {
if need_receipt(prev_block_hash, shard_id, self.me.as_ref(), &self.shard_tracker) {
return Ok(false);
Expand Down
5 changes: 2 additions & 3 deletions chain/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ impl Client {
fn produce_chunks(&mut self, block: &Block, validator_id: AccountId) {
let epoch_id =
self.epoch_manager.get_epoch_id_from_prev_block(block.header().hash()).unwrap();
for shard_id in 0..self.epoch_manager.num_shards(&epoch_id).unwrap() {
for shard_id in self.epoch_manager.shard_ids(&epoch_id).unwrap() {
let next_height = block.header().height() + 1;
let epoch_manager = self.epoch_manager.as_ref();
let chunk_proposer =
Expand Down Expand Up @@ -2544,8 +2544,7 @@ impl Client {
let tracked_shards = if self.config.tracked_shards.is_empty() {
vec![]
} else {
let num_shards = self.epoch_manager.num_shards(&tip.epoch_id)?;
(0..num_shards).collect()
self.epoch_manager.shard_ids(&tip.epoch_id)?
};
let tier1_accounts = self.get_tier1_accounts(&tip)?;
let block = self.chain.get_block(&tip.last_block_hash)?;
Expand Down
30 changes: 17 additions & 13 deletions chain/client/src/client_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ use near_primitives::epoch_manager::RngSeed;
use near_primitives::hash::CryptoHash;
use near_primitives::network::{AnnounceAccount, PeerId};
use near_primitives::static_clock::StaticClock;
use near_primitives::types::{BlockHeight, ShardId};
use near_primitives::types::BlockHeight;
use near_primitives::unwrap_or_return;
use near_primitives::utils::{from_timestamp, MaybeValidated};
use near_primitives::validator_signer::ValidatorSigner;
Expand Down Expand Up @@ -1665,18 +1665,22 @@ impl ClientActor {
.unwrap()
.epoch_id()
.clone();
let shards_to_sync: Vec<ShardId> =
(0..self.client.epoch_manager.num_shards(&epoch_id).unwrap())
.filter(|x| {
cares_about_shard_this_or_next_epoch(
me.as_ref(),
&prev_hash,
*x,
true,
&self.client.shard_tracker,
)
})
.collect();
let shards_to_sync: Vec<_> = self
.client
.epoch_manager
.shard_ids(&epoch_id)
.unwrap()
.into_iter()
.filter(|&shard_id| {
cares_about_shard_this_or_next_epoch(
me.as_ref(),
&prev_hash,
shard_id,
true,
&self.client.shard_tracker,
)
})
.collect();

let use_colour =
matches!(self.client.config.log_summary_style, LogSummaryStyle::Colored);
Expand Down
10 changes: 6 additions & 4 deletions chain/client/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,15 @@ impl ClientActor {
let epoch_id = self.client.chain.header_head()?.epoch_id;
let fetch_hash = self.client.chain.header_head()?.last_block_hash;
let me = self.client.validator_signer.as_ref().map(|x| x.validator_id().clone());
let num_shards = self.client.epoch_manager.num_shards(&epoch_id).unwrap();
let shards_tracked_this_epoch = (0..num_shards)
.map(|shard_id| {
let shard_ids = self.client.epoch_manager.shard_ids(&epoch_id).unwrap();
let shards_tracked_this_epoch = shard_ids
.iter()
.map(|&shard_id| {
self.client.shard_tracker.care_about_shard(me.as_ref(), &fetch_hash, shard_id, true)
})
.collect();
let shards_tracked_next_epoch = (0..num_shards)
let shards_tracked_next_epoch = shard_ids
.into_iter()
.map(|shard_id| {
self.client.shard_tracker.will_care_about_shard(
me.as_ref(),
Expand Down
12 changes: 6 additions & 6 deletions chain/client/src/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ impl InfoHelper {
/// Count which shards are tracked by the node in the epoch indicated by head parameter.
fn record_tracked_shards(head: &Tip, client: &crate::client::Client) {
let me = client.validator_signer.as_ref().map(|x| x.validator_id());
if let Ok(num_shards) = client.epoch_manager.num_shards(&head.epoch_id) {
for shard_id in 0..num_shards {
if let Ok(shard_ids) = client.epoch_manager.shard_ids(&head.epoch_id) {
for shard_id in shard_ids {
let tracked = client.shard_tracker.care_about_shard(
me,
&head.last_block_hash,
Expand Down Expand Up @@ -180,8 +180,8 @@ impl InfoHelper {
.with_label_values(&[&shard_id.to_string()])
.set(if is_chunk_producer_for_shard { 1 } else { 0 });
}
} else if let Ok(num_shards) = client.epoch_manager.num_shards(&head.epoch_id) {
for shard_id in 0..num_shards {
} else if let Ok(shard_ids) = client.epoch_manager.shard_ids(&head.epoch_id) {
for shard_id in shard_ids {
metrics::IS_CHUNK_PRODUCER_FOR_SHARD
.with_label_values(&[&shard_id.to_string()])
.set(0);
Expand All @@ -196,7 +196,7 @@ impl InfoHelper {
fn record_epoch_settlement_info(head: &Tip, client: &crate::client::Client) {
let epoch_info = client.epoch_manager.get_epoch_info(&head.epoch_id);
let blocks_in_epoch = client.config.epoch_length;
let number_of_shards = client.epoch_manager.num_shards(&head.epoch_id).unwrap_or_default();
let shard_ids = client.epoch_manager.shard_ids(&head.epoch_id).unwrap_or_default();
if let Ok(epoch_info) = epoch_info {
metrics::VALIDATORS_CHUNKS_EXPECTED_IN_EPOCH.reset();
metrics::VALIDATORS_BLOCKS_EXPECTED_IN_EPOCH.reset();
Expand Down Expand Up @@ -236,7 +236,7 @@ impl InfoHelper {
.set(stake_to_blocks(stake, stake_sum))
});

for shard_id in 0..number_of_shards {
for shard_id in shard_ids {
let mut stake_per_cp = HashMap::<ValidatorId, Balance>::new();
stake_sum = 0;
for &id in &epoch_info.chunk_producers_settlement()[shard_id as usize] {
Expand Down
7 changes: 4 additions & 3 deletions chain/client/src/view_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ impl ViewClientActor {
let head = self.chain.head()?;
let epoch_id = self.epoch_manager.get_epoch_id(&head.last_block_hash)?;
let epoch_info: Arc<EpochInfo> = self.epoch_manager.get_epoch_info(&epoch_id)?;
let num_shards = self.epoch_manager.num_shards(&epoch_id)?;
let shard_ids = self.epoch_manager.shard_ids(&epoch_id)?;
let cur_block_info = self.epoch_manager.get_block_info(&head.last_block_hash)?;
let next_epoch_start_height =
self.epoch_manager.get_epoch_start_height(cur_block_info.hash())?
Expand All @@ -277,8 +277,9 @@ impl ViewClientActor {
for block_height in head.height..next_epoch_start_height {
let bp = epoch_info.sample_block_producer(block_height);
let bp = epoch_info.get_validator(bp).account_id().clone();
let cps: Vec<AccountId> = (0..num_shards)
.map(|shard_id| {
let cps: Vec<AccountId> = shard_ids
.iter()
.map(|&shard_id| {
let cp = epoch_info.sample_chunk_producer(block_height, shard_id);
let cp = epoch_info.get_validator(cp).account_id().clone();
cp
Expand Down
8 changes: 8 additions & 0 deletions chain/epoch-manager/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub trait EpochManagerAdapter: Send + Sync {
/// Get current number of shards.
fn num_shards(&self, epoch_id: &EpochId) -> Result<NumShards, EpochError>;

/// Get the list of shard ids
fn shard_ids(&self, epoch_id: &EpochId) -> Result<Vec<ShardId>, EpochError>;

/// Number of Reed-Solomon parts we split each chunk into.
///
/// Note: this shouldn't be too large, our Reed-Solomon supports at most 256
Expand Down Expand Up @@ -412,6 +415,11 @@ impl EpochManagerAdapter for EpochManagerHandle {
Ok(epoch_manager.get_shard_layout(epoch_id)?.num_shards())
}

fn shard_ids(&self, epoch_id: &EpochId) -> Result<Vec<ShardId>, EpochError> {
let epoch_manager = self.read();
Ok(epoch_manager.get_shard_layout(epoch_id)?.shard_ids().collect())
}

fn num_total_parts(&self) -> usize {
let seats = self.read().genesis_num_block_producer_seats;
if seats > 1 {
Expand Down
5 changes: 3 additions & 2 deletions chain/epoch-manager/src/shard_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ impl ShardTracker {
TrackedConfig::Accounts(tracked_accounts) => {
let shard_layout = self.epoch_manager.get_shard_layout(epoch_id)?;
let tracking_mask = self.tracking_shards_cache.get_or_put(epoch_id.clone(), |_| {
let mut tracking_mask = vec![false; shard_layout.num_shards() as usize];
let mut tracking_mask: Vec<_> =
shard_layout.shard_ids().map(|_| false).collect();
for account_id in tracked_accounts {
let shard_id = account_id_to_shard_id(account_id, &shard_layout);
*tracking_mask.get_mut(shard_id as usize).unwrap() = true;
tracking_mask[shard_id as usize] = true;
}
tracking_mask
});
Expand Down
Loading

0 comments on commit 431ae41

Please sign in to comment.