diff --git a/tests/merkle_set.py b/tests/merkle_set.py index e71ffad9..75fa456b 100644 --- a/tests/merkle_set.py +++ b/tests/merkle_set.py @@ -45,7 +45,7 @@ MIDDLE = bytes([2]) TRUNCATED = bytes([3]) -BLANK = bytes32([0] * 32) +BLANK = bytes32.zeros prehashed: Dict[bytes, _Hash] = {} diff --git a/tests/test_merkle_set.py b/tests/test_merkle_set.py index 08034a28..3f42d58c 100644 --- a/tests/test_merkle_set.py +++ b/tests/test_merkle_set.py @@ -66,7 +66,7 @@ def check_tree(leafs: List[bytes32]) -> None: ) for i in range(256): - item = bytes32([i] + [2] * 31) + item = bytes32.fill(bytes([i]), fill=b"\x02", align="<") py_included, py_proof = py_tree.is_included_already_hashed(item) assert not py_included ru_included, ru_proof = ru_tree.is_included_already_hashed(item) diff --git a/tests/test_sized_bytes.py b/tests/test_sized_bytes.py new file mode 100644 index 00000000..ec357600 --- /dev/null +++ b/tests/test_sized_bytes.py @@ -0,0 +1,42 @@ +import pytest + +from chia_rs.sized_bytes import bytes8 + + +def test_fill_empty() -> None: + assert bytes8.fill(b"", b"\x01") == bytes8([1, 1, 1, 1, 1, 1, 1, 1]) + + +def test_fill_non_empty_with_single() -> None: + assert bytes8.fill(b"\x02", b"\x01") == bytes8([1, 1, 1, 1, 1, 1, 1, 2]) + + +def test_fill_non_empty_with_double() -> None: + assert bytes8.fill(b"\x02\x02", b"\x01\x01") == bytes8([1, 1, 1, 1, 1, 1, 2, 2]) + + +def test_fill_needed_with_0_length_fill_raises() -> None: + with pytest.raises(ValueError): + bytes8.fill(b"\x00", fill=b"") + + +def test_fill_not_needed_with_0_length_fill_works() -> None: + blob = b"\x00" * 8 + assert bytes8.fill(blob, fill=b"") == bytes8(blob) + + +def test_fill_not_multiple_raises() -> None: + with pytest.raises(ValueError): + bytes8.fill(b"\x00", fill=b"\x01\x01") + + +def test_align_left() -> None: + assert bytes8.fill(b"\x01", fill=b"\x02", align="<") == bytes8( + [1, 2, 2, 2, 2, 2, 2, 2] + ) + + +def test_invalid_alignment() -> None: + with pytest.raises(ValueError): + # type ignore since we are intentionally testing a bad case + bytes8.fill(b"", fill=b"\x00", align="|") # type: ignore[arg-type] diff --git a/wheel/python/chia_rs/sized_byte_class.py b/wheel/python/chia_rs/sized_byte_class.py index 6dcbeb11..d81d5449 100644 --- a/wheel/python/chia_rs/sized_byte_class.py +++ b/wheel/python/chia_rs/sized_byte_class.py @@ -5,6 +5,7 @@ from typing import ( BinaryIO, Iterable, + Literal, Optional, SupportsBytes, SupportsIndex, @@ -78,6 +79,27 @@ def random( def secret(cls: Type[_T_SizedBytes]) -> _T_SizedBytes: return cls.random(r=system_random) + @classmethod + def fill(cls: Type[_T_SizedBytes], blob: bytes, fill: bytes, align: Literal["<", ">"] = ">") -> _T_SizedBytes: + if len(blob) == cls._size: + return cls(blob) + + fill_length = len(fill) + if fill_length == 0: + raise ValueError("fill required but length is zero") + + div, mod = divmod(cls._size - len(blob), fill_length) + if mod != 0: + raise ValueError("invalid fill value, range to be filled must be multiple of fil size") + + all_fill = fill * div + if align == "<": + return cls(blob + all_fill) + elif align == ">": + return cls(all_fill + blob) + + raise ValueError(f"invalid alignment: {align!r}") + def __str__(self) -> str: return self.hex()