diff --git a/firewood/src/hashednode.rs b/firewood/src/hashednode.rs index a4c87bbfb..4574edf6d 100644 --- a/firewood/src/hashednode.rs +++ b/firewood/src/hashednode.rs @@ -56,8 +56,8 @@ impl From> for HashedNodeStore { impl HashedNodeStore { pub fn read_node(&self, addr: LinearAddress) -> Result, Error> { - if let Some(modified_node) = self.modified.get(&addr) { - Ok(modified_node.0.clone()) + if let Some((modified_node, _)) = self.modified.get(&addr) { + Ok(modified_node.clone()) } else { Ok(self.nodestore.read_node(addr)?) } @@ -158,6 +158,7 @@ impl HashedNodeStore { Node::Leaf(_) => Ok(self.hash_internal(node, path_prefix)), } } + pub fn freeze(mut self) -> Result, Error> { // fill in all remaining hashes, including the root hash if let Some(root_address) = self.root_address() { @@ -174,12 +175,17 @@ impl HashedNodeStore { root_hash: Default::default(), }) } + pub fn create_node(&mut self, node: Node) -> Result { let (addr, size) = self.nodestore.allocate_node(&node)?; self.modified.insert(addr, (Arc::new(node), size)); Ok(addr) } + pub fn delete_node(&mut self, addr: LinearAddress) -> Result<(), Error> { + self.nodestore.delete_node(addr) + } + /// Fixes the trie after a node is updated. /// When a node is updated, it might move to a different LinearAddress. /// If it does, we need to update its parent so that it points to the @@ -199,7 +205,7 @@ impl HashedNodeStore { new_addr: LinearAddress, ) -> Result<(), MerkleError> { let Some(parent) = ancestors.next_back() else { - self.set_root(new_addr)?; + self.set_root(Some(new_addr))?; return Ok(()); }; @@ -285,7 +291,7 @@ impl HashedNodeStore { Ok(new_address) } - pub fn set_root(&mut self, root_addr: LinearAddress) -> Result<(), Error> { + pub fn set_root(&mut self, root_addr: Option) -> Result<(), Error> { self.nodestore.set_root(root_addr) } } @@ -424,7 +430,7 @@ mod test { value: Box::new(*b"abc"), }); let addr = hns.create_node(node).unwrap(); - hns.set_root(addr).unwrap(); + hns.set_root(Some(addr)).unwrap(); let frozen = hns.freeze().unwrap(); assert_ne!(frozen.root_hash().unwrap(), Default::default()); diff --git a/firewood/src/merkle.rs b/firewood/src/merkle.rs index 5fd8bef90..98ea22178 100644 --- a/firewood/src/merkle.rs +++ b/firewood/src/merkle.rs @@ -359,7 +359,7 @@ impl Merkle { value, }); let root_addr = self.create_node(root)?; - self.set_root(root_addr)?; + self.set_root(Some(root_addr))?; return Ok(()); }; // There is a root but it's not a prefix of `path`. @@ -416,7 +416,7 @@ impl Merkle { } let new_root_addr = self.create_node(Node::Branch(Box::new(new_root)))?; - self.set_root(new_root_addr)?; + self.set_root(Some(new_root_addr))?; return Ok(()); }; // `greatest_prefix_node` is a prefix of `path` @@ -588,11 +588,164 @@ impl Merkle { } } - pub fn remove(&mut self, _key: &[u8]) -> Result>, MerkleError> { - // let Some(root_address) = self.root_address() else { - // return Ok(None); - // }; - todo!() + /// Removes the value associated with the given `key`. + /// Returns the value that was removed, if any. + /// Otherwise returns `None`. + pub fn remove(&mut self, key: &[u8]) -> Result>, MerkleError> { + let path = Path::from_nibbles_iterator(NibblesIterator::new(key)); + + let ancestors = + PathIterator::new(self, key)?.collect::, MerkleError>>()?; + + // If the last element of ancestors doesn't contain the key/value pair we are looking for, + // then we can return early without making any changes + + // The path from the root down to and including the node with the greatest prefix of `path` + let mut ancestors = ancestors.iter(); + + let Some(greatest_prefix_node) = ancestors.next_back() else { + // There is no node which is a prefix of `path`. + // Therefore `path` is not in the trie. + return Ok(None); + }; + + if &*greatest_prefix_node.key_nibbles != path.0.as_ref() { + // `greatest_prefix_node` is a prefix of `path` but not equal to `path`. + // Therefore `path` is not in the trie. + return Ok(None); + } + + let removed = greatest_prefix_node; + + match &*removed.node { + Node::Branch(branch) => { + let Some(removed_value) = &branch.value else { + // The node at `path` has no value so there's nothing to remove. + return Ok(None); + }; + + // We know the key exists, so see if the branch has more than 1 child. + let mut branch_children = branch + .children + .iter() + .enumerate() + .filter_map(|(index, addr)| addr.map(|addr| (index as u8, addr))); + + let (child_index, child_addr) = + branch_children.next().expect("branch must have children"); + + if branch_children.next().is_some() { + // The branch has more than 1 child. Remove the value but keep the branch. + // ... ... + // | --> | + // branch branch (no value now) + // / \ / \ + // ... child ... child + // + // Note in the after diagram `child` may be the only child of `branch`. + let branch = BranchNode { + children: branch.children, + partial_path: branch.partial_path.clone(), + value: None, + child_hashes: branch.child_hashes.clone(), + }; + self.update_node(ancestors, removed.addr, Node::Branch(Box::new(branch)))?; + return Ok(Some(removed_value.clone())); + } + + // The branch has only 1 child. + // The branch must be combined with its child since it now has no value. + // ... ... + // | --> | + // branch combined + // | + // child + // + // where combined has `child`'s value. + // Note that child/combined may be a leaf or a branch. + + let combined = { + let child = self.read_node(child_addr)?; + + // `combined`'s partial path is the concatenation of `branch`'s partial path, + // `child_index` and `child`'s partial path. + let partial_path = Path::from_nibbles_iterator( + branch + .partial_path + .iter() + .chain(once(&child_index)) + .chain(child.partial_path().iter()) + .copied(), + ); + + child.new_with_partial_path(partial_path) + }; + + self.delete_node(child_addr)?; + + self.update_node(ancestors, removed.addr, combined)?; + + Ok(Some(removed_value.clone())) + } + Node::Leaf(leaf) => { + self.delete_node(removed.addr)?; + + while let Some(ancestor) = ancestors.next_back() { + // Remove all ancestors until we find one that has a value + // or multiple children. + let ancestor_addr = ancestor.addr; + let child_index = ancestor.next_nibble.expect("parent has a child"); + let ancestor = ancestor.node.as_branch().expect("parent must be a branch"); + + let num_children = ancestor + .children + .iter() + .filter(|addr| addr.is_some()) + .count(); + + if num_children > 1 { + #[rustfmt::skip] + // Update `ancestor` to remove its child. + // ... ... + // | --> | + // ancestor ancestor + // / \ | + // ... child ... + let mut ancestor = BranchNode { + children: ancestor.children, + partial_path: ancestor.partial_path.clone(), + value: ancestor.value.clone(), + child_hashes: ancestor.child_hashes.clone(), + }; + ancestor.update_child(child_index, None); + self.update_node( + ancestors, + ancestor_addr, + Node::Branch(Box::new(ancestor)), + )?; + return Ok(Some(leaf.value.clone())); + } + + // The ancestor has only 1 child, which is now deleted. + if let Some(ancestor_value) = &ancestor.value { + // Turn the ancestor into a leaf. + let ancestor = Node::Leaf(LeafNode { + value: ancestor_value.clone(), + partial_path: ancestor.partial_path.clone(), + }); + self.update_node(ancestors, ancestor_addr, ancestor)?; + return Ok(Some(leaf.value.clone())); + } + + // The ancestor had 1 child and no value so it should be removed. + self.delete_node(ancestor_addr)?; + } + + // The trie is now empty. + self.set_root(None)?; + Ok(Some(leaf.value.clone())) + } + } } pub fn freeze(self) -> Result, MerkleError> { @@ -604,10 +757,6 @@ impl Merkle { pub fn put_node(&mut self, node: Node) -> Result { self.create_node(node).map_err(MerkleError::Format) } - - fn _delete_node(&mut self, _addr: LinearAddress) -> Result<(), MerkleError> { - todo!() - } } /// Returns an iterator where each element is the result of combining @@ -788,61 +937,97 @@ mod tests { } } - // #[test] - // fn remove_one() { - // let key = b"hello"; - // let val = b"world"; - - // let mut merkle = create_in_memory_merkle(); - // let root_addr = merkle.init_sentinel().unwrap(); - - // merkle.insert(key, val.to_vec(), root_addr).unwrap(); - - // assert_eq!( - // merkle.get(key, root_addr).unwrap().as_deref(), - // val.as_slice().into() - // ); - - // let removed_val = merkle.remove(key, root_addr).unwrap(); - // assert_eq!(removed_val.as_deref(), val.as_slice().into()); - - // let fetched_val = merkle.get(key, root_addr).unwrap(); - // assert!(fetched_val.is_none()); - // } + #[test] + fn remove_root() { + let key0 = vec![0]; + let val0 = [0]; + let key1 = vec![0, 1]; + let val1 = [0, 1]; + let key2 = vec![0, 1, 2]; + let val2 = [0, 1, 2]; + let key3 = vec![0, 1, 15]; + let val3 = [0, 1, 15]; - // #[test] - // fn remove_many() { - // let mut merkle = create_in_memory_merkle(); - // let root_addr = merkle.init_sentinel().unwrap(); + let mut merkle = create_in_memory_merkle(); - // // insert values - // for key_val in u8::MIN..=u8::MAX { - // let key = &[key_val]; - // let val = &[key_val]; + merkle.insert(&key0, Box::from(val0)).unwrap(); + merkle.insert(&key1, Box::from(val1)).unwrap(); + merkle.insert(&key2, Box::from(val2)).unwrap(); + merkle.insert(&key3, Box::from(val3)).unwrap(); + // Trie is: + // key0 + // | + // key1 + // / \ + // key2 key3 + + // Test removal of root when it's a branch with 1 branch child + let removed_val = merkle.remove(&key0).unwrap(); + assert_eq!(removed_val, Some(Box::from(val0))); + assert!(merkle.get(&key0).unwrap().is_none()); + // Removing an already removed key is a no-op + assert!(merkle.remove(&key0).unwrap().is_none()); + + // Trie is: + // key1 + // / \ + // key2 key3 + // Test removal of root when it's a branch with multiple children + assert_eq!(merkle.remove(&key1).unwrap(), Some(Box::from(val1))); + assert!(merkle.get(&key1).unwrap().is_none()); + assert!(merkle.remove(&key1).unwrap().is_none()); + + // Trie is: + // key1 (now has no value) + // / \ + // key2 key3 + let removed_val = merkle.remove(&key2).unwrap(); + assert_eq!(removed_val, Some(Box::from(val2))); + assert!(merkle.get(&key2).unwrap().is_none()); + assert!(merkle.remove(&key2).unwrap().is_none()); + + // Trie is: + // key1 (now has no value) + // | + // key3 + let removed_val = merkle.remove(&key3).unwrap(); + assert_eq!(removed_val, Some(Box::from(val3))); + assert!(merkle.get(&key3).unwrap().is_none()); + assert!(merkle.remove(&key3).unwrap().is_none()); + + assert!(merkle.root_address().is_none()); + } - // merkle.insert(key, val.to_vec(), root_addr).unwrap(); + #[test] + fn remove_many() { + let mut merkle = create_in_memory_merkle(); - // let fetched_val = merkle.get(key, root_addr).unwrap(); + // insert key-value pairs + for key_val in u8::MIN..=u8::MAX { + let key = [key_val]; + let val = [key_val]; - // // make sure the value was inserted - // assert_eq!(fetched_val.as_deref(), val.as_slice().into()); - // } + merkle.insert(&key, Box::new(val)).unwrap(); + let got = merkle.get(&key).unwrap().unwrap(); + assert_eq!(&*got, val); + } - // // remove values - // for key_val in u8::MIN..=u8::MAX { - // let key = &[key_val]; - // let val = &[key_val]; + // remove key-value pairs + for key_val in u8::MIN..=u8::MAX { + let key = [key_val]; + let val = [key_val]; - // let Ok(removed_val) = merkle.remove(key, root_addr) else { - // panic!("({key_val}, {key_val}) missing"); - // }; + let got = merkle.remove(&key).unwrap().unwrap(); + assert_eq!(&*got, val); - // assert_eq!(removed_val.as_deref(), val.as_slice().into()); + // Removing an already removed key is a no-op + assert!(merkle.remove(&key).unwrap().is_none()); - // let fetched_val = merkle.get(key, root_addr).unwrap(); - // assert!(fetched_val.is_none()); - // } - // } + let got = merkle.get(&key).unwrap(); + assert!(got.is_none()); + } + assert!(merkle.root_address().is_none()); + } // #[test] // fn get_empty_proof() { diff --git a/storage/src/nodestore.rs b/storage/src/nodestore.rs index 3964694db..bb113b094 100644 --- a/storage/src/nodestore.rs +++ b/storage/src/nodestore.rs @@ -416,8 +416,8 @@ impl NodeStore { } /// Write the root [LinearAddress] of the [NodeStore] - pub fn set_root(&mut self, addr: LinearAddress) -> Result<(), Error> { - self.header.root_address = Some(addr); + pub fn set_root(&mut self, addr: Option) -> Result<(), Error> { + self.header.root_address = addr; self.write_header() } }