From 45e16b8da39237becf2e73082eaa0868ba398332 Mon Sep 17 00:00:00 2001 From: Pia Date: Wed, 27 Dec 2023 17:56:41 +0900 Subject: [PATCH] refactor stores + fix functions requires result --- benches/incremental_benchmark.rs | 4 +- src/merkle_tree/merkle_trees/incremental.rs | 37 +++++++------- src/mmr/core.rs | 20 ++++---- src/store/core.rs | 12 +++++ src/store/counter.rs | 16 +++++-- src/store/stores/memory.rs | 1 + src/store/stores/sqlite.rs | 1 + src/store/table.rs | 53 ++++++++++++--------- tests/merkle_tree/incremental.rs | 41 ++++++++++------ tests/mmr/core.rs | 2 +- tests/store/memory.rs | 21 ++++---- tests/store/sqlite.rs | 21 ++++---- 12 files changed, 145 insertions(+), 84 deletions(-) diff --git a/benches/incremental_benchmark.rs b/benches/incremental_benchmark.rs index feca202..49ce04f 100644 --- a/benches/incremental_benchmark.rs +++ b/benches/incremental_benchmark.rs @@ -14,7 +14,9 @@ async fn prepare_incremental(count: usize) -> IncrementalMerkleTree, mmr_id: Option, - ) -> Self { + ) -> Result { let tree = IncrementalMerkleTree::new(size, null_value, hasher, store, mmr_id); let nodes = tree.render_empty_tree(); let nodes_hashmap: HashMap = @@ -81,16 +81,21 @@ where acc }); - tree.nodes.set_many(nodes_hashmap).await; + tree.nodes.set_many(nodes_hashmap).await?; tree.root_hash .set(&nodes[nodes.len() - 1][0].hash, SubKey::None) - .await; - tree + .await?; + Ok(tree) } - pub async fn get_root(&self) -> String { - self.root_hash.get(SubKey::None).await.unwrap() + pub async fn get_root(&self) -> Result { + self.root_hash.get(SubKey::None).await?.ok_or_else(|| { + anyhow!( + "Root hash not found for mmr_id: {}", + self.mmr_id.to_string() + ) + }) } pub async fn get_inclusion_proof(&self, index: usize) -> Result> { @@ -114,7 +119,7 @@ where .map(|(height, index)| SubKey::String(format!("{}:{}", height, index))) .collect(); - let nodes_hash_map = self.nodes.get_many(kv_entries).await; + let nodes_hash_map = self.nodes.get_many(kv_entries).await?; let mut ordered_nodes = Vec::with_capacity(required_nodes_by_height.len()); for (height, index) in required_nodes_by_height { @@ -148,7 +153,7 @@ where current_index /= 2; } - let root = self.root_hash.get(SubKey::None).await.unwrap(); + let root = self.get_root().await?; Ok(root == current_value) } @@ -197,8 +202,8 @@ where ); } - self.nodes.set_many(kv_updates).await; - self.root_hash.set(¤t_value, SubKey::None).await; + self.nodes.set_many(kv_updates).await?; + self.root_hash.set(¤t_value, SubKey::None).await?; Ok(current_value) } @@ -264,7 +269,7 @@ where }) .collect(); - let nodes_hash_map = self.nodes.get_many(kv_entries.clone()).await; + let nodes_hash_map = self.nodes.get_many(kv_entries.clone()).await?; let mut nodes_values: Vec = Vec::with_capacity(kv_entries.len()); for kv in kv_entries { @@ -281,13 +286,11 @@ where indexes: &mut Vec, values: &mut Vec, proof: &mut Vec, - ) -> bool { - let root = self.root_hash.get(SubKey::None).await.unwrap(); - let calculated_root = self - .calculate_multiproof_root_hash(indexes, values, proof) - .unwrap(); + ) -> Result { + let root = self.get_root().await?; + let calculated_root = self.calculate_multiproof_root_hash(indexes, values, proof)?; - root == calculated_root + Ok(root == calculated_root) } fn calculate_multiproof_root_hash( diff --git a/src/mmr/core.rs b/src/mmr/core.rs index ed3c165..3aa50e8 100644 --- a/src/mmr/core.rs +++ b/src/mmr/core.rs @@ -145,7 +145,7 @@ impl MMR { //? Store the hash in the database self.hashes .set(&value, SubKey::Usize(last_element_idx)) - .await; + .await?; peaks.push(value); @@ -161,7 +161,7 @@ impl MMR { self.hashes .set(&parent_hash, SubKey::Usize(last_element_idx)) - .await; + .await?; peaks.push(parent_hash); } @@ -171,7 +171,7 @@ impl MMR { // Compute the new root hash let root_hash = self.calculate_root_hash(&bag, last_element_idx)?; - self.root_hash.set(&root_hash, SubKey::None).await; + self.root_hash.set(&root_hash, SubKey::None).await?; let leaves = self.leaves_count.increment().await?; @@ -221,7 +221,7 @@ impl MMR { .map(SubKey::Usize) .collect::>(), ) - .await; + .await?; let mut siblings_hashes_vec: Vec = siblings .iter() @@ -233,7 +233,11 @@ impl MMR { format_proof(siblings_hashes_vec, formatting_opts.proof.clone()).unwrap(); } - let element_hash = self.hashes.get(SubKey::Usize(element_index)).await.unwrap(); + let element_hash = self + .hashes + .get(SubKey::Usize(element_index)) + .await? + .unwrap(); Ok(Proof { element_index, @@ -277,11 +281,11 @@ impl MMR { .into_iter() .map(SubKey::Usize) .collect(); - let all_siblings_hashes = self.hashes.get_many(sibling_hashes_to_get).await; + let all_siblings_hashes = self.hashes.get_many(sibling_hashes_to_get).await?; let elements_ids_str: Vec = elements_indexes.iter().map(|&x| SubKey::Usize(x)).collect(); - let element_hashes = self.hashes.get_many(elements_ids_str).await; + let element_hashes = self.hashes.get_many(elements_ids_str).await?; let mut proofs: Vec = Vec::new(); for &element_id in &elements_indexes { @@ -404,7 +408,7 @@ impl MMR { let hashes_result = self .hashes .get_many(peak_idxs.clone().into_iter().map(SubKey::Usize).collect()) - .await; + .await?; // Assuming hashes_result is a HashMap let hashes: Vec = peak_idxs .iter() diff --git a/src/store/core.rs b/src/store/core.rs index efee317..c986ad7 100644 --- a/src/store/core.rs +++ b/src/store/core.rs @@ -2,12 +2,24 @@ use anyhow::Result; use async_trait::async_trait; use std::{collections::HashMap, fmt::Debug}; +/// Define common behavior for all stores #[async_trait] pub trait Store: Send + Sync + Debug { + /// Get a value from the store async fn get(&self, key: &str) -> Result>; + + /// Get many values from the store async fn get_many(&self, keys: Vec<&str>) -> Result>; + + /// Set a value in the store async fn set(&self, key: &str, value: &str) -> Result<()>; + + /// Set many values in the store async fn set_many(&self, entries: HashMap) -> Result<()>; + + /// Delete a value from the store async fn delete(&self, key: &str) -> Result<()>; + + /// Delete many values from the store async fn delete_many(&self, keys: Vec<&str>) -> Result<()>; } diff --git a/src/store/counter.rs b/src/store/counter.rs index 22714ab..6addc25 100644 --- a/src/store/counter.rs +++ b/src/store/counter.rs @@ -3,34 +3,42 @@ use std::sync::Arc; use super::Store; use anyhow::Result; +/// A counter that is stored in a store +/// +/// It is used to keep track of the number of times for specific keys are used #[derive(Debug)] pub struct InStoreCounter { + /// The store that the counter is stored in pub store: Arc, + + /// The key of the counter pub key: String, } impl InStoreCounter { + /// Create a new counter pub fn new(store: Arc, key: String) -> Self { Self { store, key } } + /// Get the count of the key pub async fn get(&self) -> usize { - let current_count = self - .store + self.store .get(&self.key) .await .expect("Failed to get count") - .unwrap_or("0".to_string()); - current_count + .unwrap_or("0".to_string()) .parse::() .expect("Failed to parse count") } + /// Set the count of the key pub async fn set(&self, count: usize) -> Result<()> { self.store.set(&self.key, &count.to_string()).await?; Ok(()) } + /// Increment the count of the key pub async fn increment(&self) -> Result { let current_count = self.get().await; let new_count = current_count + 1; diff --git a/src/store/stores/memory.rs b/src/store/stores/memory.rs index 627e21e..4dbbe4c 100644 --- a/src/store/stores/memory.rs +++ b/src/store/stores/memory.rs @@ -6,6 +6,7 @@ use parking_lot::RwLock; use crate::store::Store; +/// A store that is stored in memory #[derive(Debug)] pub struct InMemoryStore { pub store: RwLock>, diff --git a/src/store/stores/sqlite.rs b/src/store/stores/sqlite.rs index ee73130..ecf76b6 100644 --- a/src/store/stores/sqlite.rs +++ b/src/store/stores/sqlite.rs @@ -7,6 +7,7 @@ use tokio::sync::Mutex; use super::super::Store; +/// A store that is stored in SQLite #[derive(Debug)] pub struct SQLiteStore { db: Mutex>, diff --git a/src/store/table.rs b/src/store/table.rs index d25af53..95248e3 100644 --- a/src/store/table.rs +++ b/src/store/table.rs @@ -2,9 +2,12 @@ use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; +use anyhow::{bail, Ok, Result}; + use super::Store; -#[derive(Debug, Clone)] +/// A sub key that is used to get a value from a store +#[derive(Debug, Clone, PartialEq, Eq)] pub enum SubKey { String(String), Usize(usize), @@ -21,21 +24,15 @@ impl ToString for SubKey { } } -impl PartialEq for SubKey { - fn eq(&self, other: &Self) -> bool { - self.to_string() == other.to_string() - } -} - -impl Eq for SubKey {} - impl Hash for SubKey { fn hash(&self, state: &mut H) { self.to_string().hash(state); } } +/// A function that is used to get the full key and store for a given sub_key pub type GetFullKeyAndStoreFn = fn(&InStoreTable, SubKey) -> (Arc, String); +/// A function that is used to get the full keys and stores for a given list of sub_keys pub type GetFullKeysAndStoresFn = fn(&InStoreTable, Vec) -> Vec<(Arc, Vec)>; @@ -47,6 +44,7 @@ pub struct SubMMR { pub store: Arc, } +/// A table that is stored in a store #[derive(Debug, Clone)] pub struct InStoreTable { /// Always use this store for setters @@ -70,6 +68,7 @@ pub struct InStoreTable { } impl InStoreTable { + /// Create a new table pub fn new(store: Arc, key: String) -> Self { Self { store, @@ -81,10 +80,12 @@ impl InStoreTable { } } + /// Get the full key for a given sub_key/key pub fn get_full_key(key: &str, sub_key: &str) -> String { format!("{}{}", key, sub_key) } + /// Get the full key and store for a given sub_key pub fn default_get_store_and_full_key(&self, sub_key: SubKey) -> (Arc, String) { let new_sub_key = sub_key.to_string(); ( @@ -93,6 +94,7 @@ impl InStoreTable { ) } + /// Get the full keys and stores for a given list of sub_keys pub fn default_get_stores_and_full_keys( &self, sub_keys: Vec, @@ -104,12 +106,14 @@ impl InStoreTable { vec![(self.store.clone(), sub_keys)] } - pub async fn get(&self, sub_key: SubKey) -> Option { + /// Get the value from full key that retrieved from the sub_key + pub async fn get(&self, sub_key: SubKey) -> Result> { let (store, full_key) = (self.get_store_and_full_key)(self, sub_key); - store.get(&full_key).await.unwrap_or_default() + Ok(store.get(&full_key).await.unwrap_or_default()) } - pub async fn get_many(&self, sub_keys: Vec) -> HashMap { + /// Get the values from full keys that retrieved from the sub_keys + pub async fn get_many(&self, sub_keys: Vec) -> Result> { let requested_len = sub_keys.len(); let stores_and_keys = (self.get_stores_and_full_keys)(self, sub_keys); @@ -117,7 +121,7 @@ impl InStoreTable { for store_and_keys in stores_and_keys { let (store, keys) = store_and_keys; let keys_ref: Vec<&str> = keys.iter().map(AsRef::as_ref).collect(); - let fetched = store.get_many(keys_ref).await.unwrap_or_default(); // Assuming get_many is async and returns a Result + let fetched = store.get_many(keys_ref).await?; for (key, value) in fetched.iter() { let new_key: String = if key.contains(':') { @@ -129,21 +133,22 @@ impl InStoreTable { } } - assert_eq!( - keyless.len(), - requested_len, - "Some keys were not found {:?}", - keyless - ); - keyless + if keyless.len() != requested_len { + bail!("Some keys were not found {:?}", keyless); + } + + Ok(keyless) } - pub async fn set(&self, value: &str, sub_key: SubKey) { + /// Set the value from full key that retrieved from the sub_key + pub async fn set(&self, value: &str, sub_key: SubKey) -> Result<()> { let (store, key) = (self.get_store_and_full_key)(self, sub_key); - store.set(&key, value).await.expect("Failed to set value") + store.set(&key, value).await.expect("Failed to set value"); + Ok(()) } - pub async fn set_many(&self, entries: HashMap) { + /// Set the values from full keys that retrieved from the sub_keys + pub async fn set_many(&self, entries: HashMap) -> Result<()> { let mut store_entries = HashMap::new(); for (key, value) in entries.into_iter() { @@ -155,5 +160,7 @@ impl InStoreTable { .set_many(store_entries) .await .expect("Failed to set many values"); + + Ok(()) } } diff --git a/tests/merkle_tree/incremental.rs b/tests/merkle_tree/incremental.rs index b3c125b..83458af 100644 --- a/tests/merkle_tree/incremental.rs +++ b/tests/merkle_tree/incremental.rs @@ -13,10 +13,11 @@ async fn initialize() { let store = Arc::new(store); - let tree = - IncrementalMerkleTree::initialize(1024, "0x0".to_string(), hasher, store, None).await; + let tree = IncrementalMerkleTree::initialize(1024, "0x0".to_string(), hasher, store, None) + .await + .unwrap(); assert_eq!( - tree.get_root().await, + tree.get_root().await.unwrap(), "0x4a21358c3e754766216b4c93ecfae222e86822f746e706e563f3a05ef398959" ); } @@ -27,7 +28,9 @@ async fn get_path() { let hasher = StarkPoseidonHasher::new(Some(false)); let store = Arc::new(store); - let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None).await; + let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None) + .await + .unwrap(); let path = tree.get_inclusion_proof(10).await.unwrap(); let expected_nodes = vec![ @@ -38,7 +41,7 @@ async fn get_path() { ]; // Await the async call and store the result - let node_map = tree.nodes.get_many(expected_nodes.clone()).await; + let node_map = tree.nodes.get_many(expected_nodes.clone()).await.unwrap(); // Now use the resulting HashMap let expected_path: Vec = expected_nodes @@ -55,7 +58,9 @@ async fn verify_proof() { let hasher = StarkPoseidonHasher::new(Some(false)); let store = Arc::new(store); - let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None).await; + let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None) + .await + .unwrap(); let path = tree.get_inclusion_proof(10).await.unwrap(); let valid_proof = tree.verify_proof(10, "0x0", &path).await.unwrap(); @@ -71,7 +76,9 @@ async fn update() { let hasher = StarkPoseidonHasher::new(Some(false)); let store = Arc::new(store); - let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None).await; + let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None) + .await + .unwrap(); let path = tree.get_inclusion_proof(7).await.unwrap(); let valid_proof = tree.verify_proof(7, "0x0", &path).await.unwrap(); @@ -88,7 +95,7 @@ async fn update() { assert!(updated_proof); assert_eq!( - tree.get_root().await, + tree.get_root().await.unwrap(), "0x53228c039bc23bffa7a0ba7a864088f98c92dbc41c3737b681cdd7b1bcfe1f2" ); } @@ -99,14 +106,16 @@ async fn invalid_update() { let hasher = StarkPoseidonHasher::new(Some(false)); let store = Arc::new(store); - let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None).await; + let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None) + .await + .unwrap(); let path = tree.get_inclusion_proof(7).await.unwrap(); - let empty_root = tree.get_root().await; + let empty_root = tree.get_root().await.unwrap(); let result = tree .update(7, "0x1".to_string(), "0x2".to_string(), path.clone()) .await; assert!(result.is_err()); - assert_eq!(tree.get_root().await, empty_root); + assert_eq!(tree.get_root().await.unwrap(), empty_root); } #[tokio::test] @@ -120,7 +129,8 @@ async fn generate_and_verify_multi_proof() { let default_hash = "0x0".to_string(); let tree = IncrementalMerkleTree::initialize(tree_size, default_hash.clone(), hasher, store, None) - .await; + .await + .unwrap(); for i in 0..tree_size { let path = tree.get_inclusion_proof(i).await.unwrap(); @@ -157,7 +167,8 @@ async fn generate_and_verify_multi_proof() { let is_valid = tree .verify_multi_proof(&mut test, &mut test_values, &mut multiproof) - .await; + .await + .unwrap(); assert!(is_valid); } @@ -173,7 +184,9 @@ async fn example() { let store = Arc::new(store); let hasher = StarkPoseidonHasher::new(Some(false)); - let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None).await; + let tree = IncrementalMerkleTree::initialize(16, "0x0".to_string(), hasher, store, None) + .await + .unwrap(); let path = tree.get_inclusion_proof(10).await.unwrap(); let valid_proof = tree.verify_proof(10, "0x0", &path).await.unwrap(); diff --git a/tests/mmr/core.rs b/tests/mmr/core.rs index 83b84a9..406ff55 100644 --- a/tests/mmr/core.rs +++ b/tests/mmr/core.rs @@ -219,7 +219,7 @@ async fn test_new() { .unwrap(); assert_eq!( - core_mmr.root_hash.get(SubKey::None).await.unwrap(), + core_mmr.root_hash.get(SubKey::None).await.unwrap().unwrap(), hasher .hash(vec!["1".to_string(), hasher.get_genesis()]) .unwrap() diff --git a/tests/store/memory.rs b/tests/store/memory.rs index 0b3b7c6..439268a 100644 --- a/tests/store/memory.rs +++ b/tests/store/memory.rs @@ -89,8 +89,8 @@ async fn test_get_none_in_store_table() { // Create an in-store counter let table = InStoreTable::new(store.clone(), "table".to_string()); - table.set("value1", SubKey::None).await; - let value = table.get(SubKey::None).await; + table.set("value1", SubKey::None).await.unwrap(); + let value = table.get(SubKey::None).await.unwrap(); assert_eq!(value.unwrap(), "value1".to_string()); } @@ -104,10 +104,10 @@ async fn test_get_many_none_in_store_table() { let mut entries = HashMap::new(); entries.insert(SubKey::String("key1".to_string()), "value1".to_string()); entries.insert(SubKey::String("key2".to_string()), "value2".to_string()); - table.set_many(entries).await; - let value = table.get(SubKey::String("key1".to_string())).await; + table.set_many(entries).await.unwrap(); + let value = table.get(SubKey::String("key1".to_string())).await.unwrap(); assert_eq!(value.unwrap(), "value1".to_string()); - let value = table.get(SubKey::String("key2".to_string())).await; + let value = table.get(SubKey::String("key2".to_string())).await.unwrap(); assert_eq!(value.unwrap(), "value2".to_string()); let values = table @@ -115,7 +115,8 @@ async fn test_get_many_none_in_store_table() { SubKey::String("key1".to_string()), SubKey::String("key2".to_string()), ]) - .await; + .await + .unwrap(); assert_eq!(values.get("tablekey1"), Some(&"value1".to_string())); assert_eq!(values.get("tablekey2"), Some(&"value2".to_string())); } @@ -129,7 +130,11 @@ async fn test_get_some_in_store_table() { let table = InStoreTable::new(store.clone(), "table".to_string()); table .set("value1", SubKey::String("suffix1".to_string())) - .await; - let value = table.get(SubKey::String("suffix1".to_string())).await; + .await + .unwrap(); + let value = table + .get(SubKey::String("suffix1".to_string())) + .await + .unwrap(); assert_eq!(value.unwrap(), "value1".to_string()); } diff --git a/tests/store/sqlite.rs b/tests/store/sqlite.rs index 125bd16..97266bc 100644 --- a/tests/store/sqlite.rs +++ b/tests/store/sqlite.rs @@ -92,8 +92,8 @@ async fn test_get_none_in_store_table() { // Create an in-store counter let table = InStoreTable::new(store.clone(), "table".to_string()); - table.set("value1", SubKey::None).await; - let value = table.get(SubKey::None).await; + table.set("value1", SubKey::None).await.unwrap(); + let value = table.get(SubKey::None).await.unwrap(); assert_eq!(value.unwrap(), "value1".to_string()); } @@ -108,10 +108,10 @@ async fn test_get_many_none_in_store_table() { let mut entries = HashMap::new(); entries.insert(SubKey::String("key1".to_string()), "value1".to_string()); entries.insert(SubKey::String("key2".to_string()), "value2".to_string()); - table.set_many(entries).await; - let value = table.get(SubKey::String("key1".to_string())).await; + table.set_many(entries).await.unwrap(); + let value = table.get(SubKey::String("key1".to_string())).await.unwrap(); assert_eq!(value.unwrap(), "value1".to_string()); - let value = table.get(SubKey::String("key2".to_string())).await; + let value = table.get(SubKey::String("key2".to_string())).await.unwrap(); assert_eq!(value.unwrap(), "value2".to_string()); let values = table @@ -119,7 +119,8 @@ async fn test_get_many_none_in_store_table() { SubKey::String("key1".to_string()), SubKey::String("key2".to_string()), ]) - .await; + .await + .unwrap(); assert_eq!(values.get("tablekey1"), Some(&"value1".to_string())); assert_eq!(values.get("tablekey2"), Some(&"value2".to_string())); } @@ -134,7 +135,11 @@ async fn test_get_some_in_store_table() { let table = InStoreTable::new(store.clone(), "table".to_string()); table .set("value1", SubKey::String("suffix1".to_string())) - .await; - let value = table.get(SubKey::String("suffix1".to_string())).await; + .await + .unwrap(); + let value = table + .get(SubKey::String("suffix1".to_string())) + .await + .unwrap(); assert_eq!(value.unwrap(), "value1".to_string()); }