Skip to content

Commit

Permalink
refactor stores + fix functions requires result
Browse files Browse the repository at this point in the history
  • Loading branch information
rkdud007 committed Dec 27, 2023
1 parent a4a69e2 commit 45e16b8
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 84 deletions.
4 changes: 3 additions & 1 deletion benches/incremental_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ async fn prepare_incremental(count: usize) -> IncrementalMerkleTree<StarkPoseido

let store = Arc::new(store);

IncrementalMerkleTree::initialize(count, "0x0".to_string(), hasher, store, None).await
IncrementalMerkleTree::initialize(count, "0x0".to_string(), hasher, store, None)
.await
.unwrap()
}

fn bench(c: &mut Criterion) {
Expand Down
37 changes: 20 additions & 17 deletions src/merkle_tree/merkle_trees/incremental.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ where
hasher: H,
store: Arc<dyn Store>,
mmr_id: Option<String>,
) -> Self {
) -> Result<Self> {
let tree = IncrementalMerkleTree::new(size, null_value, hasher, store, mmr_id);
let nodes = tree.render_empty_tree();
let nodes_hashmap: HashMap<SubKey, String> =
Expand All @@ -81,16 +81,21 @@ where
acc
});

tree.nodes.set_many(nodes_hashmap).await;
tree.nodes.set_many(nodes_hashmap).await?;

tree.root_hash
.set(&nodes[nodes.len() - 1][0].hash, SubKey::None)
.await;
tree
.await?;
Ok(tree)
}

pub async fn get_root(&self) -> String {
self.root_hash.get(SubKey::None).await.unwrap()
pub async fn get_root(&self) -> Result<String> {
self.root_hash.get(SubKey::None).await?.ok_or_else(|| {
anyhow!(
"Root hash not found for mmr_id: {}",
self.mmr_id.to_string()
)
})
}

pub async fn get_inclusion_proof(&self, index: usize) -> Result<Vec<String>> {
Expand All @@ -114,7 +119,7 @@ where
.map(|(height, index)| SubKey::String(format!("{}:{}", height, index)))
.collect();

let nodes_hash_map = self.nodes.get_many(kv_entries).await;
let nodes_hash_map = self.nodes.get_many(kv_entries).await?;

let mut ordered_nodes = Vec::with_capacity(required_nodes_by_height.len());
for (height, index) in required_nodes_by_height {
Expand Down Expand Up @@ -148,7 +153,7 @@ where
current_index /= 2;
}

let root = self.root_hash.get(SubKey::None).await.unwrap();
let root = self.get_root().await?;
Ok(root == current_value)
}

Expand Down Expand Up @@ -197,8 +202,8 @@ where
);
}

self.nodes.set_many(kv_updates).await;
self.root_hash.set(&current_value, SubKey::None).await;
self.nodes.set_many(kv_updates).await?;
self.root_hash.set(&current_value, SubKey::None).await?;
Ok(current_value)
}

Expand Down Expand Up @@ -264,7 +269,7 @@ where
})
.collect();

let nodes_hash_map = self.nodes.get_many(kv_entries.clone()).await;
let nodes_hash_map = self.nodes.get_many(kv_entries.clone()).await?;

let mut nodes_values: Vec<String> = Vec::with_capacity(kv_entries.len());
for kv in kv_entries {
Expand All @@ -281,13 +286,11 @@ where
indexes: &mut Vec<usize>,
values: &mut Vec<String>,
proof: &mut Vec<String>,
) -> bool {
let root = self.root_hash.get(SubKey::None).await.unwrap();
let calculated_root = self
.calculate_multiproof_root_hash(indexes, values, proof)
.unwrap();
) -> Result<bool> {
let root = self.get_root().await?;
let calculated_root = self.calculate_multiproof_root_hash(indexes, values, proof)?;

root == calculated_root
Ok(root == calculated_root)
}

