Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Oct 18, 2024
1 parent f22ac25 commit 4d42995
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 35 deletions.
2 changes: 1 addition & 1 deletion crates/chia-datalayer/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mod merkle;

pub use merkle::MerkleBlob;
pub use merkle::{InsertLocation, MerkleBlob, Node, Side};
197 changes: 166 additions & 31 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
use pyo3::{buffer::PyBuffer, exceptions::PyValueError, pyclass, pymethods, PyResult};

use clvmr::sha2::Sha256;
use clvmr::SExp;
use num_traits::ToBytes;
use pyo3::types::PyTuple;
use pyo3::{PyObject, Python};
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet, VecDeque};
use std::iter::{zip, IntoIterator};
Expand Down Expand Up @@ -97,16 +100,21 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash {
hasher.finalize()
}

#[cfg_attr(feature = "py-bindings", pyclass(name = "Side", eq, eq_int))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub enum Side {
Left,
Right,
}

#[cfg_attr(feature = "py-bindings", pyclass(name = "InsertLocation"))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub enum InsertLocation {
Auto,
AsRoot,
// error: Unit variant `Auto` is not yet supported in a complex enum
// = help: change to a struct variant with no fields: `Auto { }`
// = note: the enum is complex because of non-unit variant `Leaf`
Auto {},
AsRoot {},
Leaf { index: TreeIndex, side: Side },
}

Expand Down Expand Up @@ -148,16 +156,28 @@ impl NodeMetadata {
}
}

#[cfg_attr(feature = "py-bindings", pyclass(name = "Node"))]
#[derive(Debug, PartialEq)]
pub struct Node {
// #[cfg_attr(feature = "py-bindings", pyo3(get))]
#[pyo3(get)]
parent: Parent,
// #[cfg_attr(feature = "py-bindings", pyo3(get))]
#[pyo3(get)]
hash: Hash,
// #[cfg_attr(feature = "py-bindings", pyo3(get))]
#[pyo3(get)]
specific: NodeSpecific,
}

#[derive(Debug, PartialEq)]
// #[cfg_attr(feature = "py-bindings", pyclass(name = "NodeSpecific"))]
#[cfg_attr(feature = "py-bindings", pyclass(name = "NodeSpecific"))]
#[derive(Clone, Debug, PartialEq)]
pub enum NodeSpecific {
// #[cfg_attr(feature = "py-bindings", pyo3(constructor = (left, right)))]
#[pyo3(constructor = (left, right))]
Internal { left: TreeIndex, right: TreeIndex },
#[pyo3(constructor = (key, value))]
Leaf { key: KvId, value: KvId },
}

Expand Down Expand Up @@ -244,6 +264,46 @@ impl Node {
}
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl Node {
#[getter(left)]
pub fn left(&self) -> TreeIndex {
let NodeSpecific::Internal { left, .. } = self.specific else {
panic!();
};

left
}

#[getter(right)]
pub fn right(&self) -> TreeIndex {
let NodeSpecific::Internal { right, .. } = self.specific else {
panic!();
};

right
}

#[getter(key)]
pub fn key(&self) -> KvId {
let NodeSpecific::Leaf { key, .. } = self.specific else {
panic!();
};

key
}

#[getter(value)]
pub fn value(&self) -> KvId {
let NodeSpecific::Leaf { value, .. } = self.specific else {
panic!();
};

value
}
}

