From 4d429958caf5862deb07c8e6925e8a8c7940c398 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 18 Oct 2024 13:12:55 -0400 Subject: [PATCH] more --- crates/chia-datalayer/src/lib.rs | 2 +- crates/chia-datalayer/src/merkle.rs | 197 +++++++++++++++++++++++----- wheel/generate_type_stubs.py | 4 +- wheel/python/chia_rs/chia_rs.pyi | 4 +- wheel/src/api.rs | 5 +- 5 files changed, 177 insertions(+), 35 deletions(-) diff --git a/crates/chia-datalayer/src/lib.rs b/crates/chia-datalayer/src/lib.rs index 7e997f3c..f7512434 100644 --- a/crates/chia-datalayer/src/lib.rs +++ b/crates/chia-datalayer/src/lib.rs @@ -1,3 +1,3 @@ mod merkle; -pub use merkle::MerkleBlob; +pub use merkle::{InsertLocation, MerkleBlob, Node, Side}; diff --git a/crates/chia-datalayer/src/merkle.rs b/crates/chia-datalayer/src/merkle.rs index 9911118a..0d7961bd 100644 --- a/crates/chia-datalayer/src/merkle.rs +++ b/crates/chia-datalayer/src/merkle.rs @@ -2,7 +2,10 @@ use pyo3::{buffer::PyBuffer, exceptions::PyValueError, pyclass, pymethods, PyResult}; use clvmr::sha2::Sha256; +use clvmr::SExp; use num_traits::ToBytes; +use pyo3::types::PyTuple; +use pyo3::{PyObject, Python}; use std::cmp::Ordering; use std::collections::{HashMap, HashSet, VecDeque}; use std::iter::{zip, IntoIterator}; @@ -97,16 +100,21 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash { hasher.finalize() } +#[cfg_attr(feature = "py-bindings", pyclass(name = "Side", eq, eq_int))] #[derive(Clone, Debug, Hash, Eq, PartialEq)] pub enum Side { Left, Right, } +#[cfg_attr(feature = "py-bindings", pyclass(name = "InsertLocation"))] #[derive(Clone, Debug, Hash, Eq, PartialEq)] pub enum InsertLocation { - Auto, - AsRoot, + // error: Unit variant `Auto` is not yet supported in a complex enum + // = help: change to a struct variant with no fields: `Auto { }` + // = note: the enum is complex because of non-unit variant `Leaf` + Auto {}, + AsRoot {}, Leaf { index: TreeIndex, side: Side }, } @@ -148,16 +156,28 @@ impl NodeMetadata { } } +#[cfg_attr(feature = "py-bindings", pyclass(name = "Node"))] #[derive(Debug, PartialEq)] pub struct Node { + // #[cfg_attr(feature = "py-bindings", pyo3(get))] + #[pyo3(get)] parent: Parent, + // #[cfg_attr(feature = "py-bindings", pyo3(get))] + #[pyo3(get)] hash: Hash, + // #[cfg_attr(feature = "py-bindings", pyo3(get))] + #[pyo3(get)] specific: NodeSpecific, } -#[derive(Debug, PartialEq)] +// #[cfg_attr(feature = "py-bindings", pyclass(name = "NodeSpecific"))] +#[cfg_attr(feature = "py-bindings", pyclass(name = "NodeSpecific"))] +#[derive(Clone, Debug, PartialEq)] pub enum NodeSpecific { + // #[cfg_attr(feature = "py-bindings", pyo3(constructor = (left, right)))] + #[pyo3(constructor = (left, right))] Internal { left: TreeIndex, right: TreeIndex }, + #[pyo3(constructor = (key, value))] Leaf { key: KvId, value: KvId }, } @@ -244,6 +264,46 @@ impl Node { } } +#[cfg(feature = "py-bindings")] +#[pymethods] +impl Node { + #[getter(left)] + pub fn left(&self) -> TreeIndex { + let NodeSpecific::Internal { left, .. } = self.specific else { + panic!(); + }; + + left + } + + #[getter(right)] + pub fn right(&self) -> TreeIndex { + let NodeSpecific::Internal { right, .. } = self.specific else { + panic!(); + }; + + right + } + + #[getter(key)] + pub fn key(&self) -> KvId { + let NodeSpecific::Leaf { key, .. } = self.specific else { + panic!(); + }; + + key + } + + #[getter(value)] + pub fn value(&self) -> KvId { + let NodeSpecific::Leaf { value, .. } = self.specific else { + panic!(); + }; + + value + } +} + fn block_range(index: TreeIndex) -> Range { let block_start = index as usize * BLOCK_SIZE; block_start..block_start + BLOCK_SIZE @@ -311,6 +371,8 @@ fn get_free_indexes_and_keys_values_indexes( pub struct MerkleBlob { blob: Vec, free_indexes: HashSet, + #[pyo3(get)] + // #[cfg_attr(feature = "py-bindings", pyo3(get))] key_to_index: HashMap, } @@ -346,16 +408,20 @@ impl MerkleBlob { hash: &Hash, insert_location: InsertLocation, ) -> Result<(), String> { + if self.key_to_index.contains_key(&key) { + return Err("Key already present".to_string()); + } + let insert_location = match insert_location { - InsertLocation::Auto => self.get_random_insert_location_by_kvid(key)?, + InsertLocation::Auto {} => self.get_random_insert_location_by_kvid(key)?, _ => insert_location, }; match insert_location { - InsertLocation::Auto => { + InsertLocation::Auto {} => { panic!("this should have been caught and processed above") } - InsertLocation::AsRoot => { + InsertLocation::AsRoot {} => { if !self.key_to_index.is_empty() { return Err("requested insertion at root but tree not empty".to_string()); }; @@ -630,7 +696,7 @@ impl MerkleBlob { pub fn upsert(&mut self, key: KvId, value: KvId, new_hash: &Hash) -> Result<(), String> { let Some(leaf_index) = self.key_to_index.get(&key) else { - self.insert(key, value, new_hash, InsertLocation::Auto)?; + self.insert(key, value, new_hash, InsertLocation::Auto {})?; return Ok(()); }; @@ -763,7 +829,7 @@ impl MerkleBlob { let mut seed_bytes = Vec::from(seed_bytes); if self.blob.is_empty() { - return Ok(InsertLocation::AsRoot); + return Ok(InsertLocation::AsRoot {}); } let side = if (seed_bytes @@ -1047,28 +1113,88 @@ impl MerkleBlob { let slice = unsafe { std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes()) }; - match Self::new(Vec::from(slice)) { - Ok(blob) => Ok(blob), - Err(message) => Err(PyValueError::new_err(message)), - } + Self::new(Vec::from(slice)).map_err(|e| PyValueError::new_err(e)) } - #[pyo3(name = "insert")] - pub fn py_insert(&mut self, key: KvId, value: KvId, hash: Hash) -> PyResult<()> { - if let Err(message) = self.insert(key, value, &hash, InsertLocation::Auto) { - Err(PyValueError::new_err(message)) - } else { - Ok(()) - } + #[pyo3(name = "insert", signature = (key, value, hash, reference_kid = None, side = None))] + pub fn py_insert( + &mut self, + key: KvId, + value: KvId, + hash: Hash, + reference_kid: Option, + // TODO: should be a Side, but python has a different Side right now + side: Option, + ) -> PyResult<()> { + let insert_location = match (reference_kid, side) { + (None, None) => InsertLocation::Auto {}, + (Some(key), Some(side)) => InsertLocation::Leaf { + index: *self + .key_to_index + .get(&key) + .ok_or(PyValueError::new_err("TODO: better message here"))?, + side: match side { + // TODO: if this sticks around, we gotta get more formal about the mapping + 0 => Side::Left, + 1 => Side::Right, + _ => panic!(), + }, + }, + _ => { + return Err(PyValueError::new_err( + "must specify neither or both of reference_kid and side", + )); + } + }; + self.insert(key, value, &hash, insert_location) + .map_err(|e| PyValueError::new_err(e)) } #[pyo3(name = "delete")] pub fn py_delete(&mut self, key: KvId) -> PyResult<()> { - if let Err(message) = self.delete(key) { - Err(PyValueError::new_err(message)) - } else { - Ok(()) + self.delete(key).map_err(|e| PyValueError::new_err(e)) + } + + #[pyo3(name = "get_raw_node")] + pub fn py_get_raw_node(&mut self, index: TreeIndex) -> PyResult { + self.get_node(index).map_err(|e| PyValueError::new_err(e)) + } + + #[pyo3(name = "calculate_lazy_hashes")] + pub fn py_calculate_lazy_hashes(&mut self) -> PyResult<()> { + self.calculate_lazy_hashes() + .map_err(|e| PyValueError::new_err(e)) + } + + #[pyo3(name = "get_lineage")] + pub fn py_get_lineage(&self, index: TreeIndex, py: Python<'_>) -> PyResult { + let list = pyo3::types::PyList::empty_bound(py); + + for node in self + .get_lineage(index) + .map_err(|e| PyValueError::new_err(e))? + { + use pyo3::conversion::IntoPy; + use pyo3::types::PyListMethods; + list.append(node.into_py(py))?; } + + Ok(list.into()) + } + + #[pyo3(name = "get_nodes", signature=(index=0))] + pub fn py_get_nodes(&self, py: Python<'_>, index: TreeIndex) -> PyResult { + // TODO: use the index parameter + + let list = pyo3::types::PyList::empty_bound(py); + + for (_, block) in self { + use pyo3::conversion::IntoPy; + use pyo3::types::PyListMethods; + list.append(block.node.into_py(py))?; + } + + Ok(list.into()) } #[pyo3(name = "__len__")] @@ -1289,7 +1415,7 @@ mod tests { 0x0001_0203_0405_0607, 0x1011_1213_1415_1617, &sha256_num(0x1020), - InsertLocation::Auto, + InsertLocation::Auto {}, ) .unwrap(); @@ -1297,7 +1423,7 @@ mod tests { 0x2021_2223_2425_2627, 0x3031_3233_3435_3637, &sha256_num(0x2030), - InsertLocation::Auto, + InsertLocation::Auto {}, ) .unwrap(); @@ -1400,7 +1526,7 @@ mod tests { let start = Instant::now(); merkle_blob // NOTE: yeah this hash is garbage - .insert(i, i, &sha256_num(i), InsertLocation::Auto) + .insert(i, i, &sha256_num(i), InsertLocation::Auto {}) .unwrap(); let end = Instant::now(); total_time += end.duration_since(start); @@ -1431,7 +1557,7 @@ mod tests { merkle_blob.calculate_lazy_hashes().unwrap(); reference_blobs.push(MerkleBlob::new(merkle_blob.blob.clone()).unwrap()); merkle_blob - .insert(key_value_id, key_value_id, &hash, InsertLocation::Auto) + .insert(key_value_id, key_value_id, &hash, InsertLocation::Auto {}) .unwrap(); dots.push(merkle_blob.to_dot().dump()); } @@ -1460,7 +1586,7 @@ mod tests { key_value_id, key_value_id, &sha256_num(key_value_id), - InsertLocation::Auto, + InsertLocation::Auto {}, ) .unwrap(); open_dot(merkle_blob.to_dot().set_note("first after")); @@ -1481,7 +1607,7 @@ mod tests { let key: KvId = i as KvId; open_dot(merkle_blob.to_dot().set_note("empty")); merkle_blob - .insert(key, key, &sha256_num(key), InsertLocation::Auto) + .insert(key, key, &sha256_num(key), InsertLocation::Auto {}) .unwrap(); last_key = key; } @@ -1539,7 +1665,7 @@ mod tests { key_value_id, key_value_id, &sha256_num(key_value_id), - InsertLocation::Auto, + InsertLocation::Auto {}, ) .unwrap(); open_dot(merkle_blob.to_dot().set_note("first after")); @@ -1626,7 +1752,7 @@ mod tests { let mut insert_blob = MerkleBlob::new(small_blob.blob.clone()).unwrap(); insert_blob - .insert(key, value, &sha256_num(key), InsertLocation::Auto) + .insert(key, value, &sha256_num(key), InsertLocation::Auto {}) .unwrap(); open_dot(insert_blob.to_dot().set_note("first after")); @@ -1683,4 +1809,13 @@ mod tests { } } } + + #[test] + fn test_double_insert_fails() { + let mut blob = MerkleBlob::new(vec![]).unwrap(); + blob.insert(0, 0, &[0u8; 32], InsertLocation::Auto {}) + .unwrap(); + blob.insert(0, 0, &[0u8; 32], InsertLocation::Auto {}) + .expect_err(""); + } } diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index 03580c9c..df91a6cb 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -386,8 +386,10 @@ def __init__( blob: bytes, ) -> None: ... - def insert(self, key: uint64, value: uint64, hash: bytes32) -> None: ... + def insert(self, key: uint64, value: uint64, hash: bytes32, reference_kid: uint64, side: Side) -> None: ... def delete(self, key: uint64) -> None: ... + # def get_raw_node(self, index: int) -> Node: ... + def calculate_lazy_hashes(self) -> None: ... def __len__(self) -> int: ... @final diff --git a/wheel/python/chia_rs/chia_rs.pyi b/wheel/python/chia_rs/chia_rs.pyi index 6edcac16..d6332e13 100644 --- a/wheel/python/chia_rs/chia_rs.pyi +++ b/wheel/python/chia_rs/chia_rs.pyi @@ -128,8 +128,10 @@ class MerkleBlob: blob: bytes, ) -> None: ... - def insert(self, key: uint64, value: uint64, hash: bytes32) -> None: ... + def insert(self, key: uint64, value: uint64, hash: bytes32, reference_kid: uint64, side: Side) -> None: ... def delete(self, key: uint64) -> None: ... + # def get_raw_node(self, index: int) -> Node: ... + def calculate_lazy_hashes(self) -> None: ... def __len__(self) -> int: ... @final diff --git a/wheel/src/api.rs b/wheel/src/api.rs index c2246282..91eab80f 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -75,7 +75,7 @@ use chia_bls::{ Signature, }; -use chia_datalayer::MerkleBlob; +use chia_datalayer::{MerkleBlob, Node}; #[pyfunction] pub fn compute_merkle_set_root<'p>( @@ -477,6 +477,9 @@ pub fn chia_rs(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { // datalayer m.add_class::()?; + // m.add_class::()?; + // m.add_class::()?; + m.add_class::()?; // merkle tree m.add_class::()?;