Skip to content

Commit

Permalink
feat(state-sync): implement peer selection for state part requests (#…
Browse files Browse the repository at this point in the history
…10441)

This implements the ideas laid out here:
#10213, and implements a
select_host() function for the SnapshotHostsCache. This function selects
a peer to request the part from based on the up-to-date
SnapshotHostInfos we've seen, and sorts them by a priority score
calculated from the Peer Id and Part Id, remembering who we've requested
from so we can try different peers on subsequent calls

A goal of the implementation is to have as little impact as possible on
nodes not running state sync, which will be most nodes. So here we
lazily only store data associated with peer selection for a given part
after select_host() has been called for that part, and the added
part_received() function is meant to serve as a signal that we can free
this data
  • Loading branch information
marcelo-gonzalez authored Jan 23, 2024
1 parent 64e11a5 commit b0df7f6
Show file tree
Hide file tree
Showing 5 changed files with 420 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions chain/network/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ rand.workspace = true
rayon.workspace = true
serde.workspace = true
smart-default.workspace = true
sha2.workspace = true
strum.workspace = true
stun.workspace = true
thiserror.workspace = true
Expand Down
6 changes: 5 additions & 1 deletion chain/network/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ impl NetworkConfig {
},
snapshot_hosts: snapshot_hosts::Config {
snapshot_hosts_cache_size: cfg.snapshot_hosts_cache_size,
part_selection_cache_batch_size: 10,
},
whitelist_nodes: if cfg.whitelist_nodes.is_empty() {
vec![]
Expand Down Expand Up @@ -372,7 +373,10 @@ impl NetworkConfig {
peer_expiration_duration: time::Duration::seconds(60 * 60),
connect_only_to_boot_nodes: false,
},
snapshot_hosts: snapshot_hosts::Config { snapshot_hosts_cache_size: 1000 },
snapshot_hosts: snapshot_hosts::Config {
snapshot_hosts_cache_size: 1000,
part_selection_cache_batch_size: 10,
},
whitelist_nodes: vec![],
handshake_timeout: time::Duration::seconds(5),
connect_to_reliable_peers_on_startup: true,
Expand Down
242 changes: 240 additions & 2 deletions chain/network/src/snapshot_hosts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ use crate::concurrency;
use crate::network_protocol::SnapshotHostInfo;
use crate::network_protocol::SnapshotHostInfoVerificationError;
use lru::LruCache;
use near_primitives::hash::CryptoHash;
use near_primitives::network::PeerId;
use near_primitives::state_part::PartId;
use near_primitives::types::ShardId;
use parking_lot::Mutex;
use rayon::iter::ParallelBridge;
use std::collections::HashMap;
use sha2::{Digest, Sha256};
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc;

#[cfg(test)]
Expand All @@ -32,11 +36,160 @@ pub struct Config {
/// least-recently-used cache. In the future, we may wish to
/// implement something more sophisticated.
pub snapshot_hosts_cache_size: u32,
/// The number of hosts we'll add to structures related to state part peer
/// selection each time we need to request parts from a new peer
pub part_selection_cache_batch_size: u32,
}

pub(crate) fn priority_score(peer_id: &PeerId, part_id: &PartId) -> [u8; 32] {
let mut h = Sha256::new();
h.update(peer_id.public_key().key_data());
h.update(part_id.idx.to_le_bytes());
h.finalize().into()
}

#[derive(Clone, Debug, PartialEq, Eq)]
struct PartPriority {
peer_id: PeerId,
score: [u8; 32],
// TODO: consider storing this on disk, so we can remember who hasn't
// been able to provide us with the parts across restarts
times_returned: usize,
}

impl PartPriority {
fn inc(&mut self) {
self.times_returned += 1;
}
}

impl PartialOrd for PartPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for PartPriority {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.times_returned
.cmp(&other.times_returned)
.reverse()
.then_with(|| self.score.cmp(&other.score).reverse())
.then_with(|| self.peer_id.cmp(&other.peer_id))
}
}

impl From<ReversePartPriority> for PartPriority {
fn from(ReversePartPriority { peer_id, score }: ReversePartPriority) -> Self {
Self { peer_id, score, times_returned: 0 }
}
}

// used in insert_part_hosts() to iterate through the list of unseen hosts
// and keep the top N hosts as we go through. We use this struct there instead
// of PartPriority because we need the comparator to be the opposite of what
// it is for that struct
#[derive(Clone, Debug, PartialEq, Eq)]
struct ReversePartPriority {
peer_id: PeerId,
score: [u8; 32],
}

impl PartialOrd for ReversePartPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for ReversePartPriority {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.score.cmp(&other.score).then_with(|| self.peer_id.cmp(&other.peer_id))
}
}

#[derive(Default)]
struct PartPeerSelector {
peers: BinaryHeap<PartPriority>,
}

impl PartPeerSelector {
fn next(&mut self) -> Option<PeerId> {
match self.peers.pop() {
Some(mut priority) => {
priority.inc();
let peer_id = priority.peer_id.clone();
self.peers.push(priority);
Some(peer_id)
}
None => None,
}
}

fn insert_peers<T: IntoIterator<Item = PartPriority>>(&mut self, peers: T) {
self.peers.extend(peers)
}

fn len(&self) -> usize {
self.peers.len()
}

fn tried_everybody(&self) -> bool {
self.peers.iter().all(|priority| priority.times_returned > 0)
}
}

#[derive(Default)]
struct PeerSelector {
selectors: HashMap<u64, PartPeerSelector>,
}