fn block_range(index: TreeIndex) -> Range<usize> {
let block_start = index as usize * BLOCK_SIZE;
block_start..block_start + BLOCK_SIZE
Expand Down Expand Up @@ -311,6 +371,8 @@ fn get_free_indexes_and_keys_values_indexes(
pub struct MerkleBlob {
blob: Vec<u8>,
free_indexes: HashSet<TreeIndex>,
#[pyo3(get)]
// #[cfg_attr(feature = "py-bindings", pyo3(get))]
key_to_index: HashMap<KvId, TreeIndex>,
}

Expand Down Expand Up @@ -346,16 +408,20 @@ impl MerkleBlob {
hash: &Hash,
insert_location: InsertLocation,
) -> Result<(), String> {
if self.key_to_index.contains_key(&key) {
return Err("Key already present".to_string());
}

let insert_location = match insert_location {
InsertLocation::Auto => self.get_random_insert_location_by_kvid(key)?,
InsertLocation::Auto {} => self.get_random_insert_location_by_kvid(key)?,
_ => insert_location,
};

match insert_location {
InsertLocation::Auto => {
InsertLocation::Auto {} => {
panic!("this should have been caught and processed above")
}
InsertLocation::AsRoot => {
InsertLocation::AsRoot {} => {
if !self.key_to_index.is_empty() {
return Err("requested insertion at root but tree not empty".to_string());
};
Expand Down Expand Up @@ -630,7 +696,7 @@ impl MerkleBlob {

pub fn upsert(&mut self, key: KvId, value: KvId, new_hash: &Hash) -> Result<(), String> {
let Some(leaf_index) = self.key_to_index.get(&key) else {
self.insert(key, value, new_hash, InsertLocation::Auto)?;
self.insert(key, value, new_hash, InsertLocation::Auto {})?;
return Ok(());
};

Expand Down Expand Up @@ -763,7 +829,7 @@ impl MerkleBlob {
let mut seed_bytes = Vec::from(seed_bytes);

if self.blob.is_empty() {
return Ok(InsertLocation::AsRoot);
return Ok(InsertLocation::AsRoot {});
}

let side = if (seed_bytes
Expand Down Expand Up @@ -1047,28 +1113,88 @@ impl MerkleBlob {
let slice =
unsafe { std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes()) };

match Self::new(Vec::from(slice)) {
Ok(blob) => Ok(blob),
Err(message) => Err(PyValueError::new_err(message)),
}
Self::new(Vec::from(slice)).map_err(|e| PyValueError::new_err(e))
}

#[pyo3(name = "insert")]
pub fn py_insert(&mut self, key: KvId, value: KvId, hash: Hash) -> PyResult<()> {
if let Err(message) = self.insert(key, value, &hash, InsertLocation::Auto) {
Err(PyValueError::new_err(message))
} else {
Ok(())
}
#[pyo3(name = "insert", signature = (key, value, hash, reference_kid = None, side = None))]
pub fn py_insert(
&mut self,
key: KvId,
value: KvId,
hash: Hash,
reference_kid: Option<KvId>,
// TODO: should be a Side, but python has a different Side right now
side: Option<u8>,
) -> PyResult<()> {
let insert_location = match (reference_kid, side) {
(None, None) => InsertLocation::Auto {},
(Some(key), Some(side)) => InsertLocation::Leaf {
index: *self
.key_to_index
.get(&key)
.ok_or(PyValueError::new_err("TODO: better message here"))?,
side: match side {
// TODO: if this sticks around, we gotta get more formal about the mapping
0 => Side::Left,
1 => Side::Right,
_ => panic!(),
},
},
_ => {
return Err(PyValueError::new_err(
"must specify neither or both of reference_kid and side",
));
}
};
self.insert(key, value, &hash, insert_location)
.map_err(|e| PyValueError::new_err(e))
}

#[pyo3(name = "delete")]
pub fn py_delete(&mut self, key: KvId) -> PyResult<()> {
if let Err(message) = self.delete(key) {
Err(PyValueError::new_err(message))
} else {
Ok(())
self.delete(key).map_err(|e| PyValueError::new_err(e))
}

#[pyo3(name = "get_raw_node")]
pub fn py_get_raw_node(&mut self, index: TreeIndex) -> PyResult<Node> {
self.get_node(index).map_err(|e| PyValueError::new_err(e))
}

#[pyo3(name = "calculate_lazy_hashes")]
pub fn py_calculate_lazy_hashes(&mut self) -> PyResult<()> {
self.calculate_lazy_hashes()
.map_err(|e| PyValueError::new_err(e))
}

#[pyo3(name = "get_lineage")]
pub fn py_get_lineage(&self, index: TreeIndex, py: Python<'_>) -> PyResult<pyo3::PyObject> {
let list = pyo3::types::PyList::empty_bound(py);

for node in self
.get_lineage(index)
.map_err(|e| PyValueError::new_err(e))?
{
use pyo3::conversion::IntoPy;
use pyo3::types::PyListMethods;
list.append(node.into_py(py))?;
}

Ok(list.into())
}

#[pyo3(name = "get_nodes", signature=(index=0))]
pub fn py_get_nodes(&self, py: Python<'_>, index: TreeIndex) -> PyResult<pyo3::PyObject> {
// TODO: use the index parameter

let list = pyo3::types::PyList::empty_bound(py);

for (_, block) in self {
use pyo3::conversion::IntoPy;
use pyo3::types::PyListMethods;
list.append(block.node.into_py(py))?;
}

Ok(list.into())
}

#[pyo3(name = "__len__")]
Expand Down Expand Up @@ -1289,15 +1415,15 @@ mod tests {
0x0001_0203_0405_0607,
0x1011_1213_1415_1617,
&sha256_num(0x1020),
InsertLocation::Auto,
InsertLocation::Auto {},
)
.unwrap();

blob.insert(
0x2021_2223_2425_2627,
0x3031_3233_3435_3637,
&sha256_num(0x2030),
InsertLocation::Auto,
InsertLocation::Auto {},
)
.unwrap();

Expand Down Expand Up @@ -1400,7 +1526,7 @@ mod tests {
let start = Instant::now();
merkle_blob
// NOTE: yeah this hash is garbage
.insert(i, i, &sha256_num(i), InsertLocation::Auto)
.insert(i, i, &sha256_num(i), InsertLocation::Auto {})
.unwrap();
let end = Instant::now();
total_time += end.duration_since(start);
Expand Down Expand Up @@ -1431,7 +1557,7 @@ mod tests {
merkle_blob.calculate_lazy_hashes().unwrap();
reference_blobs.push(MerkleBlob::new(merkle_blob.blob.clone()).unwrap());
merkle_blob
.insert(key_value_id, key_value_id, &hash, InsertLocation::Auto)
.insert(key_value_id, key_value_id, &hash, InsertLocation::Auto {})
.unwrap();
dots.push(merkle_blob.to_dot().dump());
}
Expand Down Expand Up @@ -1460,7 +1586,7 @@ mod tests {
key_value_id,
key_value_id,
&sha256_num(key_value_id),
InsertLocation::Auto,
InsertLocation::Auto {},
)
.unwrap();
open_dot(merkle_blob.to_dot().set_note("first after"));
Expand All @@ -1481,7 +1607,7 @@ mod tests {
let key: KvId = i as KvId;
open_dot(merkle_blob.to_dot().set_note("empty"));
merkle_blob
.insert(key, key, &sha256_num(key), InsertLocation::Auto)
.insert(key, key, &sha256_num(key), InsertLocation::Auto {})
.unwrap();
last_key = key;
}
Expand Down Expand Up @@ -1539,7 +1665,7 @@ mod tests {
key_value_id,
key_value_id,
&sha256_num(key_value_id),
InsertLocation::Auto,
InsertLocation::Auto {},
)
.unwrap();
open_dot(merkle_blob.to_dot().set_note("first after"));
Expand Down Expand Up @@ -1626,7 +1752,7 @@ mod tests {

let mut insert_blob = MerkleBlob::new(small_blob.blob.clone()).unwrap();
insert_blob
.insert(key, value, &sha256_num(key), InsertLocation::Auto)
.insert(key, value, &sha256_num(key), InsertLocation::Auto {})
.unwrap();
open_dot(insert_blob.to_dot().set_note("first after"));

Expand Down Expand Up @@ -1683,4 +1809,13 @@ mod tests {
}
}
}

#[test]
fn test_double_insert_fails() {
let mut blob = MerkleBlob::new(vec![]).unwrap();
blob.insert(0, 0, &[0u8; 32], InsertLocation::Auto {})
.unwrap();
blob.insert(0, 0, &[0u8; 32], InsertLocation::Auto {})
.expect_err("");
}
}
4 changes: 3 additions & 1 deletion wheel/generate_type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,10 @@ def __init__(
blob: bytes,
) -> None: ...
def insert(self, key: uint64, value: uint64, hash: bytes32) -> None: ...
def insert(self, key: uint64, value: uint64, hash: bytes32, reference_kid: uint64, side: Side) -> None: ...
def delete(self, key: uint64) -> None: ...
# def get_raw_node(self, index: int) -> Node: ...
def calculate_lazy_hashes(self) -> None: ...
def __len__(self) -> int: ...
@final
Expand Down
4 changes: 3 additions & 1 deletion wheel/python/chia_rs/chia_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ class MerkleBlob:
blob: bytes,
) -> None: ...

def insert(self, key: uint64, value: uint64, hash: bytes32) -> None: ...
def insert(self, key: uint64, value: uint64, hash: bytes32, reference_kid: uint64, side: Side) -> None: ...
def delete(self, key: uint64) -> None: ...
# def get_raw_node(self, index: int) -> Node: ...
def calculate_lazy_hashes(self) -> None: ...
def __len__(self) -> int: ...

@final
Expand Down
Loading

0 comments on commit 4d42995

Please sign in to comment.