Skip to content

Commit

Permalink
Use a cache eviction logic that matches our BlsCache usage pattern.
Browse files Browse the repository at this point in the history
  • Loading branch information
AmineKhaldi committed Oct 17, 2024
1 parent beca98c commit 74b8c95
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 19 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ hex = "0.4.3"
thiserror = "1.0.63"
pyo3 = "0.21.2"
arbitrary = "1.3.2"
lru = "0.12.4"
rand = "0.8.5"
criterion = "0.5.1"
rstest = "0.22.0"
Expand Down
1 change: 0 additions & 1 deletion crates/chia-bls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ hex = { workspace = true }
thiserror = { workspace = true }
pyo3 = { workspace = true, features = ["multiple-pymethods"], optional = true }
arbitrary = { workspace = true, optional = true }
lru = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
Expand Down
79 changes: 63 additions & 16 deletions crates/chia-bls/src/bls_cache.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::borrow::Borrow;
use std::collections::{HashMap, HashSet, VecDeque};
use std::num::NonZeroUsize;

use chia_sha2::Sha256;
use lru::LruCache;
use std::sync::Mutex;

use crate::{aggregate_verify_gt, hash_to_g2};
Expand All @@ -17,16 +17,37 @@ use crate::{GTElement, PublicKey, Signature};
/// However, validating a signature where we have no cached GT elements, the
/// aggregate_verify() primitive is faster. When long-syncing, that's
/// preferable.

#[derive(Debug, Clone)]
struct BlsCacheData {
// sha256(pubkey + message) -> GTElement
items: HashMap<[u8; 32], GTElement>,
insertions_order: VecDeque<[u8; 32]>,
capacity: NonZeroUsize,
}

impl BlsCacheData {
pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) {
// If the cache is full, remove the oldest item.
if self.items.len() == self.capacity.get() {
if let Some(oldest_key) = self.insertions_order.pop_front() {
self.items.remove(&oldest_key);
}
}
self.items.insert(hash, pairing);
self.insertions_order.push_back(hash);
}
}

#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
#[derive(Debug)]
pub struct BlsCache {
// sha256(pubkey + message) -> GTElement
cache: Mutex<LruCache<[u8; 32], GTElement>>,
cache: Mutex<BlsCacheData>,
}

impl Default for BlsCache {
fn default() -> Self {
Self::new(NonZeroUsize::new(50000).unwrap())
Self::new(NonZeroUsize::new(50_000).unwrap())
}
}

Expand All @@ -39,25 +60,30 @@ impl Clone for BlsCache {
}

impl BlsCache {
pub fn new(cache_size: NonZeroUsize) -> Self {
pub fn new(capacity: NonZeroUsize) -> Self {
Self {
cache: Mutex::new(LruCache::new(cache_size)),
cache: Mutex::new(BlsCacheData {
items: HashMap::new(),
insertions_order: VecDeque::new(),
capacity,
}),
}
}

pub fn len(&self) -> usize {
self.cache.lock().expect("cache").len()
self.cache.lock().expect("cache").items.len()
}

pub fn is_empty(&self) -> bool {
self.cache.lock().expect("cache").is_empty()
self.cache.lock().expect("cache").items.is_empty()
}

pub fn aggregate_verify<Pk: Borrow<PublicKey>, Msg: AsRef<[u8]>>(
&mut self,
pks_msgs: impl IntoIterator<Item = (Pk, Msg)>,
sig: &Signature,
) -> bool {
let mut hashes_to_remove = HashSet::new();
let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement {
// Hash pubkey + message
let mut hasher = Sha256::new();
Expand All @@ -67,7 +93,8 @@ impl BlsCache {
let hash: [u8; 32] = hasher.finalize();

// If the pairing is in the cache, we don't need to recalculate it.
if let Some(pairing) = self.cache.lock().expect("cache").get(&hash).cloned() {
if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() {
hashes_to_remove.insert(hash);
return pairing;
}

Expand All @@ -79,7 +106,16 @@ impl BlsCache {
pairing
});

aggregate_verify_gt(sig, iter)
let is_valid = aggregate_verify_gt(sig, iter);
if is_valid {
// Evict cache hit entries on successful validation.
let mut c = self.cache.lock().expect("cache");
for hash in &hashes_to_remove {
c.items.remove(hash);
c.insertions_order.retain(|h| h != hash);
}
}
is_valid
}

pub fn update(&mut self, aug_msg: &[u8], gt: GTElement) {
Expand Down Expand Up @@ -148,7 +184,7 @@ impl BlsCache {
use pyo3::types::PyBytes;
let ret = PyList::empty_bound(py);
let c = self.cache.lock().expect("cache");
for (key, value) in &*c {
for (key, value) in &c.items {
ret.append((PyBytes::new_bound(py, key), value.clone().into_py(py)))?;
}
Ok(ret.into())
Expand Down Expand Up @@ -194,9 +230,9 @@ pub mod tests {
assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
assert_eq!(bls_cache.len(), 1);

// Now that it's cached, it shouldn't cache it again.
// Now that it's cached, if we hit it, it gets removed.
assert!(bls_cache.aggregate_verify(pks_msgs, &sig));
assert_eq!(bls_cache.len(), 1);
assert_eq!(bls_cache.len(), 0);
}

#[test]
Expand Down Expand Up @@ -226,7 +262,8 @@ pub mod tests {
pks_msgs.push((pk2, msg2));

assert!(bls_cache.aggregate_verify(pks_msgs.clone(), &agg_sig));
assert_eq!(bls_cache.len(), 2);
// We should have added the second and removed the first (cache hit)
assert_eq!(bls_cache.len(), 1);

// Try reusing a public key.
let msg3 = [108; 32];
Expand All @@ -236,6 +273,14 @@ pub mod tests {

// Verify this signature and add to the cache as well (since it's still a different aggregate).
assert!(bls_cache.aggregate_verify(pks_msgs, &agg_sig));
assert_eq!(bls_cache.len(), 2);

// Verify that cache hits are not removed when verification fails.
assert!(!bls_cache.aggregate_verify(
vec![(pk2, msg3), (PublicKey::default(), msg3)],
&Signature::default()
));
// We added the new one but didn't remove the old despite its cache hit.
assert_eq!(bls_cache.len(), 3);
}

Expand All @@ -261,7 +306,7 @@ pub mod tests {
}

// The cache should be full now.
assert_eq!(bls_cache.cache.lock().expect("cache").len(), 3);
assert_eq!(bls_cache.len(), 3);

// Recreate first key.
let sk = SecretKey::from_seed(&[1; 32]);
Expand All @@ -275,7 +320,9 @@ pub mod tests {
let hash: [u8; 32] = hasher.finalize();

// The first key should have been removed, since it's the oldest that's been accessed.
assert!(!bls_cache.cache.lock().expect("cache").contains(&hash));
let c = bls_cache.cache.lock().expect("cache");
assert!(!c.items.contains_key(&hash));
assert!(!c.insertions_order.contains(&hash));
}

#[test]
Expand Down

0 comments on commit 74b8c95

Please sign in to comment.