fn calculate_multiproof_root_hash(
Expand Down
20 changes: 12 additions & 8 deletions src/mmr/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl MMR {
//? Store the hash in the database
self.hashes
.set(&value, SubKey::Usize(last_element_idx))
.await;
.await?;

peaks.push(value);

Expand All @@ -161,7 +161,7 @@ impl MMR {

self.hashes
.set(&parent_hash, SubKey::Usize(last_element_idx))
.await;
.await?;
peaks.push(parent_hash);
}

Expand All @@ -171,7 +171,7 @@ impl MMR {

// Compute the new root hash
let root_hash = self.calculate_root_hash(&bag, last_element_idx)?;
self.root_hash.set(&root_hash, SubKey::None).await;
self.root_hash.set(&root_hash, SubKey::None).await?;

let leaves = self.leaves_count.increment().await?;

Expand Down Expand Up @@ -221,7 +221,7 @@ impl MMR {
.map(SubKey::Usize)
.collect::<Vec<SubKey>>(),
)
.await;
.await?;

let mut siblings_hashes_vec: Vec<String> = siblings
.iter()
Expand All @@ -233,7 +233,11 @@ impl MMR {
format_proof(siblings_hashes_vec, formatting_opts.proof.clone()).unwrap();
}

let element_hash = self.hashes.get(SubKey::Usize(element_index)).await.unwrap();
let element_hash = self
.hashes
.get(SubKey::Usize(element_index))
.await?
.unwrap();

Ok(Proof {
element_index,
Expand Down Expand Up @@ -277,11 +281,11 @@ impl MMR {
.into_iter()
.map(SubKey::Usize)
.collect();
let all_siblings_hashes = self.hashes.get_many(sibling_hashes_to_get).await;
let all_siblings_hashes = self.hashes.get_many(sibling_hashes_to_get).await?;

let elements_ids_str: Vec<SubKey> =
elements_indexes.iter().map(|&x| SubKey::Usize(x)).collect();
let element_hashes = self.hashes.get_many(elements_ids_str).await;
let element_hashes = self.hashes.get_many(elements_ids_str).await?;

let mut proofs: Vec<Proof> = Vec::new();
for &element_id in &elements_indexes {
Expand Down Expand Up @@ -404,7 +408,7 @@ impl MMR {
let hashes_result = self
.hashes
.get_many(peak_idxs.clone().into_iter().map(SubKey::Usize).collect())
.await;
.await?;
// Assuming hashes_result is a HashMap<String, String>
let hashes: Vec<String> = peak_idxs
.iter()
Expand Down
12 changes: 12 additions & 0 deletions src/store/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,24 @@ use anyhow::Result;
use async_trait::async_trait;
use std::{collections::HashMap, fmt::Debug};

/// Define common behavior for all stores
#[async_trait]
pub trait Store: Send + Sync + Debug {
/// Get a value from the store
async fn get(&self, key: &str) -> Result<Option<String>>;

/// Get many values from the store
async fn get_many(&self, keys: Vec<&str>) -> Result<HashMap<String, String>>;

/// Set a value in the store
async fn set(&self, key: &str, value: &str) -> Result<()>;

/// Set many values in the store
async fn set_many(&self, entries: HashMap<String, String>) -> Result<()>;

/// Delete a value from the store
async fn delete(&self, key: &str) -> Result<()>;

/// Delete many values from the store
async fn delete_many(&self, keys: Vec<&str>) -> Result<()>;
}
16 changes: 12 additions & 4 deletions src/store/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,42 @@ use std::sync::Arc;
use super::Store;
use anyhow::Result;

/// A counter that is stored in a store
///
/// It is used to keep track of the number of times for specific keys are used
#[derive(Debug)]
pub struct InStoreCounter {
/// The store that the counter is stored in
pub store: Arc<dyn Store>,

/// The key of the counter
pub key: String,
}

impl InStoreCounter {
/// Create a new counter
pub fn new(store: Arc<dyn Store>, key: String) -> Self {
Self { store, key }
}

/// Get the count of the key
pub async fn get(&self) -> usize {
let current_count = self
.store
self.store
.get(&self.key)
.await
.expect("Failed to get count")
.unwrap_or("0".to_string());
current_count
.unwrap_or("0".to_string())
.parse::<usize>()
.expect("Failed to parse count")
}

/// Set the count of the key
pub async fn set(&self, count: usize) -> Result<()> {
self.store.set(&self.key, &count.to_string()).await?;
Ok(())
}

/// Increment the count of the key
pub async fn increment(&self) -> Result<usize> {
let current_count = self.get().await;
let new_count = current_count + 1;
Expand Down
1 change: 1 addition & 0 deletions src/store/stores/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use parking_lot::RwLock;

use crate::store::Store;

/// A store that is stored in memory
#[derive(Debug)]
pub struct InMemoryStore {
pub store: RwLock<HashMap<String, String>>,
Expand Down
1 change: 1 addition & 0 deletions src/store/stores/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tokio::sync::Mutex;

use super::super::Store;

/// A store that is stored in SQLite
#[derive(Debug)]
pub struct SQLiteStore {
db: Mutex<Pool<Sqlite>>,
Expand Down
Loading

0 comments on commit 45e16b8

Please sign in to comment.