diff --git a/Cargo.lock b/Cargo.lock index ed9c651d2..153dde12c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,7 +314,6 @@ dependencies = [ "criterion", "hex", "hkdf", - "lru", "pyo3", "rand", "rstest", diff --git a/Cargo.toml b/Cargo.toml index 0864aa2c7..0ba1765e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -126,7 +126,6 @@ hex = "0.4.3" thiserror = "1.0.63" pyo3 = "0.21.2" arbitrary = "1.3.2" -lru = "0.12.4" rand = "0.8.5" criterion = "0.5.1" rstest = "0.22.0" diff --git a/crates/chia-bls/Cargo.toml b/crates/chia-bls/Cargo.toml index 97a4d9bd0..314f6a0e7 100644 --- a/crates/chia-bls/Cargo.toml +++ b/crates/chia-bls/Cargo.toml @@ -26,7 +26,6 @@ hex = { workspace = true } thiserror = { workspace = true } pyo3 = { workspace = true, features = ["multiple-pymethods"], optional = true } arbitrary = { workspace = true, optional = true } -lru = { workspace = true } [dev-dependencies] rand = { workspace = true } diff --git a/crates/chia-bls/src/bls_cache.rs b/crates/chia-bls/src/bls_cache.rs index e3b8cdcf9..6abc421cc 100644 --- a/crates/chia-bls/src/bls_cache.rs +++ b/crates/chia-bls/src/bls_cache.rs @@ -1,8 +1,8 @@ use std::borrow::Borrow; +use std::collections::{HashMap, HashSet, VecDeque}; use std::num::NonZeroUsize; use chia_sha2::Sha256; -use lru::LruCache; use std::sync::Mutex; use crate::{aggregate_verify_gt, hash_to_g2}; @@ -17,16 +17,37 @@ use crate::{GTElement, PublicKey, Signature}; /// However, validating a signature where we have no cached GT elements, the /// aggregate_verify() primitive is faster. When long-syncing, that's /// preferable. + +#[derive(Debug, Clone)] +struct BlsCacheData { + // sha256(pubkey + message) -> GTElement + items: HashMap<[u8; 32], GTElement>, + insertions_order: VecDeque<[u8; 32]>, + capacity: NonZeroUsize, +} + +impl BlsCacheData { + pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) { + // If the cache is full, remove the oldest item. + if self.items.len() == self.capacity.get() { + if let Some(oldest_key) = self.insertions_order.pop_front() { + self.items.remove(&oldest_key); + } + } + self.items.insert(hash, pairing); + self.insertions_order.push_back(hash); + } +} + #[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))] #[derive(Debug)] pub struct BlsCache { - // sha256(pubkey + message) -> GTElement - cache: Mutex>, + cache: Mutex, } impl Default for BlsCache { fn default() -> Self { - Self::new(NonZeroUsize::new(50000).unwrap()) + Self::new(NonZeroUsize::new(50_000).unwrap()) } } @@ -39,18 +60,22 @@ impl Clone for BlsCache { } impl BlsCache { - pub fn new(cache_size: NonZeroUsize) -> Self { + pub fn new(capacity: NonZeroUsize) -> Self { Self { - cache: Mutex::new(LruCache::new(cache_size)), + cache: Mutex::new(BlsCacheData { + items: HashMap::new(), + insertions_order: VecDeque::new(), + capacity, + }), } } pub fn len(&self) -> usize { - self.cache.lock().expect("cache").len() + self.cache.lock().expect("cache").items.len() } pub fn is_empty(&self) -> bool { - self.cache.lock().expect("cache").is_empty() + self.cache.lock().expect("cache").items.is_empty() } pub fn aggregate_verify, Msg: AsRef<[u8]>>( @@ -58,6 +83,7 @@ impl BlsCache { pks_msgs: impl IntoIterator, sig: &Signature, ) -> bool { + let mut hashes_to_remove = HashSet::new(); let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement { // Hash pubkey + message let mut hasher = Sha256::new(); @@ -67,7 +93,8 @@ impl BlsCache { let hash: [u8; 32] = hasher.finalize(); // If the pairing is in the cache, we don't need to recalculate it. - if let Some(pairing) = self.cache.lock().expect("cache").get(&hash).cloned() { + if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() { + hashes_to_remove.insert(hash); return pairing; } @@ -79,7 +106,16 @@ impl BlsCache { pairing }); - aggregate_verify_gt(sig, iter) + let is_valid = aggregate_verify_gt(sig, iter); + if is_valid { + // Evict cache hit entries on successful validation. + let mut c = self.cache.lock().expect("cache"); + for hash in &hashes_to_remove { + c.items.remove(hash); + c.insertions_order.retain(|h| h != hash); + } + } + is_valid } pub fn update(&mut self, aug_msg: &[u8], gt: GTElement) { @@ -148,7 +184,7 @@ impl BlsCache { use pyo3::types::PyBytes; let ret = PyList::empty_bound(py); let c = self.cache.lock().expect("cache"); - for (key, value) in &*c { + for (key, value) in &c.items { ret.append((PyBytes::new_bound(py, key), value.clone().into_py(py)))?; } Ok(ret.into()) @@ -194,9 +230,9 @@ pub mod tests { assert!(bls_cache.aggregate_verify(pks_msgs, &sig)); assert_eq!(bls_cache.len(), 1); - // Now that it's cached, it shouldn't cache it again. + // Now that it's cached, if we hit it, it gets removed. assert!(bls_cache.aggregate_verify(pks_msgs, &sig)); - assert_eq!(bls_cache.len(), 1); + assert_eq!(bls_cache.len(), 0); } #[test] @@ -226,7 +262,8 @@ pub mod tests { pks_msgs.push((pk2, msg2)); assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig)); - assert_eq!(bls_cache.len(), 2); + // We should have added the second and removed the first (cache hit) + assert_eq!(bls_cache.len(), 1); // Try reusing a public key. let msg3 = [108; 32]; @@ -236,6 +273,14 @@ pub mod tests { // Verify this signature and add to the cache as well (since it's still a different aggregate). assert!(bls_cache.aggregate_verify(pks_msgs, &agg_sig)); + assert_eq!(bls_cache.len(), 2); + + // Verify that cache hits are not removed when verification fails. + assert!(!bls_cache.aggregate_verify( + vec![(pk2, msg3), (PublicKey::default(), msg3)], + &Signature::default() + )); + // We added the new one but didn't remove the old despite its cache hit. assert_eq!(bls_cache.len(), 3); } @@ -261,7 +306,7 @@ pub mod tests { } // The cache should be full now. - assert_eq!(bls_cache.cache.lock().expect("cache").len(), 3); + assert_eq!(bls_cache.len(), 3); // Recreate first key. let sk = SecretKey::from_seed(&[1; 32]); @@ -275,7 +320,9 @@ pub mod tests { let hash: [u8; 32] = hasher.finalize(); // The first key should have been removed, since it's the oldest that's been accessed. - assert!(!bls_cache.cache.lock().expect("cache").contains(&hash)); + let c = bls_cache.cache.lock().expect("cache"); + assert!(!c.items.contains_key(&hash)); + assert!(!c.insertions_order.contains(&hash)); } #[test]