impl PeerSelector {
fn next(&mut self, part_id: &PartId) -> Option<PeerId> {
self.selectors.entry(part_id.idx).or_default().next()
}

fn len(&self, part_id: &PartId) -> usize {
match self.selectors.get(&part_id.idx) {
Some(s) => s.len(),
None => 0,
}
}

fn insert_peers<T: IntoIterator<Item = PartPriority>>(&mut self, part_id: &PartId, peers: T) {
self.selectors.entry(part_id.idx).or_default().insert_peers(peers);
}

fn seen_peers(&self, part_id: &PartId) -> HashSet<PeerId> {
match self.selectors.get(&part_id.idx) {
Some(s) => {
let mut ret = HashSet::new();
for p in s.peers.iter() {
ret.insert(p.peer_id.clone());
}
ret
}
None => HashSet::new(),
}
}

// have we already returned every peer we know about?
fn tried_everybody(&self, part_id: &PartId) -> bool {
match self.selectors.get(&part_id.idx) {
Some(s) => s.tried_everybody(),
None => true,
}
}

fn clear(&mut self, part_id: &PartId) {
self.selectors.remove(&part_id.idx);
}
}

struct Inner {
/// The latest known SnapshotHostInfo for each node in the network
hosts: LruCache<PeerId, Arc<SnapshotHostInfo>>,
state_part_selectors: HashMap<ShardId, PeerSelector>,
part_selection_cache_batch_size: usize,
}

impl Inner {
Expand All @@ -57,14 +210,56 @@ impl Inner {
self.hosts.push(d.peer_id.clone(), d.clone());
Some(d)
}

// Try to insert up to max_entries_added more peers into the state part selector for this part ID
// this will look for the best priority `max_entries_added` peers that we haven't yet added to the set
// of peers to ask for this part, and will add them to the heap so that we can return one of those next
// time select_host() is called
fn insert_part_hosts(
&mut self,
sync_hash: &CryptoHash,
shard_id: ShardId,
part_id: &PartId,
max_entries_added: usize,
) {
let selector = self.state_part_selectors.get(&shard_id).unwrap();
let seen_peers = selector.seen_peers(part_id);

let mut new_peers = BinaryHeap::new();
for (peer_id, info) in self.hosts.iter() {
if seen_peers.contains(peer_id)
|| info.sync_hash != *sync_hash
|| !info.shards.contains(&shard_id)
{
continue;
}
let score = priority_score(peer_id, part_id);
if new_peers.len() < max_entries_added {
new_peers.push(ReversePartPriority { peer_id: peer_id.clone(), score });
} else {
if score < new_peers.peek().unwrap().score {
new_peers.pop();
new_peers.push(ReversePartPriority { peer_id: peer_id.clone(), score });
}
}
}
let selector = self.state_part_selectors.get_mut(&shard_id).unwrap();
selector.insert_peers(part_id, new_peers.into_iter().map(Into::into));
}
}

pub(crate) struct SnapshotHostsCache(Mutex<Inner>);

impl SnapshotHostsCache {
pub fn new(config: Config) -> Self {
debug_assert!(config.part_selection_cache_batch_size > 0);
let hosts = LruCache::new(config.snapshot_hosts_cache_size as usize);
Self(Mutex::new(Inner { hosts }))
let state_part_selectors = HashMap::new();
Self(Mutex::new(Inner {
hosts,
state_part_selectors,
part_selection_cache_batch_size: config.part_selection_cache_batch_size as usize,
}))
}

/// Selects new data and verifies the signatures.
Expand Down Expand Up @@ -133,4 +328,47 @@ impl SnapshotHostsCache {
pub fn get_hosts(&self) -> Vec<Arc<SnapshotHostInfo>> {
self.0.lock().hosts.iter().map(|(_, v)| v.clone()).collect()
}

// Selecs a peer to send the request for this part ID to. Chooses based on a priority score
// calculated as a hash of the Peer ID plus the part ID, and will return different hosts
// on subsequent calls, eventually iterating over all valid SnapshotHostInfos we know about
// TODO: get rid of the dead_code and hook this up to the decentralized state sync
#[allow(dead_code)]
pub fn select_host(
&self,
sync_hash: &CryptoHash,
shard_id: ShardId,
part_id: &PartId,
) -> Option<PeerId> {
let mut inner = self.0.lock();
let num_hosts = inner.hosts.len();
let selector = inner.state_part_selectors.entry(shard_id).or_default();

if selector.tried_everybody(part_id) && selector.len(part_id) < num_hosts {
let max_entries_added = inner.part_selection_cache_batch_size;
inner.insert_part_hosts(sync_hash, shard_id, part_id, max_entries_added);
}
let selector = inner.state_part_selectors.get_mut(&shard_id).unwrap();
selector.next(part_id)
}

// Lets us know that we have already successfully retrieved this part, and we can free any data
// associated with it that we were going to use to respond to future calls to select_host()
// TODO: get rid of the dead_code and hook this up to the decentralized state sync
#[allow(dead_code)]
pub fn part_received(&self, _sync_hash: &CryptoHash, shard_id: ShardId, part_id: &PartId) {
let mut inner = self.0.lock();
let selector = inner.state_part_selectors.entry(shard_id).or_default();
selector.clear(part_id);
}

// used for testing purposes only to check that we clear state after part_received() is called
#[allow(dead_code)]
pub(crate) fn part_peer_state_len(&self, shard_id: ShardId, part_id: &PartId) -> usize {
let inner = self.0.lock();
match inner.state_part_selectors.get(&shard_id) {
Some(s) => s.len(part_id),
None => 0,
}
}
}
Loading

0 comments on commit b0df7f6

Please sign in to comment.