Skip to content

Commit

Permalink
[trie] Remove TrieNode::Empty to align it better with RawTrieNode (#1…
Browse files Browse the repository at this point in the history
…2394)

Step 1 of removing TrieNode

Part of #12361

The main difference between TrieNode and RawTrieNode was that TrieNode
had an enum value TrieNode::Empty while RawTrieNode did not.

This PR removes all occurrences of TrieNode::Empty and replaces them
with `Option::None` for TrieNode.

In the next PR, we can change all occurrences of TrieNode to RawTrieNode
  • Loading branch information
shreyan-gupta authored Jan 27, 2025
1 parent 5581f65 commit b810c4c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 46 deletions.
45 changes: 29 additions & 16 deletions core/store/src/trie/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use near_primitives::hash::CryptoHash;

use crate::trie::nibble_slice::NibbleSlice;
use crate::trie::{TrieNode, TrieNodeWithSize, ValueHandle};
use crate::{MissingTrieValueContext, StorageError, Trie};
use crate::{StorageError, Trie};

use super::mem::iter::STMemTrieIterator;

/// Crumb is a piece of trie iteration state. It describes a node on the trail and processing status of that node.
#[derive(Debug)]
struct Crumb {
node: TrieNodeWithSize,
node: Option<TrieNodeWithSize>,
status: CrumbStatus,
prefix_boundary: bool,
}
Expand All @@ -31,11 +31,16 @@ impl Crumb {
self.status = CrumbStatus::Exiting;
return;
}
self.status = match (&self.status, &self.node.node) {
(_, &TrieNode::Empty) => CrumbStatus::Exiting,

let Some(node) = &self.node else {
self.status = CrumbStatus::Exiting;
return;
};

self.status = match (&self.status, &node.node) {
(&CrumbStatus::Entering, _) => CrumbStatus::At,
(&CrumbStatus::At, &TrieNode::Branch(_, _)) => CrumbStatus::AtChild(0),
(&CrumbStatus::AtChild(x), &TrieNode::Branch(_, _)) if x < 15 => {
(&CrumbStatus::At, TrieNode::Branch(_, _)) => CrumbStatus::AtChild(0),
(&CrumbStatus::AtChild(x), TrieNode::Branch(_, _)) if x < 15 => {
CrumbStatus::AtChild(x + 1)
}
_ => CrumbStatus::Exiting,
Expand Down Expand Up @@ -154,8 +159,11 @@ impl<'a> DiskTrieIterator<'a> {
self.descend_into_node(&hash)?;
let Crumb { status, node, prefix_boundary } = self.trail.last_mut().unwrap();
prev_prefix_boundary = prefix_boundary;

let Some(node) = node else {
break;
};
match &node.node {
TrieNode::Empty => break,
TrieNode::Leaf(leaf_key, _) => {
let existing_key = NibbleSlice::from_encoded(leaf_key).0;
if !check_ext_key(&key, &existing_key) {
Expand Down Expand Up @@ -206,12 +214,13 @@ impl<'a> DiskTrieIterator<'a> {
/// with [`Self::remember_visited_nodes`]), the node will be added to the
/// list.
fn descend_into_node(&mut self, hash: &CryptoHash) -> Result<(), StorageError> {
let (bytes, node) = self.trie.retrieve_node(hash)?;
if let Some(ref mut visited) = self.visited_nodes {
visited.push(bytes.ok_or({
StorageError::MissingTrieValue(MissingTrieValueContext::TrieIterator, *hash)
})?);
}
let node = self.trie.retrieve_node(hash)?.map(|(bytes, node)| {
if let Some(ref mut visited) = self.visited_nodes {
visited.push(bytes);
}
node
});

self.trail.push(Crumb { status: CrumbStatus::Entering, node, prefix_boundary: false });
Ok(())
}
Expand All @@ -227,7 +236,9 @@ impl<'a> DiskTrieIterator<'a> {
fn has_value(&self) -> bool {
match self.trail.last() {
Some(b) => match &b.status {
CrumbStatus::At => b.node.node.has_value(),
CrumbStatus::At => {
b.node.as_ref().map(|node| node.node.has_value()).unwrap_or_default()
}
_ => false,
},
None => false, // Trail finished
Expand All @@ -237,7 +248,10 @@ impl<'a> DiskTrieIterator<'a> {
fn iter_step(&mut self) -> Option<IterStep> {
let last = self.trail.last_mut()?;
last.increment();
Some(match (last.status, &last.node.node) {
let Some(node) = &last.node else {
return Some(IterStep::PopTrail);
};
Some(match (last.status, &node.node) {
(CrumbStatus::Exiting, n) => {
match n {
TrieNode::Leaf(ref key, _) | TrieNode::Extension(ref key, _) => {
Expand All @@ -248,7 +262,6 @@ impl<'a> DiskTrieIterator<'a> {
TrieNode::Branch(_, _) => {
self.key_nibbles.pop();
}
_ => {}
}
IterStep::PopTrail
}
Expand Down
26 changes: 6 additions & 20 deletions core/store/src/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ impl std::fmt::Debug for ValueHandle {
// TODO(#12361): replace with `RawTrieNode`.
#[derive(Clone, Hash)]
enum TrieNode {
/// Null trie node. Could be an empty root or an empty branch entry.
Empty,
/// Key and value of the leaf node.
Leaf(Vec<u8>, ValueHandle),
/// Branch of 16 possible children and value if key ends here.
Expand All @@ -167,10 +165,6 @@ impl TrieNodeWithSize {
fn new(node: TrieNode, memory_usage: u64) -> TrieNodeWithSize {
TrieNodeWithSize { node, memory_usage }
}

fn empty() -> TrieNodeWithSize {
TrieNodeWithSize { node: TrieNode::Empty, memory_usage: 0 }
}
}

impl TrieNode {
Expand Down Expand Up @@ -288,12 +282,6 @@ impl TrieNode {
/// `GenericUpdatedTrieNode::memory_usage_direct`.
fn memory_usage_direct(&self) -> u64 {
match self {
TrieNode::Empty => {
// DEVNOTE: empty nodes don't exist in storage.
// In the in-memory implementation Some(TrieNode::Empty) and None are interchangeable as
// children of branch nodes which means cost has to be 0
0
}
TrieNode::Leaf(key, value) => {
TRIE_COSTS.node_cost
+ (key.len() as u64) * TRIE_COSTS.byte_of_key
Expand All @@ -318,7 +306,6 @@ impl std::fmt::Debug for TrieNode {
let empty = "";
let indent = fmtr.width().unwrap_or(0);
match self {
TrieNode::Empty => write!(fmtr, "{empty:indent$}Empty"),
TrieNode::Leaf(key, value) => write!(
fmtr,
"{empty:indent$}Leaf({:?}, {value:?})",
Expand Down Expand Up @@ -615,7 +602,7 @@ pub struct ApplyStatePartResult {

enum NodeOrValue {
Node,
Value(std::sync::Arc<[u8]>),
Value(Arc<[u8]>),
}

/// Like a ValueRef, but allows for optimized retrieval of the value if the
Expand Down Expand Up @@ -1263,7 +1250,7 @@ impl Trie {
hash: &CryptoHash,
use_accounting_cache: bool,
side_effects: bool,
) -> Result<Option<(std::sync::Arc<[u8]>, RawTrieNodeWithSize)>, StorageError> {
) -> Result<Option<(Arc<[u8]>, RawTrieNodeWithSize)>, StorageError> {
if hash == &Self::EMPTY_ROOT {
return Ok(None);
}
Expand Down Expand Up @@ -1313,11 +1300,10 @@ impl Trie {
fn retrieve_node(
&self,
hash: &CryptoHash,
) -> Result<(Option<std::sync::Arc<[u8]>>, TrieNodeWithSize), StorageError> {
match self.retrieve_raw_node(hash, true, true)? {
None => Ok((None, TrieNodeWithSize::empty())),
Some((bytes, node)) => Ok((Some(bytes), TrieNodeWithSize::from_raw(node))),
}
) -> Result<Option<(Arc<[u8]>, TrieNodeWithSize)>, StorageError> {
Ok(self
.retrieve_raw_node(hash, true, true)?
.map(|(bytes, node)| (bytes, TrieNodeWithSize::from_raw(node))))
}

pub fn retrieve_root_node(&self) -> Result<StateRootNode, StorageError> {
Expand Down
20 changes: 10 additions & 10 deletions core/store/src/trie/state_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ impl Trie {
if part_id == num_parts {
return Ok(LAST_STATE_PART_BOUNDARY.to_vec());
}
let root_node = self.retrieve_node(&self.root)?.1;
let root_node = self.retrieve_node(&self.root)?.map(|node| node.1);
if root_node.is_none() {
return Ok(vec![]);
}
let root_node = root_node.unwrap();
let total_size = root_node.memory_usage;
let size_start = total_size / num_parts * part_id + part_id.min(total_size % num_parts);
self.find_node_in_dfs_order(&root_node, size_start)
Expand Down Expand Up @@ -344,7 +348,6 @@ impl Trie {
*memory_skipped += node.node.memory_usage_direct();

match &node.node {
TrieNode::Empty => Ok(false),
TrieNode::Leaf(key, _) => {
let (slice, _) = NibbleSlice::from_encoded(key);
key_nibbles.extend(slice.iter());
Expand All @@ -361,7 +364,7 @@ impl Trie {
let mut iter = children.iter();
while let Some((index, child)) = iter.next() {
let NodeHandle::Hash(h) = child;
let child = self.retrieve_node(h)?.1;
let child = self.retrieve_node(h)?.unwrap().1;
if *memory_skipped + child.memory_usage > memory_threshold {
core::mem::drop(iter);
key_nibbles.push(index);
Expand All @@ -379,7 +382,7 @@ impl Trie {
}
TrieNode::Extension(key, child_handle) => {
let child = match child_handle {
NodeHandle::Hash(h) => self.retrieve_node(h)?.1,
NodeHandle::Hash(h) => self.retrieve_node(h)?.unwrap().1,
};
let (slice, _) = NibbleSlice::from_encoded(key);
key_nibbles.extend(slice.iter());
Expand Down Expand Up @@ -658,7 +661,7 @@ mod tests {
return Ok(());
}
let mut stack: Vec<(CryptoHash, TrieNodeWithSize, CrumbStatus)> = Vec::new();
let root_node = self.retrieve_node(&self.root)?.1;
let root_node = self.retrieve_node(&self.root)?.unwrap().1;
stack.push((self.root, root_node, CrumbStatus::Entering));
while let Some((hash, node, position)) = stack.pop() {
if let CrumbStatus::Entering = position {
Expand All @@ -672,9 +675,6 @@ mod tests {
}
};
match &node.node {
TrieNode::Empty => {
continue;
}
TrieNode::Leaf(_, value) => {
on_enter_value(value)?;
continue;
Expand All @@ -694,7 +694,7 @@ mod tests {
}
if let Some(NodeHandle::Hash(ref h)) = children[i] {
let h = *h;
let child = self.retrieve_node(&h)?.1;
let child = self.retrieve_node(&h)?.unwrap().1;
stack.push((hash, node, CrumbStatus::AtChild(i + 1)));
stack.push((h, child, CrumbStatus::Entering));
break;
Expand All @@ -711,7 +711,7 @@ mod tests {
TrieNode::Extension(_key, child) => {
if let CrumbStatus::Entering = position {
let NodeHandle::Hash(h) = child.clone();
let child = self.retrieve_node(&h)?.1;
let child = self.retrieve_node(&h)?.unwrap().1;
stack.push((h, node, CrumbStatus::Exiting));
stack.push((h, child, CrumbStatus::Entering));
}
Expand Down

0 comments on commit b810c4c

Please sign in to comment.