diff --git a/bitcoin_client/ledger_bitcoin/__init__.py b/bitcoin_client/ledger_bitcoin/__init__.py index 4bbd22e14..4c4bd82be 100644 --- a/bitcoin_client/ledger_bitcoin/__init__.py +++ b/bitcoin_client/ledger_bitcoin/__init__.py @@ -7,7 +7,7 @@ from .wallet import AddressType, WalletPolicy, MultisigWallet, WalletType -__version__ = '0.2.2' +__version__ = '0.3.0' __all__ = [ "Client", diff --git a/bitcoin_client/ledger_bitcoin/bip380/README b/bitcoin_client/ledger_bitcoin/bip380/README deleted file mode 100644 index 3609525a4..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/README +++ /dev/null @@ -1,4 +0,0 @@ -This folder is based on https://github.com/Eunovo/python-bip380/tree/4226b7f2b70211d696155f6fd39edc611761ed0b, in turn built on https://github.com/darosior/python-bip380/commit/d2f5d8f5b41cba189bd793c1081e9d61d2d160c1. - -The library is "not ready for any real world use", however we _only_ use it in order to generate addresses for descriptors containing miniscript, and compare the result with the address computed by the device. -This is a generic mitigation for any bug related to address generation on the device, like [this](https://donjon.ledger.com/lsb/019/). diff --git a/bitcoin_client/ledger_bitcoin/bip380/__init__.py b/bitcoin_client/ledger_bitcoin/bip380/__init__.py deleted file mode 100644 index 27fdca497..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.3" diff --git a/bitcoin_client/ledger_bitcoin/bip380/descriptors/__init__.py b/bitcoin_client/ledger_bitcoin/bip380/descriptors/__init__.py deleted file mode 100644 index bc4eaac4d..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/descriptors/__init__.py +++ /dev/null @@ -1,220 +0,0 @@ -from ...bip380.key import DescriptorKey -from ...bip380.miniscript import Node -from ...bip380.utils.hashes import sha256, hash160 -from ...bip380.utils.script import ( - CScript, - OP_1, - OP_DUP, - OP_HASH160, - OP_EQUALVERIFY, - OP_CHECKSIG, -) - -from .checksum import descsum_create -from .errors import DescriptorParsingError -from .parsing import descriptor_from_str -from .utils import taproot_tweak - - -class Descriptor: - """A Bitcoin Output Script Descriptor.""" - - def from_str(desc_str, strict=False): - """Parse a Bitcoin Output Script Descriptor from its string representation. - - :param strict: whether to require the presence of a checksum. - """ - desc = descriptor_from_str(desc_str, strict) - - # BIP389 prescribes that no two multipath key expressions in a single descriptor - # have different length. - multipath_len = None - for key in desc.keys: - if key.is_multipath(): - m_len = len(key.path.paths) - if multipath_len is None: - multipath_len = m_len - elif multipath_len != m_len: - raise DescriptorParsingError( - f"Descriptor contains multipath key expressions with varying length: '{desc_str}'." - ) - - return desc - - @property - def script_pubkey(self): - """Get the ScriptPubKey (output 'locking' Script) for this descriptor.""" - # To be implemented by derived classes - raise NotImplementedError - - @property - def script_sighash(self): - """Get the Script to be committed to by the signature hash of a spending transaction.""" - # To be implemented by derived classes - raise NotImplementedError - - @property - def keys(self): - """Get the list of all keys from this descriptor, in order of apparition.""" - # To be implemented by derived classes - raise NotImplementedError - - def derive(self, index): - """Derive the key at the given derivation index. - - A no-op if the key isn't a wildcard. Will start from 2**31 if the key is a "hardened - wildcard". - """ - assert isinstance(index, int) - for key in self.keys: - key.derive(index) - - def satisfy(self, *args, **kwargs): - """Get the witness stack to spend from this descriptor. - - Various data may need to be passed as parameters to meet the locking - conditions set by the Script. - """ - # To be implemented by derived classes - raise NotImplementedError - - def copy(self): - """Get a copy of this descriptor.""" - # FIXME: do something nicer than roundtripping through string ser - return Descriptor.from_str(str(self)) - - def is_multipath(self): - """Whether this descriptor contains multipath key expression(s).""" - return any(k.is_multipath() for k in self.keys) - - def singlepath_descriptors(self): - """Get a list of descriptors that only contain keys that don't have multiple - derivation paths. - """ - singlepath_descs = [self.copy()] - - # First figure out the number of descriptors there will be - for key in self.keys: - if key.is_multipath(): - singlepath_descs += [ - self.copy() for _ in range(len(key.path.paths) - 1) - ] - break - - # Return early if there was no multipath key expression - if len(singlepath_descs) == 1: - return singlepath_descs - - # Then use one path for each - for i, desc in enumerate(singlepath_descs): - for key in desc.keys: - if key.is_multipath(): - assert len(key.path.paths) == len(singlepath_descs) - key.path.paths = key.path.paths[i: i + 1] - - assert all(not d.is_multipath() for d in singlepath_descs) - return singlepath_descs - - -# TODO: add methods to give access to all the Miniscript analysis -class WshDescriptor(Descriptor): - """A Segwit v0 P2WSH Output Script Descriptor.""" - - def __init__(self, witness_script): - assert isinstance(witness_script, Node) - self.witness_script = witness_script - - def __repr__(self): - return descsum_create(f"wsh({self.witness_script})") - - @property - def script_pubkey(self): - witness_program = sha256(self.witness_script.script) - return CScript([0, witness_program]) - - @property - def script_sighash(self): - return self.witness_script.script - - @property - def keys(self): - return self.witness_script.keys - - def satisfy(self, sat_material=None): - """Get the witness stack to spend from this descriptor. - - :param sat_material: a miniscript.satisfaction.SatisfactionMaterial with data - available to fulfill the conditions set by the Script. - """ - sat = self.witness_script.satisfy(sat_material) - if sat is not None: - return sat + [self.witness_script.script] - - -class WpkhDescriptor(Descriptor): - """A Segwit v0 P2WPKH Output Script Descriptor.""" - - def __init__(self, pubkey): - assert isinstance(pubkey, DescriptorKey) - self.pubkey = pubkey - - def __repr__(self): - return descsum_create(f"wpkh({self.pubkey})") - - @property - def script_pubkey(self): - witness_program = hash160(self.pubkey.bytes()) - return CScript([0, witness_program]) - - @property - def script_sighash(self): - key_hash = hash160(self.pubkey.bytes()) - return CScript([OP_DUP, OP_HASH160, key_hash, OP_EQUALVERIFY, OP_CHECKSIG]) - - @property - def keys(self): - return [self.pubkey] - - def satisfy(self, signature): - """Get the witness stack to spend from this descriptor. - - :param signature: a signature (in bytes) for the pubkey from the descriptor. - """ - assert isinstance(signature, bytes) - return [signature, self.pubkey.bytes()] - - -class TrDescriptor(Descriptor): - """A Pay-to-Taproot Output Script Descriptor.""" - - def __init__(self, internal_key): - assert isinstance(internal_key, DescriptorKey) and internal_key.x_only - self.internal_key = internal_key - - def __repr__(self): - return descsum_create(f"tr({self.internal_key})") - - def output_key(self): - # "If the spending conditions do not require a script path, the output key - # should commit to an unspendable script path" (see BIP341, BIP386) - return taproot_tweak(self.internal_key.bytes(), b"").format() - - @property - def script_pubkey(self): - return CScript([OP_1, self.output_key()]) - - @property - def keys(self): - return [self.internal_key] - - def satisfy(self, sat_material=None): - """Get the witness stack to spend from this descriptor. - - :param sat_material: a miniscript.satisfaction.SatisfactionMaterial with data - available to spend from the key path or any of the leaves. - """ - out_key = self.output_key() - if out_key in sat_material.signatures: - return [sat_material.signatures[out_key]] - - return diff --git a/bitcoin_client/ledger_bitcoin/bip380/descriptors/checksum.py b/bitcoin_client/ledger_bitcoin/bip380/descriptors/checksum.py deleted file mode 100644 index 9f3e01326..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/descriptors/checksum.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2019 Pieter Wuille -# Distributed under the MIT software license, see the accompanying -# file COPYING or http://www.opensource.org/licenses/mit-license.php. -"""Utility functions related to output descriptors""" - -import re - -INPUT_CHARSET = "0123456789()[],'/*abcdefgh@:$%{}IJKLMNOPQRSTUVWXYZ&+-.;<=>?!^_|~ijklmnopqrstuvwxyzABCDEFGH`#\"\\ " -CHECKSUM_CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" -GENERATOR = [0xF5DEE51989, 0xA9FDCA3312, 0x1BAB10E32D, 0x3706B1677A, 0x644D626FFD] - - -def descsum_polymod(symbols): - """Internal function that computes the descriptor checksum.""" - chk = 1 - for value in symbols: - top = chk >> 35 - chk = (chk & 0x7FFFFFFFF) << 5 ^ value - for i in range(5): - chk ^= GENERATOR[i] if ((top >> i) & 1) else 0 - return chk - - -def descsum_expand(s): - """Internal function that does the character to symbol expansion""" - groups = [] - symbols = [] - for c in s: - if not c in INPUT_CHARSET: - return None - v = INPUT_CHARSET.find(c) - symbols.append(v & 31) - groups.append(v >> 5) - if len(groups) == 3: - symbols.append(groups[0] * 9 + groups[1] * 3 + groups[2]) - groups = [] - if len(groups) == 1: - symbols.append(groups[0]) - elif len(groups) == 2: - symbols.append(groups[0] * 3 + groups[1]) - return symbols - - -def descsum_create(s): - """Add a checksum to a descriptor without""" - symbols = descsum_expand(s) + [0, 0, 0, 0, 0, 0, 0, 0] - checksum = descsum_polymod(symbols) ^ 1 - return ( - s - + "#" - + "".join(CHECKSUM_CHARSET[(checksum >> (5 * (7 - i))) & 31] for i in range(8)) - ) - - -def descsum_check(s): - """Verify that the checksum is correct in a descriptor""" - if s[-9] != "#": - return False - if not all(x in CHECKSUM_CHARSET for x in s[-8:]): - return False - symbols = descsum_expand(s[:-9]) + [CHECKSUM_CHARSET.find(x) for x in s[-8:]] - return descsum_polymod(symbols) == 1 - - -def drop_origins(s): - """Drop the key origins from a descriptor""" - desc = re.sub(r"\[.+?\]", "", s) - if "#" in s: - desc = desc[: desc.index("#")] - return descsum_create(desc) diff --git a/bitcoin_client/ledger_bitcoin/bip380/descriptors/errors.py b/bitcoin_client/ledger_bitcoin/bip380/descriptors/errors.py deleted file mode 100644 index f7b58483a..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/descriptors/errors.py +++ /dev/null @@ -1,5 +0,0 @@ -class DescriptorParsingError(ValueError): - """Error while parsing a Bitcoin Output Descriptor from its string representation""" - - def __init__(self, message): - self.message = message diff --git a/bitcoin_client/ledger_bitcoin/bip380/descriptors/parsing.py b/bitcoin_client/ledger_bitcoin/bip380/descriptors/parsing.py deleted file mode 100644 index 1d18bffdd..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/descriptors/parsing.py +++ /dev/null @@ -1,56 +0,0 @@ -from ...bip380 import descriptors -from ...bip380.key import DescriptorKey, DescriptorKeyError -from ...bip380.miniscript import Node -from ...bip380.descriptors.checksum import descsum_check - -from .errors import DescriptorParsingError - - -def split_checksum(desc_str, strict=False): - """Removes and check the provided checksum. - If not told otherwise, this won't fail on a missing checksum. - - :param strict: whether to require the presence of the checksum. - """ - desc_split = desc_str.split("#") - if len(desc_split) != 2: - if strict: - raise DescriptorParsingError("Missing checksum") - return desc_split[0] - - descriptor, checksum = desc_split - if not descsum_check(desc_str): - raise DescriptorParsingError( - f"Checksum '{checksum}' is invalid for '{descriptor}'" - ) - - return descriptor - - -def descriptor_from_str(desc_str, strict=False): - """Parse a Bitcoin Output Script Descriptor from its string representation. - - :param strict: whether to require the presence of a checksum. - """ - desc_str = split_checksum(desc_str, strict=strict) - - if desc_str.startswith("wsh(") and desc_str.endswith(")"): - # TODO: decent errors in the Miniscript module to be able to catch them here. - ms = Node.from_str(desc_str[4:-1]) - return descriptors.WshDescriptor(ms) - - if desc_str.startswith("wpkh(") and desc_str.endswith(")"): - try: - pubkey = DescriptorKey(desc_str[5:-1]) - except DescriptorKeyError as e: - raise DescriptorParsingError(str(e)) - return descriptors.WpkhDescriptor(pubkey) - - if desc_str.startswith("tr(") and desc_str.endswith(")"): - try: - pubkey = DescriptorKey(desc_str[3:-1], x_only=True) - except DescriptorKeyError as e: - raise DescriptorParsingError(str(e)) - return descriptors.TrDescriptor(pubkey) - - raise DescriptorParsingError(f"Unknown descriptor fragment: {desc_str}") diff --git a/bitcoin_client/ledger_bitcoin/bip380/descriptors/utils.py b/bitcoin_client/ledger_bitcoin/bip380/descriptors/utils.py deleted file mode 100644 index 25dbfe94f..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/descriptors/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Utilities for working with descriptors.""" -import coincurve -import hashlib - - -def tagged_hash(tag, data): - ss = hashlib.sha256(tag.encode("utf-8")).digest() - ss += ss - ss += data - return hashlib.sha256(ss).digest() - - -def taproot_tweak(pubkey_bytes, merkle_root): - assert isinstance(pubkey_bytes, bytes) and len(pubkey_bytes) == 32 - assert isinstance(merkle_root, bytes) - - t = tagged_hash("TapTweak", pubkey_bytes + merkle_root) - xonly_pubkey = coincurve.PublicKeyXOnly(pubkey_bytes) - xonly_pubkey.tweak_add(t) # TODO: error handling - - return xonly_pubkey diff --git a/bitcoin_client/ledger_bitcoin/bip380/key.py b/bitcoin_client/ledger_bitcoin/bip380/key.py deleted file mode 100644 index 3e05b61d5..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/key.py +++ /dev/null @@ -1,338 +0,0 @@ -import coincurve -import copy - -from bip32 import BIP32, HARDENED_INDEX -from bip32.utils import _deriv_path_str_to_list -from .utils.hashes import hash160 -from enum import Enum, auto - - -def is_raw_key(obj): - return isinstance(obj, (coincurve.PublicKey, coincurve.PublicKeyXOnly)) - - -class DescriptorKeyError(Exception): - def __init__(self, message): - self.message = message - - -class DescriporKeyOrigin: - """The origin of a key in a descriptor. - - See https://github.com/bitcoin/bips/blob/master/bip-0380.mediawiki#key-expressions. - """ - - def __init__(self, fingerprint, path): - assert isinstance(fingerprint, bytes) and isinstance(path, list) - - self.fingerprint = fingerprint - self.path = path - - def from_str(origin_str): - # Origing starts and ends with brackets - if not origin_str.startswith("[") or not origin_str.endswith("]"): - raise DescriptorKeyError(f"Insane origin: '{origin_str}'") - # At least 8 hex characters + brackets - if len(origin_str) < 10: - raise DescriptorKeyError(f"Insane origin: '{origin_str}'") - - # For the fingerprint, just read the 4 bytes. - try: - fingerprint = bytes.fromhex(origin_str[1:9]) - except ValueError: - raise DescriptorKeyError(f"Insane fingerprint in origin: '{origin_str}'") - # For the path, we (how bad) reuse an internal helper from python-bip32. - path = [] - if len(origin_str) > 10: - if origin_str[9] != "/": - raise DescriptorKeyError(f"Insane path in origin: '{origin_str}'") - # The helper operates on "m/10h/11/12'/13", so give it a "m". - dummy = "m" - try: - path = _deriv_path_str_to_list(dummy + origin_str[9:-1]) - except ValueError: - raise DescriptorKeyError(f"Insane path in origin: '{origin_str}'") - - return DescriporKeyOrigin(fingerprint, path) - - -class KeyPathKind(Enum): - FINAL = auto() - WILDCARD_UNHARDENED = auto() - WILDCARD_HARDENED = auto() - - def is_wildcard(self): - return self in [KeyPathKind.WILDCARD_HARDENED, KeyPathKind.WILDCARD_UNHARDENED] - - -def parse_index(index_str): - """Parse a derivation index, as contained in a derivation path.""" - assert isinstance(index_str, str) - - try: - # if HARDENED - if index_str[-1:] in ["'", "h", "H"]: - return int(index_str[:-1]) + HARDENED_INDEX - else: - return int(index_str) - except ValueError as e: - raise DescriptorKeyError(f"Invalid derivation index {index_str}: '{e}'") - - -class DescriptorKeyPath: - """The derivation path of a key in a descriptor. - - See https://github.com/bitcoin/bips/blob/master/bip-0380.mediawiki#key-expressions - as well as BIP389 for multipath expressions. - """ - - def __init__(self, paths, kind): - assert ( - isinstance(paths, list) - and isinstance(kind, KeyPathKind) - and len(paths) > 0 - and all(isinstance(p, list) for p in paths) - ) - - self.paths = paths - self.kind = kind - - def is_multipath(self): - """Whether this derivation path actually contains multiple of them.""" - return len(self.paths) > 1 - - def from_str(path_str): - if len(path_str) < 2: - raise DescriptorKeyError(f"Insane key path: '{path_str}'") - if path_str[0] != "/": - raise DescriptorKeyError(f"Insane key path: '{path_str}'") - - # Determine whether this key may be derived. - kind = KeyPathKind.FINAL - if len(path_str) > 2 and path_str[-3:] in ["/*'", "/*h", "/*H"]: - kind = KeyPathKind.WILDCARD_HARDENED - path_str = path_str[:-3] - elif len(path_str) > 1 and path_str[-2:] == "/*": - kind = KeyPathKind.WILDCARD_UNHARDENED - path_str = path_str[:-2] - - paths = [[]] - if len(path_str) == 0: - return DescriptorKeyPath(paths, kind) - - for index in path_str[1:].split("/"): - # If this is a multipath expression, of the form '' - if ( - index.startswith("<") - and index.endswith(">") - and ";" in index - and len(index) >= 5 - ): - # Can't have more than one multipath expression - if len(paths) > 1: - raise DescriptorKeyError( - f"May only have a single multipath step in derivation path: '{path_str}'" - ) - indexes = index[1:-1].split(";") - paths = [copy.copy(paths[0]) for _ in indexes] - for i, der_index in enumerate(indexes): - paths[i].append(parse_index(der_index)) - else: - # This is a "single index" expression. - for path in paths: - path.append(parse_index(index)) - return DescriptorKeyPath(paths, kind) - - -class DescriptorKey: - """A Bitcoin key to be used in Output Script Descriptors. - - May be an extended or raw public key. - """ - - def __init__(self, key, x_only=False): - # Information about the origin of this key. - self.origin = None - # If it is an xpub, a path toward a child key of that xpub. - self.path = None - # Whether to only create x-only public keys. - self.x_only = x_only - # Whether to serialize to string representation without the sign byte. - # This is necessary to roundtrip 33-bytes keys under Taproot context. - self.ser_x_only = x_only - - if isinstance(key, bytes): - if len(key) == 32: - key_cls = coincurve.PublicKeyXOnly - self.x_only = True - self.ser_x_only = True - elif len(key) == 33: - key_cls = coincurve.PublicKey - self.ser_x_only = False - else: - raise DescriptorKeyError( - "Only compressed and x-only keys are supported" - ) - try: - self.key = key_cls(key) - except ValueError as e: - raise DescriptorKeyError(f"Public key parsing error: '{str(e)}'") - - elif isinstance(key, BIP32): - self.key = key - - elif isinstance(key, str): - # Try parsing an optional origin prepended to the key - splitted_key = key.split("]", maxsplit=1) - if len(splitted_key) == 2: - origin, key = splitted_key - self.origin = DescriporKeyOrigin.from_str(origin + "]") - - # Is it a raw key? - if len(key) in (64, 66): - pk_cls = coincurve.PublicKey - if len(key) == 64: - pk_cls = coincurve.PublicKeyXOnly - self.x_only = True - self.ser_x_only = True - else: - self.ser_x_only = False - try: - self.key = pk_cls(bytes.fromhex(key)) - except ValueError as e: - raise DescriptorKeyError(f"Public key parsing error: '{str(e)}'") - # If not it must be an xpub. - else: - # There may be an optional path appended to the xpub. - splitted_key = key.split("/", maxsplit=1) - if len(splitted_key) == 2: - key, path = splitted_key - self.path = DescriptorKeyPath.from_str("/" + path) - - try: - self.key = BIP32.from_xpub(key) - except ValueError as e: - raise DescriptorKeyError(f"Xpub parsing error: '{str(e)}'") - - else: - raise DescriptorKeyError( - "Invalid parameter type: expecting bytes, hex str or BIP32 instance." - ) - - def __repr__(self): - key = "" - - def ser_index(key, der_index): - # If this a hardened step, deduce the threshold and mark it. - if der_index < HARDENED_INDEX: - return str(der_index) - else: - return f"{der_index - 2**31}'" - - def ser_paths(key, paths): - assert len(paths) > 0 - - for i, der_index in enumerate(paths[0]): - # If this is a multipath expression, write the multi-index step accordingly - if len(paths) > 1 and paths[1][i] != der_index: - key += "/<" - for j, path in enumerate(paths): - key += ser_index(key, path[i]) - if j < len(paths) - 1: - key += ";" - key += ">" - else: - key += "/" + ser_index(key, der_index) - - return key - - if self.origin is not None: - key += f"[{self.origin.fingerprint.hex()}" - key = ser_paths(key, [self.origin.path]) - key += "]" - - if isinstance(self.key, BIP32): - key += self.key.get_xpub() - else: - assert is_raw_key(self.key) - raw_key = self.key.format() - if len(raw_key) == 33 and self.ser_x_only: - raw_key = raw_key[1:] - key += raw_key.hex() - - if self.path is not None: - key = ser_paths(key, self.path.paths) - if self.path.kind == KeyPathKind.WILDCARD_UNHARDENED: - key += "/*" - elif self.path.kind == KeyPathKind.WILDCARD_HARDENED: - key += "/*'" - - return key - - def is_multipath(self): - """Whether this key contains more than one derivation path.""" - return self.path is not None and self.path.is_multipath() - - def derivation_path(self): - """Get the single derivation path for this key. - - Will raise if it has multiple, and return None if it doesn't have any. - """ - if self.path is None: - return None - if self.path.is_multipath(): - raise DescriptorKeyError( - f"Key has multiple derivation paths: {self.path.paths}" - ) - return self.path.paths[0] - - def bytes(self): - """Get this key as raw bytes. - - Will raise if this key contains multiple derivation paths. - """ - if is_raw_key(self.key): - raw = self.key.format() - if self.x_only and len(raw) == 33: - return raw[1:] - assert len(raw) == 32 or not self.x_only - return raw - else: - assert isinstance(self.key, BIP32) - path = self.derivation_path() - if path is None: - return self.key.pubkey - assert not self.path.kind.is_wildcard() # TODO: real errors - return self.key.get_pubkey_from_path(path) - - def derive(self, index): - """Derive the key at the given index. - - Will raise if this key contains multiple derivation paths. - A no-op if the key isn't a wildcard. Will start from 2**31 if the key is a "hardened - wildcard". - """ - assert isinstance(index, int) - if ( - self.path is None - or self.path.is_multipath() - or self.path.kind == KeyPathKind.FINAL - ): - return - assert isinstance(self.key, BIP32) - - if self.path.kind == KeyPathKind.WILDCARD_HARDENED: - index += 2 ** 31 - assert index < 2 ** 32 - - if self.origin is None: - fingerprint = hash160(self.key.pubkey)[:4] - self.origin = DescriporKeyOrigin(fingerprint, [index]) - else: - self.origin.path.append(index) - - # This can't fail now. - path = self.derivation_path() - # TODO(bip32): have a way to derive without roundtripping through string ser. - self.key = BIP32.from_xpub(self.key.get_xpub_from_path(path + [index])) - self.path = None diff --git a/bitcoin_client/ledger_bitcoin/bip380/miniscript/__init__.py b/bitcoin_client/ledger_bitcoin/bip380/miniscript/__init__.py deleted file mode 100644 index b0de1f9c7..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/miniscript/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Miniscript -========== - -Miniscript is an extension to Bitcoin Output Script descriptors. It is a language for \ -writing (a subset of) Bitcoin Scripts in a structured way, enabling analysis, composition, \ -generic signing and more. - -For more information about Miniscript, see https://bitcoin.sipa.be/miniscript. -""" - -from .fragments import Node -from .satisfaction import SatisfactionMaterial diff --git a/bitcoin_client/ledger_bitcoin/bip380/miniscript/errors.py b/bitcoin_client/ledger_bitcoin/bip380/miniscript/errors.py deleted file mode 100644 index 7ccd98f4e..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/miniscript/errors.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -All the exceptions raised when dealing with Miniscript. -""" - - -class MiniscriptMalformed(ValueError): - def __init__(self, message): - self.message = message - - -class MiniscriptNodeCreationError(ValueError): - def __init__(self, message): - self.message = message - - -class MiniscriptPropertyError(ValueError): - def __init__(self, message): - self.message = message - -# TODO: errors for type errors, parsing errors, etc.. diff --git a/bitcoin_client/ledger_bitcoin/bip380/miniscript/fragments.py b/bitcoin_client/ledger_bitcoin/bip380/miniscript/fragments.py deleted file mode 100644 index d0e572eeb..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/miniscript/fragments.py +++ /dev/null @@ -1,1225 +0,0 @@ -""" -Miniscript AST elements. - -Each element correspond to a Bitcoin Script fragment, and has various type properties. -See the Miniscript website for the specification of the type system: https://bitcoin.sipa.be/miniscript/. -""" - -import copy -from ...bip380.miniscript import parsing - -from ...bip380.key import DescriptorKey -from ...bip380.utils.hashes import hash160 -from ...bip380.utils.script import ( - CScript, - OP_1, - OP_0, - OP_ADD, - OP_BOOLAND, - OP_BOOLOR, - OP_DUP, - OP_ELSE, - OP_ENDIF, - OP_EQUAL, - OP_EQUALVERIFY, - OP_FROMALTSTACK, - OP_IFDUP, - OP_IF, - OP_CHECKLOCKTIMEVERIFY, - OP_CHECKMULTISIG, - OP_CHECKMULTISIGVERIFY, - OP_CHECKSEQUENCEVERIFY, - OP_CHECKSIG, - OP_CHECKSIGVERIFY, - OP_HASH160, - OP_HASH256, - OP_NOTIF, - OP_RIPEMD160, - OP_SHA256, - OP_SIZE, - OP_SWAP, - OP_TOALTSTACK, - OP_VERIFY, - OP_0NOTEQUAL, -) - -from .errors import MiniscriptNodeCreationError -from .property import Property -from .satisfaction import ExecutionInfo, Satisfaction - - -# Threshold for nLockTime: below this value it is interpreted as block number, -# otherwise as UNIX timestamp. -LOCKTIME_THRESHOLD = 500000000 # Tue Nov 5 00:53:20 1985 UTC - -# If CTxIn::nSequence encodes a relative lock-time and this flag -# is set, the relative lock-time has units of 512 seconds, -# otherwise it specifies blocks with a granularity of 1. -SEQUENCE_LOCKTIME_TYPE_FLAG = 1 << 22 - - -class Node: - """A Miniscript fragment.""" - - # The fragment's type and properties - p = None - # List of all sub fragments - subs = [] - # A list of Script elements, a CScript is created all at once in the script() method. - _script = [] - # Whether any satisfaction for this fragment require a signature - needs_sig = None - # Whether any dissatisfaction for this fragment requires a signature - is_forced = None - # Whether this fragment has a unique unconditional satisfaction, and all conditional - # ones require a signature. - is_expressive = None - # Whether for any possible way to satisfy this fragment (may be none), a - # non-malleable satisfaction exists. - is_nonmalleable = None - # Whether this node or any of its subs contains an absolute heightlock - abs_heightlocks = None - # Whether this node or any of its subs contains a relative heightlock - rel_heightlocks = None - # Whether this node or any of its subs contains an absolute timelock - abs_timelocks = None - # Whether this node or any of its subs contains a relative timelock - rel_timelocks = None - # Whether this node does not contain a mix of timelock or heightlock of different types. - # That is, not (abs_heightlocks and rel_heightlocks or abs_timelocks and abs_timelocks) - no_timelock_mix = None - # Information about this Miniscript execution (satisfaction cost, etc..) - exec_info = None - - def __init__(self, *args, **kwargs): - # Needs to be implemented by derived classes. - raise NotImplementedError - - def from_str(ms_str): - """Parse a Miniscript fragment from its string representation.""" - assert isinstance(ms_str, str) - return parsing.miniscript_from_str(ms_str) - - def from_script(script, pkh_preimages={}): - """Decode a Miniscript fragment from its Script representation.""" - assert isinstance(script, CScript) - return parsing.miniscript_from_script(script, pkh_preimages) - - # TODO: have something like BuildScript from Core and get rid of the _script member. - @property - def script(self): - return CScript(self._script) - - @property - def keys(self): - """Get the list of all keys from this Miniscript, in order of apparition.""" - # Overriden by fragments that actually have keys. - return [key for sub in self.subs for key in sub.keys] - - def satisfy(self, sat_material): - """Get the witness of the smallest non-malleable satisfaction for this fragment, - if one exists. - - :param sat_material: a SatisfactionMaterial containing available data to satisfy - challenges. - """ - sat = self.satisfaction(sat_material) - if not sat.has_sig: - return None - return sat.witness - - def satisfaction(self, sat_material): - """Get the satisfaction for this fragment. - - :param sat_material: a SatisfactionMaterial containing available data to satisfy - challenges. - """ - # Needs to be implemented by derived classes. - raise NotImplementedError - - def dissatisfaction(self): - """Get the dissatisfaction for this fragment.""" - # Needs to be implemented by derived classes. - raise NotImplementedError - - -class Just0(Node): - def __init__(self): - - self._script = [OP_0] - - self.p = Property("Bzud") - self.needs_sig = False - self.is_forced = False - self.is_expressive = True - self.is_nonmalleable = True - self.abs_heightlocks = False - self.rel_heightlocks = False - self.abs_timelocks = False - self.rel_timelocks = False - self.no_timelock_mix = True - self.exec_info = ExecutionInfo(0, 0, None, 0) - - def satisfaction(self, sat_material): - return Satisfaction.unavailable() - - def dissatisfaction(self): - return Satisfaction(witness=[]) - - def __repr__(self): - return "0" - - -class Just1(Node): - def __init__(self): - - self._script = [OP_1] - - self.p = Property("Bzu") - self.needs_sig = False - self.is_forced = True # No dissat - self.is_expressive = False # No dissat - self.is_nonmalleable = True # FIXME: how comes? Standardness rules? - self.abs_heightlocks = False - self.rel_heightlocks = False - self.abs_timelocks = False - self.rel_timelocks = False - self.no_timelock_mix = True - self.exec_info = ExecutionInfo(0, 0, 0, None) - - def satisfaction(self, sat_material): - return Satisfaction(witness=[]) - - def dissatisfaction(self): - return Satisfaction.unavailable() - - def __repr__(self): - return "1" - - -class PkNode(Node): - """A virtual class for nodes containing a single public key. - - Should not be instanced directly, use Pk() or Pkh(). - """ - - def __init__(self, pubkey): - - if isinstance(pubkey, bytes) or isinstance(pubkey, str): - self.pubkey = DescriptorKey(pubkey) - elif isinstance(pubkey, DescriptorKey): - self.pubkey = pubkey - else: - raise MiniscriptNodeCreationError("Invalid public key") - - self.needs_sig = True # FIXME: think about having it in 'c:' instead - self.is_forced = False - self.is_expressive = True - self.is_nonmalleable = True - self.abs_heightlocks = False - self.rel_heightlocks = False - self.abs_timelocks = False - self.rel_timelocks = False - self.no_timelock_mix = True - - @property - def keys(self): - return [self.pubkey] - - -class Pk(PkNode): - def __init__(self, pubkey): - PkNode.__init__(self, pubkey) - - self.p = Property("Konud") - self.exec_info = ExecutionInfo(0, 0, 0, 0) - - @property - def _script(self): - return [self.pubkey.bytes()] - - def satisfaction(self, sat_material): - sig = sat_material.signatures.get(self.pubkey.bytes()) - if sig is None: - return Satisfaction.unavailable() - return Satisfaction([sig], has_sig=True) - - def dissatisfaction(self): - return Satisfaction(witness=[b""]) - - def __repr__(self): - return f"pk_k({self.pubkey})" - - -class Pkh(PkNode): - # FIXME: should we support a hash here, like rust-bitcoin? I don't think it's safe. - def __init__(self, pubkey): - PkNode.__init__(self, pubkey) - - self.p = Property("Knud") - self.exec_info = ExecutionInfo(3, 0, 1, 1) - - @property - def _script(self): - return [OP_DUP, OP_HASH160, self.pk_hash(), OP_EQUALVERIFY] - - def satisfaction(self, sat_material): - sig = sat_material.signatures.get(self.pubkey.bytes()) - if sig is None: - return Satisfaction.unavailable() - return Satisfaction(witness=[sig, self.pubkey.bytes()], has_sig=True) - - def dissatisfaction(self): - return Satisfaction(witness=[b"", self.pubkey.bytes()]) - - def __repr__(self): - return f"pk_h({self.pubkey})" - - def pk_hash(self): - assert isinstance(self.pubkey, DescriptorKey) - return hash160(self.pubkey.bytes()) - - -class Older(Node): - def __init__(self, value): - assert value > 0 and value < 2 ** 31 - - self.value = value - self._script = [self.value, OP_CHECKSEQUENCEVERIFY] - - self.p = Property("Bz") - self.needs_sig = False - self.is_forced = True - self.is_expressive = False # No dissat - self.is_nonmalleable = True - self.rel_timelocks = bool(value & SEQUENCE_LOCKTIME_TYPE_FLAG) - self.rel_heightlocks = not self.rel_timelocks - self.abs_heightlocks = False - self.abs_timelocks = False - self.no_timelock_mix = True - self.exec_info = ExecutionInfo(1, 0, 0, None) - - def satisfaction(self, sat_material): - if sat_material.max_sequence < self.value: - return Satisfaction.unavailable() - return Satisfaction(witness=[]) - - def dissatisfaction(self): - return Satisfaction.unavailable() - - def __repr__(self): - return f"older({self.value})" - - -class After(Node): - def __init__(self, value): - assert value > 0 and value < 2 ** 31 - - self.value = value - self._script = [self.value, OP_CHECKLOCKTIMEVERIFY] - - self.p = Property("Bz") - self.needs_sig = False - self.is_forced = True - self.is_expressive = False # No dissat - self.is_nonmalleable = True - self.abs_heightlocks = value < LOCKTIME_THRESHOLD - self.abs_timelocks = not self.abs_heightlocks - self.rel_heightlocks = False - self.rel_timelocks = False - self.no_timelock_mix = True - self.exec_info = ExecutionInfo(1, 0, 0, None) - - def satisfaction(self, sat_material): - if sat_material.max_lock_time < self.value: - return Satisfaction.unavailable() - return Satisfaction(witness=[]) - - def dissatisfaction(self): - return Satisfaction.unavailable() - - def __repr__(self): - return f"after({self.value})" - - -class HashNode(Node): - """A virtual class for fragments with hashlock semantics. - - Should not be instanced directly, use concrete fragments instead. - """ - - def __init__(self, digest, hash_op): - assert isinstance(digest, bytes) # TODO: real errors - - self.digest = digest - self._script = [OP_SIZE, 32, OP_EQUALVERIFY, hash_op, digest, OP_EQUAL] - - self.p = Property("Bonud") - self.needs_sig = False - self.is_forced = False - self.is_expressive = False - self.is_nonmalleable = True - self.abs_heightlocks = False - self.rel_heightlocks = False - self.abs_timelocks = False - self.rel_timelocks = False - self.no_timelock_mix = True - self.exec_info = ExecutionInfo(4, 0, 1, None) - - def satisfaction(self, sat_material): - preimage = sat_material.preimages.get(self.digest) - if preimage is None: - return Satisfaction.unavailable() - return Satisfaction(witness=[preimage]) - - def dissatisfaction(self): - return Satisfaction.unavailable() - return Satisfaction(witness=[b""]) - - -class Sha256(HashNode): - def __init__(self, digest): - assert len(digest) == 32 # TODO: real errors - HashNode.__init__(self, digest, OP_SHA256) - - def __repr__(self): - return f"sha256({self.digest.hex()})" - - -class Hash256(HashNode): - def __init__(self, digest): - assert len(digest) == 32 # TODO: real errors - HashNode.__init__(self, digest, OP_HASH256) - - def __repr__(self): - return f"hash256({self.digest.hex()})" - - -class Ripemd160(HashNode): - def __init__(self, digest): - assert len(digest) == 20 # TODO: real errors - HashNode.__init__(self, digest, OP_RIPEMD160) - - def __repr__(self): - return f"ripemd160({self.digest.hex()})" - - -class Hash160(HashNode): - def __init__(self, digest): - assert len(digest) == 20 # TODO: real errors - HashNode.__init__(self, digest, OP_HASH160) - - def __repr__(self): - return f"hash160({self.digest.hex()})" - - -class Multi(Node): - def __init__(self, k, keys): - assert 1 <= k <= len(keys) - assert all(isinstance(k, DescriptorKey) for k in keys) - - self.k = k - self.pubkeys = keys - - self.p = Property("Bndu") - self.needs_sig = True - self.is_forced = False - self.is_expressive = True - self.is_nonmalleable = True - self.abs_heightlocks = False - self.rel_heightlocks = False - self.abs_timelocks = False - self.rel_timelocks = False - self.no_timelock_mix = True - self.exec_info = ExecutionInfo(1, len(keys), 1 + k, 1 + k) - - @property - def keys(self): - return self.pubkeys - - @property - def _script(self): - return [ - self.k, - *[k.bytes() for k in self.keys], - len(self.keys), - OP_CHECKMULTISIG, - ] - - def satisfaction(self, sat_material): - sigs = [] - for key in self.keys: - sig = sat_material.signatures.get(key.bytes()) - if sig is not None: - assert isinstance(sig, bytes) - sigs.append(sig) - if len(sigs) == self.k: - break - if len(sigs) < self.k: - return Satisfaction.unavailable() - return Satisfaction(witness=[b""] + sigs, has_sig=True) - - def dissatisfaction(self): - return Satisfaction(witness=[b""] * (self.k + 1)) - - def __repr__(self): - return f"multi({','.join([str(self.k)] + [str(k) for k in self.keys])})" - - -class AndV(Node): - def __init__(self, sub_x, sub_y): - assert sub_x.p.V - assert sub_y.p.has_any("BKV") - - self.subs = [sub_x, sub_y] - - self.p = Property( - sub_y.p.type() - + ("z" if sub_x.p.z and sub_y.p.z else "") - + ("o" if sub_x.p.z and sub_y.p.o or sub_x.p.o and sub_y.p.z else "") - + ("n" if sub_x.p.n or sub_x.p.z and sub_y.p.n else "") - + ("u" if sub_y.p.u else "") - ) - self.needs_sig = any(sub.needs_sig for sub in self.subs) - self.is_forced = any(sub.needs_sig for sub in self.subs) - self.is_expressive = False # Not 'd' - self.is_nonmalleable = all(sub.is_nonmalleable for sub in self.subs) - self.abs_heightlocks = any(sub.abs_heightlocks for sub in self.subs) - self.rel_heightlocks = any(sub.rel_heightlocks for sub in self.subs) - self.abs_timelocks = any(sub.abs_timelocks for sub in self.subs) - self.rel_timelocks = any(sub.rel_timelocks for sub in self.subs) - self.no_timelock_mix = not ( - self.abs_heightlocks - and self.abs_timelocks - or self.rel_heightlocks - and self.rel_timelocks - ) - - @property - def _script(self): - return sum((sub._script for sub in self.subs), start=[]) - - @property - def exec_info(self): - exec_info = ExecutionInfo.from_concat( - self.subs[0].exec_info, self.subs[1].exec_info - ) - exec_info.set_undissatisfiable() # it's V. - return exec_info - - def satisfaction(self, sat_material): - return Satisfaction.from_concat(sat_material, *self.subs) - - def dissatisfaction(self): - return Satisfaction.unavailable() # it's V. - - def __repr__(self): - return f"and_v({','.join(map(str, self.subs))})" - - -class AndB(Node): - def __init__(self, sub_x, sub_y): - assert sub_x.p.B and sub_y.p.W - - self.subs = [sub_x, sub_y] - - self.p = Property( - "Bu" - + ("z" if sub_x.p.z and sub_y.p.z else "") - + ("o" if sub_x.p.z and sub_y.p.o or sub_x.p.o and sub_y.p.z else "") - + ("n" if sub_x.p.n or sub_x.p.z and sub_y.p.n else "") - + ("d" if sub_x.p.d and sub_y.p.d else "") - + ("u" if sub_y.p.u else "") - ) - self.needs_sig = any(sub.needs_sig for sub in self.subs) - self.is_forced = ( - sub_x.is_forced - and sub_y.is_forced - or any(sub.is_forced and sub.needs_sig for sub in self.subs) - ) - self.is_expressive = all(sub.is_forced and sub.needs_sig for sub in self.subs) - self.is_nonmalleable = all(sub.is_nonmalleable for sub in self.subs) - self.abs_heightlocks = any(sub.abs_heightlocks for sub in self.subs) - self.rel_heightlocks = any(sub.rel_heightlocks for sub in self.subs) - self.abs_timelocks = any(sub.abs_timelocks for sub in self.subs) - self.rel_timelocks = any(sub.rel_timelocks for sub in self.subs) - self.no_timelock_mix = not ( - self.abs_heightlocks - and self.abs_timelocks - or self.rel_heightlocks - and self.rel_timelocks - ) - - @property - def _script(self): - return sum((sub._script for sub in self.subs), start=[]) + [OP_BOOLAND] - - @property - def exec_info(self): - return ExecutionInfo.from_concat( - self.subs[0].exec_info, self.subs[1].exec_info, ops_count=1 - ) - - def satisfaction(self, sat_material): - return Satisfaction.from_concat(sat_material, self.subs[0], self.subs[1]) - - def dissatisfaction(self): - return self.subs[1].dissatisfaction() + self.subs[0].dissatisfaction() - - def __repr__(self): - return f"and_b({','.join(map(str, self.subs))})" - - -class OrB(Node): - def __init__(self, sub_x, sub_z): - assert sub_x.p.has_all("Bd") - assert sub_z.p.has_all("Wd") - - self.subs = [sub_x, sub_z] - - self.p = Property( - "Bdu" - + ("z" if sub_x.p.z and sub_z.p.z else "") - + ("o" if sub_x.p.z and sub_z.p.o or sub_x.p.o and sub_z.p.z else "") - ) - self.needs_sig = all(sub.needs_sig for sub in self.subs) - self.is_forced = False # Both subs are 'd' - self.is_expressive = all(sub.is_expressive for sub in self.subs) - self.is_nonmalleable = all( - sub.is_nonmalleable and sub.is_expressive for sub in self.subs - ) and any(sub.needs_sig for sub in self.subs) - self.abs_heightlocks = any(sub.abs_heightlocks for sub in self.subs) - self.rel_heightlocks = any(sub.rel_heightlocks for sub in self.subs) - self.abs_timelocks = any(sub.abs_timelocks for sub in self.subs) - self.rel_timelocks = any(sub.rel_timelocks for sub in self.subs) - self.no_timelock_mix = all(sub.no_timelock_mix for sub in self.subs) - - @property - def _script(self): - return sum((sub._script for sub in self.subs), start=[]) + [OP_BOOLOR] - - @property - def exec_info(self): - return ExecutionInfo.from_concat( - self.subs[0].exec_info, - self.subs[1].exec_info, - ops_count=1, - disjunction=True, - ) - - def satisfaction(self, sat_material): - return Satisfaction.from_concat( - sat_material, self.subs[0], self.subs[1], disjunction=True - ) - - def dissatisfaction(self): - return self.subs[1].dissatisfaction() + self.subs[0].dissatisfaction() - - def __repr__(self): - return f"or_b({','.join(map(str, self.subs))})" - - -class OrC(Node): - def __init__(self, sub_x, sub_z): - assert sub_x.p.has_all("Bdu") and sub_z.p.V - - self.subs = [sub_x, sub_z] - - self.p = Property( - "V" - + ("z" if sub_x.p.z and sub_z.p.z else "") - + ("o" if sub_x.p.o and sub_z.p.z else "") - ) - self.needs_sig = all(sub.needs_sig for sub in self.subs) - self.is_forced = True # Because sub_z is 'V' - self.is_expressive = False # V - self.is_nonmalleable = ( - all(sub.is_nonmalleable for sub in self.subs) - and any(sub.needs_sig for sub in self.subs) - and sub_x.is_expressive - ) - self.abs_heightlocks = any(sub.abs_heightlocks for sub in self.subs) - self.rel_heightlocks = any(sub.rel_heightlocks for sub in self.subs) - self.abs_timelocks = any(sub.abs_timelocks for sub in self.subs) - self.rel_timelocks = any(sub.rel_timelocks for sub in self.subs) - self.no_timelock_mix = all(sub.no_timelock_mix for sub in self.subs) - - @property - def _script(self): - return self.subs[0]._script + [OP_NOTIF] + self.subs[1]._script + [OP_ENDIF] - - @property - def exec_info(self): - exec_info = ExecutionInfo.from_or_uneven( - self.subs[0].exec_info, self.subs[1].exec_info, ops_count=2 - ) - exec_info.set_undissatisfiable() # it's V. - return exec_info - - def satisfaction(self, sat_material): - return Satisfaction.from_or_uneven(sat_material, self.subs[0], self.subs[1]) - - def dissatisfaction(self): - return Satisfaction.unavailable() # it's V. - - def __repr__(self): - return f"or_c({','.join(map(str, self.subs))})" - - -class OrD(Node): - def __init__(self, sub_x, sub_z): - assert sub_x.p.has_all("Bdu") - assert sub_z.p.has_all("B") - - self.subs = [sub_x, sub_z] - - self.p = Property( - "B" - + ("z" if sub_x.p.z and sub_z.p.z else "") - + ("o" if sub_x.p.o and sub_z.p.z else "") - + ("d" if sub_z.p.d else "") - + ("u" if sub_z.p.u else "") - ) - self.needs_sig = all(sub.needs_sig for sub in self.subs) - self.is_forced = all(sub.is_forced for sub in self.subs) - self.is_expressive = all(sub.is_expressive for sub in self.subs) - self.is_nonmalleable = ( - all(sub.is_nonmalleable for sub in self.subs) - and any(sub.needs_sig for sub in self.subs) - and sub_x.is_expressive - ) - self.abs_heightlocks = any(sub.abs_heightlocks for sub in self.subs) - self.rel_heightlocks = any(sub.rel_heightlocks for sub in self.subs) - self.abs_timelocks = any(sub.abs_timelocks for sub in self.subs) - self.rel_timelocks = any(sub.rel_timelocks for sub in self.subs) - self.no_timelock_mix = all(sub.no_timelock_mix for sub in self.subs) - - @property - def _script(self): - return ( - self.subs[0]._script - + [OP_IFDUP, OP_NOTIF] - + self.subs[1]._script - + [OP_ENDIF] - ) - - @property - def exec_info(self): - return ExecutionInfo.from_or_uneven( - self.subs[0].exec_info, self.subs[1].exec_info, ops_count=3 - ) - - def satisfaction(self, sat_material): - return Satisfaction.from_or_uneven(sat_material, self.subs[0], self.subs[1]) - - def dissatisfaction(self): - return self.subs[1].dissatisfaction() + self.subs[0].dissatisfaction() - - def __repr__(self): - return f"or_d({','.join(map(str, self.subs))})" - - -class OrI(Node): - def __init__(self, sub_x, sub_z): - assert sub_x.p.type() == sub_z.p.type() and sub_x.p.has_any("BKV") - - self.subs = [sub_x, sub_z] - - self.p = Property( - sub_x.p.type() - + ("o" if sub_x.p.z and sub_z.p.z else "") - + ("d" if sub_x.p.d or sub_z.p.d else "") - + ("u" if sub_x.p.u and sub_z.p.u else "") - ) - self.needs_sig = all(sub.needs_sig for sub in self.subs) - self.is_forced = all(sub.is_forced for sub in self.subs) - self.is_expressive = ( - sub_x.is_expressive - and sub_z.is_forced - or sub_x.is_forced - and sub_z.is_expressive - ) - self.is_nonmalleable = all(sub.is_nonmalleable for sub in self.subs) and any( - sub.needs_sig for sub in self.subs - ) - self.abs_heightlocks = any(sub.abs_heightlocks for sub in self.subs) - self.rel_heightlocks = any(sub.rel_heightlocks for sub in self.subs) - self.abs_timelocks = any(sub.abs_timelocks for sub in self.subs) - self.rel_timelocks = any(sub.rel_timelocks for sub in self.subs) - self.no_timelock_mix = all(sub.no_timelock_mix for sub in self.subs) - - @property - def _script(self): - return ( - [OP_IF] - + self.subs[0]._script - + [OP_ELSE] - + self.subs[1]._script - + [OP_ENDIF] - ) - - @property - def exec_info(self): - return ExecutionInfo.from_or_even( - self.subs[0].exec_info, self.subs[1].exec_info, ops_count=3 - ) - - def satisfaction(self, sat_material): - return (self.subs[0].satisfaction(sat_material) + Satisfaction([b"\x01"])) | ( - self.subs[1].satisfaction(sat_material) + Satisfaction([b""]) - ) - - def dissatisfaction(self): - return (self.subs[0].dissatisfaction() + Satisfaction(witness=[b"\x01"])) | ( - self.subs[1].dissatisfaction() + Satisfaction(witness=[b""]) - ) - - def __repr__(self): - return f"or_i({','.join(map(str, self.subs))})" - - -class AndOr(Node): - def __init__(self, sub_x, sub_y, sub_z): - assert sub_x.p.has_all("Bdu") - assert sub_y.p.type() == sub_z.p.type() and sub_y.p.has_any("BKV") - - self.subs = [sub_x, sub_y, sub_z] - - self.p = Property( - sub_y.p.type() - + ("z" if sub_x.p.z and sub_y.p.z and sub_z.p.z else "") - + ( - "o" - if sub_x.p.z - and sub_y.p.o - and sub_z.p.o - or sub_x.p.o - and sub_y.p.z - and sub_z.p.z - else "" - ) - + ("d" if sub_z.p.d else "") - + ("u" if sub_y.p.u and sub_z.p.u else "") - ) - self.needs_sig = sub_x.needs_sig and (sub_y.needs_sig or sub_z.needs_sig) - self.is_forced = sub_z.is_forced and (sub_x.needs_sig or sub_y.is_forced) - self.is_expressive = ( - sub_x.is_expressive - and sub_z.is_expressive - and (sub_x.needs_sig or sub_y.is_forced) - ) - self.is_nonmalleable = ( - all(sub.is_nonmalleable for sub in self.subs) - and any(sub.needs_sig for sub in self.subs) - and sub_x.is_expressive - ) - self.abs_heightlocks = any(sub.abs_heightlocks for sub in self.subs) - self.rel_heightlocks = any(sub.rel_heightlocks for sub in self.subs) - self.abs_timelocks = any(sub.abs_timelocks for sub in self.subs) - self.rel_timelocks = any(sub.rel_timelocks for sub in self.subs) - # X and Y, or Z. So we have a mix if any contain a timelock mix, or - # there is a mix between X and Y. - self.no_timelock_mix = all(sub.no_timelock_mix for sub in self.subs) and not ( - any(sub.rel_timelocks for sub in [sub_x, sub_y]) - and any(sub.rel_heightlocks for sub in [sub_x, sub_y]) - or any(sub.abs_timelocks for sub in [sub_x, sub_y]) - and any(sub.abs_heightlocks for sub in [sub_x, sub_y]) - ) - - @property - def _script(self): - return ( - self.subs[0]._script - + [OP_NOTIF] - + self.subs[2]._script - + [OP_ELSE] - + self.subs[1]._script - + [OP_ENDIF] - ) - - @property - def exec_info(self): - return ExecutionInfo.from_andor_uneven( - self.subs[0].exec_info, - self.subs[1].exec_info, - self.subs[2].exec_info, - ops_count=3, - ) - - def satisfaction(self, sat_material): - # (A and B) or (!A and C) - return ( - self.subs[1].satisfaction(sat_material) - + self.subs[0].satisfaction(sat_material) - ) | (self.subs[2].satisfaction(sat_material) + self.subs[0].dissatisfaction()) - - def dissatisfaction(self): - # Dissatisfy X and Z - return self.subs[2].dissatisfaction() + self.subs[0].dissatisfaction() - - def __repr__(self): - return f"andor({','.join(map(str, self.subs))})" - - -class AndN(AndOr): - def __init__(self, sub_x, sub_y): - AndOr.__init__(self, sub_x, sub_y, Just0()) - - def __repr__(self): - return f"and_n({self.subs[0]},{self.subs[1]})" - - -class Thresh(Node): - def __init__(self, k, subs): - n = len(subs) - assert 1 <= k <= n - - self.k = k - self.subs = subs - - all_z = True - all_z_but_one_odu = False - all_e = True - all_m = True - s_count = 0 - # If k == 1, just check each child for k - if k > 1: - self.abs_heightlocks = subs[0].abs_heightlocks - self.rel_heightlocks = subs[0].rel_heightlocks - self.abs_timelocks = subs[0].abs_timelocks - self.rel_timelocks = subs[0].rel_timelocks - else: - self.no_timelock_mix = True - - assert subs[0].p.has_all("Bdu") - for sub in subs[1:]: - assert sub.p.has_all("Wdu") - if not sub.p.z: - if all_z_but_one_odu: - # Fails "all 'z' but one" - all_z_but_one_odu = False - if all_z and sub.p.has_all("odu"): - # They were all 'z' up to now. - all_z_but_one_odu = True - all_z = False - all_e = all_e and sub.is_expressive - all_m = all_m and sub.is_nonmalleable - if sub.needs_sig: - s_count += 1 - if k > 1: - self.abs_heightlocks |= sub.abs_heightlocks - self.rel_heightlocks |= sub.rel_heightlocks - self.abs_timelocks |= sub.abs_timelocks - self.rel_timelocks |= sub.rel_timelocks - else: - self.no_timelock_mix &= sub.no_timelock_mix - - self.p = Property( - "Bdu" + ("z" if all_z else "") + ("o" if all_z_but_one_odu else "") - ) - self.needs_sig = s_count >= n - k - self.is_forced = False # All subs need to be 'd' - self.is_expressive = all_e and s_count == n - self.is_nonmalleable = all_e and s_count >= n - k - if k > 1: - self.no_timelock_mix = not ( - self.abs_heightlocks - and self.abs_timelocks - or self.rel_heightlocks - and self.rel_timelocks - ) - - @property - def _script(self): - return ( - self.subs[0]._script - + sum(((sub._script + [OP_ADD]) for sub in self.subs[1:]), start=[]) - + [self.k, OP_EQUAL] - ) - - @property - def exec_info(self): - return ExecutionInfo.from_thresh(self.k, [sub.exec_info for sub in self.subs]) - - def satisfaction(self, sat_material): - return Satisfaction.from_thresh(sat_material, self.k, self.subs) - - def dissatisfaction(self): - return sum( - [sub.dissatisfaction() for sub in self.subs], start=Satisfaction(witness=[]) - ) - - def __repr__(self): - return f"thresh({self.k},{','.join(map(str, self.subs))})" - - -class WrapperNode(Node): - """A virtual base class for wrappers. - - Don't instanciate it directly, use concret wrapper fragments instead. - """ - - def __init__(self, sub): - self.subs = [sub] - - # Properties for most wrappers are directly inherited. When it's not, they - # are overriden in the fragment's __init__. - self.needs_sig = sub.needs_sig - self.is_forced = sub.is_forced - self.is_expressive = sub.is_expressive - self.is_nonmalleable = sub.is_nonmalleable - self.abs_heightlocks = sub.abs_heightlocks - self.rel_heightlocks = sub.rel_heightlocks - self.abs_timelocks = sub.abs_timelocks - self.rel_timelocks = sub.rel_timelocks - self.no_timelock_mix = not ( - self.abs_heightlocks - and self.abs_timelocks - or self.rel_heightlocks - and self.rel_timelocks - ) - - @property - def sub(self): - # Wrapper have a single sub - return self.subs[0] - - def satisfaction(self, sat_material): - # Most wrappers are satisfied this way, for special cases it's overriden. - return self.subs[0].satisfaction(sat_material) - - def dissatisfaction(self): - # Most wrappers are satisfied this way, for special cases it's overriden. - return self.subs[0].dissatisfaction() - - def skip_colon(self): - # We need to check this because of the pk() and pkh() aliases. - if isinstance(self.subs[0], WrapC) and isinstance( - self.subs[0].subs[0], (Pk, Pkh) - ): - return False - return isinstance(self.subs[0], WrapperNode) - - -class WrapA(WrapperNode): - def __init__(self, sub): - assert sub.p.B - WrapperNode.__init__(self, sub) - - self.p = Property("W" + "".join(c for c in "ud" if getattr(sub.p, c))) - - @property - def _script(self): - return [OP_TOALTSTACK] + self.sub._script + [OP_FROMALTSTACK] - - @property - def exec_info(self): - return ExecutionInfo.from_wrap(self.sub.exec_info, ops_count=2) - - def __repr__(self): - # Don't duplicate colons - if self.skip_colon(): - return f"a{self.subs[0]}" - return f"a:{self.subs[0]}" - - -class WrapS(WrapperNode): - def __init__(self, sub): - assert sub.p.has_all("Bo") - WrapperNode.__init__(self, sub) - - self.p = Property("W" + "".join(c for c in "ud" if getattr(sub.p, c))) - - @property - def _script(self): - return [OP_SWAP] + self.sub._script - - @property - def exec_info(self): - return ExecutionInfo.from_wrap(self.sub.exec_info, ops_count=1) - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"s{self.subs[0]}" - return f"s:{self.subs[0]}" - - -class WrapC(WrapperNode): - def __init__(self, sub): - assert sub.p.K - WrapperNode.__init__(self, sub) - - # FIXME: shouldn't n and d be default props on the website? - self.p = Property("Bu" + "".join(c for c in "dno" if getattr(sub.p, c))) - - @property - def _script(self): - return self.sub._script + [OP_CHECKSIG] - - @property - def exec_info(self): - # FIXME: should need_sig be set to True here instead of in keys? - return ExecutionInfo.from_wrap(self.sub.exec_info, ops_count=1, sat=1, dissat=1) - - def __repr__(self): - # Special case of aliases - if isinstance(self.subs[0], Pk): - return f"pk({self.subs[0].pubkey})" - if isinstance(self.subs[0], Pkh): - return f"pkh({self.subs[0].pubkey})" - # Avoid duplicating colons - if self.skip_colon(): - return f"c{self.subs[0]}" - return f"c:{self.subs[0]}" - - -class WrapT(AndV, WrapperNode): - def __init__(self, sub): - AndV.__init__(self, sub, Just1()) - - def is_wrapper(self): - return True - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"t{self.subs[0]}" - return f"t:{self.subs[0]}" - - -class WrapD(WrapperNode): - def __init__(self, sub): - assert sub.p.has_all("Vz") - WrapperNode.__init__(self, sub) - - self.p = Property("Bond") - self.is_forced = True # sub is V - self.is_expressive = True # sub is V, and we add a single dissat - - @property - def _script(self): - return [OP_DUP, OP_IF] + self.sub._script + [OP_ENDIF] - - @property - def exec_info(self): - return ExecutionInfo.from_wrap_dissat( - self.sub.exec_info, ops_count=3, sat=1, dissat=1 - ) - - def satisfaction(self, sat_material): - return Satisfaction(witness=[b"\x01"]) + self.subs[0].satisfaction(sat_material) - - def dissatisfaction(self): - return Satisfaction(witness=[b""]) - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"d{self.subs[0]}" - return f"d:{self.subs[0]}" - - -class WrapV(WrapperNode): - def __init__(self, sub): - assert sub.p.B - WrapperNode.__init__(self, sub) - - self.p = Property("V" + "".join(c for c in "zon" if getattr(sub.p, c))) - self.is_forced = True # V - self.is_expressive = False # V - - @property - def _script(self): - if self.sub._script[-1] == OP_CHECKSIG: - return self.sub._script[:-1] + [OP_CHECKSIGVERIFY] - elif self.sub._script[-1] == OP_CHECKMULTISIG: - return self.sub._script[:-1] + [OP_CHECKMULTISIGVERIFY] - elif self.sub._script[-1] == OP_EQUAL: - return self.sub._script[:-1] + [OP_EQUALVERIFY] - return self.sub._script + [OP_VERIFY] - - @property - def exec_info(self): - verify_cost = int(self._script[-1] == OP_VERIFY) - return ExecutionInfo.from_wrap(self.sub.exec_info, ops_count=verify_cost) - - def dissatisfaction(self): - return Satisfaction.unavailable() # It's V. - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"v{self.subs[0]}" - return f"v:{self.subs[0]}" - - -class WrapJ(WrapperNode): - def __init__(self, sub): - assert sub.p.has_all("Bn") - WrapperNode.__init__(self, sub) - - self.p = Property("Bnd" + "".join(c for c in "ou" if getattr(sub.p, c))) - self.is_forced = False # d - self.is_expressive = sub.is_forced - - @property - def _script(self): - return [OP_SIZE, OP_0NOTEQUAL, OP_IF, *self.sub._script, OP_ENDIF] - - @property - def exec_info(self): - return ExecutionInfo.from_wrap_dissat(self.sub.exec_info, ops_count=4, dissat=1) - - def dissatisfaction(self): - return Satisfaction(witness=[b""]) - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"j{self.subs[0]}" - return f"j:{self.subs[0]}" - - -class WrapN(WrapperNode): - def __init__(self, sub): - assert sub.p.B - WrapperNode.__init__(self, sub) - - self.p = Property("Bu" + "".join(c for c in "zond" if getattr(sub.p, c))) - - @property - def _script(self): - return [*self.sub._script, OP_0NOTEQUAL] - - @property - def exec_info(self): - return ExecutionInfo.from_wrap(self.sub.exec_info, ops_count=1) - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"n{self.subs[0]}" - return f"n:{self.subs[0]}" - - -class WrapL(OrI, WrapperNode): - def __init__(self, sub): - OrI.__init__(self, Just0(), sub) - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"l{self.subs[1]}" - return f"l:{self.subs[1]}" - - -class WrapU(OrI, WrapperNode): - def __init__(self, sub): - OrI.__init__(self, sub, Just0()) - - def __repr__(self): - # Avoid duplicating colons - if self.skip_colon(): - return f"u{self.subs[0]}" - return f"u:{self.subs[0]}" diff --git a/bitcoin_client/ledger_bitcoin/bip380/miniscript/parsing.py b/bitcoin_client/ledger_bitcoin/bip380/miniscript/parsing.py deleted file mode 100644 index 2058b7b6b..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/miniscript/parsing.py +++ /dev/null @@ -1,736 +0,0 @@ -""" -Utilities to parse Miniscript from string and Script representations. -""" - -from ...bip380.miniscript import fragments - -from ...bip380.key import DescriptorKey -from ...bip380.miniscript.errors import MiniscriptMalformed -from ...bip380.utils.script import ( - CScriptOp, - OP_ADD, - OP_BOOLAND, - OP_BOOLOR, - OP_CHECKSIGVERIFY, - OP_CHECKMULTISIGVERIFY, - OP_EQUALVERIFY, - OP_DUP, - OP_ELSE, - OP_ENDIF, - OP_EQUAL, - OP_FROMALTSTACK, - OP_IFDUP, - OP_IF, - OP_CHECKLOCKTIMEVERIFY, - OP_CHECKMULTISIG, - OP_CHECKSEQUENCEVERIFY, - OP_CHECKSIG, - OP_HASH160, - OP_HASH256, - OP_NOTIF, - OP_RIPEMD160, - OP_SHA256, - OP_SIZE, - OP_SWAP, - OP_TOALTSTACK, - OP_VERIFY, - OP_0NOTEQUAL, - ScriptNumError, - read_script_number, -) - - -def stack_item_to_int(item): - """ - Convert a stack item to an integer depending on its type. - May raise an exception if the item is bytes, otherwise return None if it - cannot perform the conversion. - """ - if isinstance(item, bytes): - return read_script_number(item) - - if isinstance(item, fragments.Node): - if isinstance(item, fragments.Just1): - return 1 - if isinstance(item, fragments.Just0): - return 0 - - if isinstance(item, int): - return item - - return None - - -def decompose_script(script): - """Create a list of Script element from a CScript, decomposing the compact - -VERIFY opcodes into the non-VERIFY OP and an OP_VERIFY. - """ - elems = [] - for elem in script: - if elem == OP_CHECKSIGVERIFY: - elems += [OP_CHECKSIG, OP_VERIFY] - elif elem == OP_CHECKMULTISIGVERIFY: - elems += [OP_CHECKMULTISIG, OP_VERIFY] - elif elem == OP_EQUALVERIFY: - elems += [OP_EQUAL, OP_VERIFY] - else: - elems.append(elem) - return elems - - -def parse_term_single_elem(expr_list, idx): - """ - Try to parse a terminal node from the element of {expr_list} at {idx}. - """ - # Match against pk_k(key). - if ( - isinstance(expr_list[idx], bytes) - and len(expr_list[idx]) == 33 - and expr_list[idx][0] in [2, 3] - ): - expr_list[idx] = fragments.Pk(expr_list[idx]) - - # Match against JUST_1 and JUST_0. - if expr_list[idx] == 1: - expr_list[idx] = fragments.Just1() - if expr_list[idx] == b"": - expr_list[idx] = fragments.Just0() - - -def parse_term_2_elems(expr_list, idx): - """ - Try to parse a terminal node from two elements of {expr_list}, starting - from {idx}. - Return the new expression list on success, None if there was no match. - """ - elem_a = expr_list[idx] - elem_b = expr_list[idx + 1] - - # Only older() and after() as term with 2 stack items - if not isinstance(elem_b, CScriptOp): - return - try: - n = stack_item_to_int(elem_a) - if n is None: - return - except ScriptNumError: - return - - if n <= 0 or n >= 2 ** 31: - return - - if elem_b == OP_CHECKSEQUENCEVERIFY: - node = fragments.Older(n) - expr_list[idx: idx + 2] = [node] - return expr_list - - if elem_b == OP_CHECKLOCKTIMEVERIFY: - node = fragments.After(n) - expr_list[idx: idx + 2] = [node] - return expr_list - - -def parse_term_5_elems(expr_list, idx, pkh_preimages={}): - """ - Try to parse a terminal node from five elements of {expr_list}, starting - from {idx}. - Return the new expression list on success, None if there was no match. - """ - # The only 3 items node is pk_h - if expr_list[idx: idx + 2] != [OP_DUP, OP_HASH160]: - return - if not isinstance(expr_list[idx + 2], bytes): - return - if len(expr_list[idx + 2]) != 20: - return - if expr_list[idx + 3: idx + 5] != [OP_EQUAL, OP_VERIFY]: - return - - key_hash = expr_list[idx + 2] - key = pkh_preimages.get(key_hash) - assert key is not None # TODO: have a real error here - node = fragments.Pkh(key) - expr_list[idx: idx + 5] = [node] - return expr_list - - -def parse_term_7_elems(expr_list, idx): - """ - Try to parse a terminal node from seven elements of {expr_list}, starting - from {idx}. - Return the new expression list on success, None if there was no match. - """ - # Note how all the hashes are 7 elems because the VERIFY was decomposed - # Match against sha256. - if ( - expr_list[idx: idx + 5] == [OP_SIZE, b"\x20", OP_EQUAL, OP_VERIFY, OP_SHA256] - and isinstance(expr_list[idx + 5], bytes) - and len(expr_list[idx + 5]) == 32 - and expr_list[idx + 6] == OP_EQUAL - ): - node = fragments.Sha256(expr_list[idx + 5]) - expr_list[idx: idx + 7] = [node] - return expr_list - - # Match against hash256. - if ( - expr_list[idx: idx + 5] == [OP_SIZE, b"\x20", OP_EQUAL, OP_VERIFY, OP_HASH256] - and isinstance(expr_list[idx + 5], bytes) - and len(expr_list[idx + 5]) == 32 - and expr_list[idx + 6] == OP_EQUAL - ): - node = fragments.Hash256(expr_list[idx + 5]) - expr_list[idx: idx + 7] = [node] - return expr_list - - # Match against ripemd160. - if ( - expr_list[idx: idx + 5] - == [OP_SIZE, b"\x20", OP_EQUAL, OP_VERIFY, OP_RIPEMD160] - and isinstance(expr_list[idx + 5], bytes) - and len(expr_list[idx + 5]) == 20 - and expr_list[idx + 6] == OP_EQUAL - ): - node = fragments.Ripemd160(expr_list[idx + 5]) - expr_list[idx: idx + 7] = [node] - return expr_list - - # Match against hash160. - if ( - expr_list[idx: idx + 5] == [OP_SIZE, b"\x20", OP_EQUAL, OP_VERIFY, OP_HASH160] - and isinstance(expr_list[idx + 5], bytes) - and len(expr_list[idx + 5]) == 20 - and expr_list[idx + 6] == OP_EQUAL - ): - node = fragments.Hash160(expr_list[idx + 5]) - expr_list[idx: idx + 7] = [node] - return expr_list - - -def parse_nonterm_2_elems(expr_list, idx): - """ - Try to parse a non-terminal node from two elements of {expr_list}, starting - from {idx}. - Return the new expression list on success, None if there was no match. - """ - elem_a = expr_list[idx] - elem_b = expr_list[idx + 1] - - if isinstance(elem_a, fragments.Node): - # Match against and_v. - if isinstance(elem_b, fragments.Node) and elem_a.p.V and elem_b.p.has_any("BKV"): - # Is it a special case of t: wrapper? - if isinstance(elem_b, fragments.Just1): - node = fragments.WrapT(elem_a) - else: - node = fragments.AndV(elem_a, elem_b) - expr_list[idx: idx + 2] = [node] - return expr_list - - # Match against c wrapper. - if elem_b == OP_CHECKSIG and elem_a.p.K: - node = fragments.WrapC(elem_a) - expr_list[idx: idx + 2] = [node] - return expr_list - - # Match against v wrapper. - if elem_b == OP_VERIFY and elem_a.p.B: - node = fragments.WrapV(elem_a) - expr_list[idx: idx + 2] = [node] - return expr_list - - # Match against n wrapper. - if elem_b == OP_0NOTEQUAL and elem_a.p.B: - node = fragments.WrapN(elem_a) - expr_list[idx: idx + 2] = [node] - return expr_list - - # Match against s wrapper. - if isinstance(elem_b, fragments.Node) and elem_a == OP_SWAP and elem_b.p.has_all("Bo"): - node = fragments.WrapS(elem_b) - expr_list[idx: idx + 2] = [node] - return expr_list - - -def parse_nonterm_3_elems(expr_list, idx): - """ - Try to parse a non-terminal node from *at least* three elements of - {expr_list}, starting from {idx}. - Return the new expression list on success, None if there was no match. - """ - elem_a = expr_list[idx] - elem_b = expr_list[idx + 1] - elem_c = expr_list[idx + 2] - - if isinstance(elem_a, fragments.Node) and isinstance(elem_b, fragments.Node): - # Match against and_b. - if elem_c == OP_BOOLAND and elem_a.p.B and elem_b.p.W: - node = fragments.AndB(elem_a, elem_b) - expr_list[idx: idx + 3] = [node] - return expr_list - - # Match against or_b. - if elem_c == OP_BOOLOR and elem_a.p.has_all("Bd") and elem_b.p.has_all("Wd"): - node = fragments.OrB(elem_a, elem_b) - expr_list[idx: idx + 3] = [node] - return expr_list - - # Match against a wrapper. - if ( - elem_a == OP_TOALTSTACK - and isinstance(elem_b, fragments.Node) - and elem_b.p.B - and elem_c == OP_FROMALTSTACK - ): - node = fragments.WrapA(elem_b) - expr_list[idx: idx + 3] = [node] - return expr_list - - # FIXME: multi is a terminal! - # Match against a multi. - try: - k = stack_item_to_int(expr_list[idx]) - except ScriptNumError: - return - if k is None: - return - # ()* CHECKMULTISIG - if k > len(expr_list[idx + 1:]) - 2: - return - # Get the keys - keys = [] - i = idx + 1 - while idx < len(expr_list) - 2: - if not isinstance(expr_list[i], fragments.Pk): - break - keys.append(expr_list[i].pubkey) - i += 1 - if expr_list[i + 1] == OP_CHECKMULTISIG: - if k > len(keys): - return - try: - m = stack_item_to_int(expr_list[i]) - except ScriptNumError: - return - if m is None or m != len(keys): - return - node = fragments.Multi(k, keys) - expr_list[idx: i + 2] = [node] - return expr_list - - -def parse_nonterm_4_elems(expr_list, idx): - """ - Try to parse a non-terminal node from at least four elements of {expr_list}, - starting from {idx}. - Return the new expression list on success, None if there was no match. - """ - (it_a, it_b, it_c, it_d) = expr_list[idx: idx + 4] - - # Match against thresh. It's of the form [X] ([X] ADD)* k EQUAL - if isinstance(it_a, fragments.Node) and it_a.p.has_all("Bdu"): - subs = [it_a] - # The first matches, now do all the ([X] ADD)s and return - # if a pair is of the form (k, EQUAL). - for i in range(idx + 1, len(expr_list) - 1, 2): - if ( - isinstance(expr_list[i], fragments.Node) - and expr_list[i].p.has_all("Wdu") - and expr_list[i + 1] == OP_ADD - ): - subs.append(expr_list[i]) - continue - elif expr_list[i + 1] == OP_EQUAL: - try: - k = stack_item_to_int(expr_list[i]) - if len(subs) >= k >= 1: - node = fragments.Thresh(k, subs) - expr_list[idx: i + 1 + 1] = [node] - return expr_list - except ScriptNumError: - break - else: - break - - # Match against or_c. - if ( - isinstance(it_a, fragments.Node) - and it_a.p.has_all("Bdu") - and it_b == OP_NOTIF - and isinstance(it_c, fragments.Node) - and it_c.p.V - and it_d == OP_ENDIF - ): - node = fragments.OrC(it_a, it_c) - expr_list[idx: idx + 4] = [node] - return expr_list - - # Match against d wrapper. - if ( - [it_a, it_b] == [OP_DUP, OP_IF] - and isinstance(it_c, fragments.Node) - and it_c.p.has_all("Vz") - and it_d == OP_ENDIF - ): - node = fragments.WrapD(it_c) - expr_list[idx: idx + 4] = [node] - return expr_list - - -def parse_nonterm_5_elems(expr_list, idx): - """ - Try to parse a non-terminal node from five elements of {expr_list}, starting - from {idx}. - Return the new expression list on success, None if there was no match. - """ - (it_a, it_b, it_c, it_d, it_e) = expr_list[idx: idx + 5] - - # Match against or_d. - if ( - isinstance(it_a, fragments.Node) - and it_a.p.has_all("Bdu") - and [it_b, it_c] == [OP_IFDUP, OP_NOTIF] - and isinstance(it_d, fragments.Node) - and it_d.p.B - and it_e == OP_ENDIF - ): - node = fragments.OrD(it_a, it_d) - expr_list[idx: idx + 5] = [node] - return expr_list - - # Match against or_i. - if ( - it_a == OP_IF - and isinstance(it_b, fragments.Node) - and it_b.p.has_any("BKV") - and it_c == OP_ELSE - and isinstance(it_d, fragments.Node) - and it_d.p.has_any("BKV") - and it_e == OP_ENDIF - ): - if isinstance(it_b, fragments.Just0): - node = fragments.WrapL(it_d) - elif isinstance(it_d, fragments.Just0): - node = fragments.WrapU(it_b) - else: - node = fragments.OrI(it_b, it_d) - expr_list[idx: idx + 5] = [node] - return expr_list - - # Match against j wrapper. - if ( - [it_a, it_b, it_c] == [OP_SIZE, OP_0NOTEQUAL, OP_IF] - and isinstance(it_d, fragments.Node) - and it_e == OP_ENDIF - ): - node = fragments.WrapJ(expr_list[idx + 3]) - expr_list[idx: idx + 5] = [node] - return expr_list - - -def parse_nonterm_6_elems(expr_list, idx): - """ - Try to parse a non-terminal node from six elements of {expr_list}, starting - from {idx}. - Return the new expression list on success, None if there was no match. - """ - (it_a, it_b, it_c, it_d, it_e, it_f) = expr_list[idx: idx + 6] - - # Match against andor. - if ( - isinstance(it_a, fragments.Node) - and it_a.p.has_all("Bdu") - and it_b == OP_NOTIF - and isinstance(it_c, fragments.Node) - and it_c.p.has_any("BKV") - and it_d == OP_ELSE - and isinstance(it_e, fragments.Node) - and it_e.p.has_any("BKV") - and it_f == OP_ENDIF - ): - if isinstance(it_c, fragments.Just0): - node = fragments.AndN(it_a, it_e) - else: - node = fragments.AndOr(it_a, it_e, it_c) - expr_list[idx: idx + 6] = [node] - return expr_list - - -def parse_expr_list(expr_list): - """Parse a node from a list of Script elements.""" - # Every recursive call must progress the AST construction, - # until it is complete (single root node remains). - expr_list_len = len(expr_list) - - # Root node reached. - if expr_list_len == 1 and isinstance(expr_list[0], fragments.Node): - return expr_list[0] - - # Step through each list index and match against templates. - idx = expr_list_len - 1 - while idx >= 0: - if expr_list_len - idx >= 2: - new_expr_list = parse_nonterm_2_elems(expr_list, idx) - if new_expr_list is not None: - return parse_expr_list(new_expr_list) - - if expr_list_len - idx >= 3: - new_expr_list = parse_nonterm_3_elems(expr_list, idx) - if new_expr_list is not None: - return parse_expr_list(new_expr_list) - - if expr_list_len - idx >= 4: - new_expr_list = parse_nonterm_4_elems(expr_list, idx) - if new_expr_list is not None: - return parse_expr_list(new_expr_list) - - if expr_list_len - idx >= 5: - new_expr_list = parse_nonterm_5_elems(expr_list, idx) - if new_expr_list is not None: - return parse_expr_list(new_expr_list) - - if expr_list_len - idx >= 6: - new_expr_list = parse_nonterm_6_elems(expr_list, idx) - if new_expr_list is not None: - return parse_expr_list(new_expr_list) - - # Right-to-left parsing. - # Step one position left. - idx -= 1 - - # No match found. - raise MiniscriptMalformed(f"{expr_list}") - - -def miniscript_from_script(script, pkh_preimages={}): - """Construct miniscript node from script. - - :param script: The Bitcoin Script to decode. - :param pkh_preimage: A mapping from keyhash to key to decode pk_h() fragments. - """ - expr_list = decompose_script(script) - expr_list_len = len(expr_list) - - # We first parse terminal expressions. - idx = 0 - while idx < expr_list_len: - parse_term_single_elem(expr_list, idx) - - if expr_list_len - idx >= 2: - new_expr_list = parse_term_2_elems(expr_list, idx) - if new_expr_list is not None: - expr_list = new_expr_list - expr_list_len = len(expr_list) - - if expr_list_len - idx >= 5: - new_expr_list = parse_term_5_elems(expr_list, idx, pkh_preimages) - if new_expr_list is not None: - expr_list = new_expr_list - expr_list_len = len(expr_list) - - if expr_list_len - idx >= 7: - new_expr_list = parse_term_7_elems(expr_list, idx) - if new_expr_list is not None: - expr_list = new_expr_list - expr_list_len = len(expr_list) - - idx += 1 - - # fragments.And then recursively parse non-terminal ones. - return parse_expr_list(expr_list) - - -def split_params(string): - """Read a list of values before the next ')'. Split the result by comma.""" - i = string.find(")") - assert i >= 0 - - params, remaining = string[:i], string[i:] - if len(remaining) > 0: - return params.split(","), remaining[1:] - else: - return params.split(","), "" - - -def parse_many(string): - """Read a list of nodes before the next ')'.""" - subs = [] - remaining = string - while True: - sub, remaining = parse_one(remaining) - subs.append(sub) - if remaining[0] == ")": - return subs, remaining[1:] - assert remaining[0] == "," # TODO: real errors - remaining = remaining[1:] - - -def parse_one_num(string): - """Read an integer before the next comma.""" - i = string.find(",") - assert i >= 0 - - return int(string[:i]), string[i + 1:] - - -def parse_one(string): - """Read a node and its subs recursively from a string. - Returns the node and the part of the string not consumed. - """ - - # We special case fragments.Just1 and fragments.Just0 since they are the only one which don't - # have a function syntax. - if string[0] == "0": - return fragments.Just0(), string[1:] - if string[0] == "1": - return fragments.Just1(), string[1:] - - # Now, find the separator for all functions. - for i, char in enumerate(string): - if char in ["(", ":"]: - break - # For wrappers, we may have many of them. - if char == ":" and i > 1: - tag, remaining = string[0], string[1:] - else: - tag, remaining = string[:i], string[i + 1:] - - # fragments.Wrappers - if char == ":": - sub, remaining = parse_one(remaining) - if tag == "a": - return fragments.WrapA(sub), remaining - - if tag == "s": - return fragments.WrapS(sub), remaining - - if tag == "c": - return fragments.WrapC(sub), remaining - - if tag == "t": - return fragments.WrapT(sub), remaining - - if tag == "d": - return fragments.WrapD(sub), remaining - - if tag == "v": - return fragments.WrapV(sub), remaining - - if tag == "j": - return fragments.WrapJ(sub), remaining - - if tag == "n": - return fragments.WrapN(sub), remaining - - if tag == "l": - return fragments.WrapL(sub), remaining - - if tag == "u": - return fragments.WrapU(sub), remaining - - assert False, (tag, sub, remaining) # TODO: real errors - - # Terminal elements other than 0 and 1 - if tag in [ - "pk", - "pkh", - "pk_k", - "pk_h", - "sha256", - "hash256", - "ripemd160", - "hash160", - "older", - "after", - "multi", - ]: - params, remaining = split_params(remaining) - - if tag == "0": - return fragments.Just0(), remaining - - if tag == "1": - return fragments.Just1(), remaining - - if tag == "pk": - return fragments.WrapC(fragments.Pk(params[0])), remaining - - if tag == "pk_k": - return fragments.Pk(params[0]), remaining - - if tag == "pkh": - return fragments.WrapC(fragments.Pkh(params[0])), remaining - - if tag == "pk_h": - return fragments.Pkh(params[0]), remaining - - if tag == "older": - value = int(params[0]) - return fragments.Older(value), remaining - - if tag == "after": - value = int(params[0]) - return fragments.After(value), remaining - - if tag in ["sha256", "hash256", "ripemd160", "hash160"]: - digest = bytes.fromhex(params[0]) - if tag == "sha256": - return fragments.Sha256(digest), remaining - if tag == "hash256": - return fragments.Hash256(digest), remaining - if tag == "ripemd160": - return fragments.Ripemd160(digest), remaining - return fragments.Hash160(digest), remaining - - if tag == "multi": - k = int(params.pop(0)) - key_n = [] - for param in params: - key_obj = DescriptorKey(param) - key_n.append(key_obj) - return fragments.Multi(k, key_n), remaining - - assert False, (tag, params, remaining) - - # Non-terminal elements (connectives) - # We special case fragments.Thresh, as its first sub is an integer. - if tag == "thresh": - k, remaining = parse_one_num(remaining) - # TODO: real errors in place of unpacking - subs, remaining = parse_many(remaining) - - if tag == "and_v": - return fragments.AndV(*subs), remaining - - if tag == "and_b": - return fragments.AndB(*subs), remaining - - if tag == "and_n": - return fragments.AndN(*subs), remaining - - if tag == "or_b": - return fragments.OrB(*subs), remaining - - if tag == "or_c": - return fragments.OrC(*subs), remaining - - if tag == "or_d": - return fragments.OrD(*subs), remaining - - if tag == "or_i": - return fragments.OrI(*subs), remaining - - if tag == "andor": - return fragments.AndOr(*subs), remaining - - if tag == "thresh": - return fragments.Thresh(k, subs), remaining - - assert False, (tag, subs, remaining) # TODO - - -def miniscript_from_str(ms_str): - """Construct miniscript node from string representation""" - node, remaining = parse_one(ms_str) - assert remaining == "" - return node diff --git a/bitcoin_client/ledger_bitcoin/bip380/miniscript/property.py b/bitcoin_client/ledger_bitcoin/bip380/miniscript/property.py deleted file mode 100644 index 5cff50b79..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/miniscript/property.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2020 The Bitcoin Core developers -# Copyright (c) 2021 Antoine Poinsot -# Distributed under the MIT software license, see the accompanying -# file LICENSE or http://www.opensource.org/licenses/mit-license.php. - -from .errors import MiniscriptPropertyError - - -# TODO: implement __eq__ -class Property: - """Miniscript expression property""" - - # "B": Base type - # "V": Verify type - # "K": Key type - # "W": Wrapped type - # "z": Zero-arg property - # "o": One-arg property - # "n": Nonzero arg property - # "d": Dissatisfiable property - # "u": Unit property - types = "BVKW" - props = "zondu" - - def __init__(self, property_str=""): - """Create a property, optionally from a str of property and types""" - allowed = self.types + self.props - invalid = set(property_str).difference(set(allowed)) - - if invalid: - raise MiniscriptPropertyError( - f"Invalid property/type character(s) '{''.join(invalid)}'" - f" (allowed: '{allowed}')" - ) - - for literal in allowed: - setattr(self, literal, literal in property_str) - - self.check_valid() - - def __repr__(self): - """Generate string representation of property""" - return "".join([c for c in self.types + self.props if getattr(self, c)]) - - def has_all(self, properties): - """Given a str of types and properties, return whether we have all of them""" - return all([getattr(self, pt) for pt in properties]) - - def has_any(self, properties): - """Given a str of types and properties, return whether we have at least one of them""" - return any([getattr(self, pt) for pt in properties]) - - def check_valid(self): - """Raises a MiniscriptPropertyError if the types/properties conflict""" - # Can only be of a single type. - if len(self.type()) > 1: - raise MiniscriptPropertyError(f"A Miniscript fragment can only be of a single type, got '{self.type()}'") - - # Check for conflicts in type & properties. - checks = [ - # (type/property, must_be, must_not_be) - ("K", "u", ""), - ("V", "", "du"), - ("z", "", "o"), - ("n", "", "z"), - ] - conflicts = [] - - for (attr, must_be, must_not_be) in checks: - if not getattr(self, attr): - continue - if not self.has_all(must_be): - conflicts.append(f"{attr} must be {must_be}") - if self.has_any(must_not_be): - conflicts.append(f"{attr} must not be {must_not_be}") - if conflicts: - raise MiniscriptPropertyError(f"Conflicting types and properties: {', '.join(conflicts)}") - - def type(self): - return "".join(filter(lambda x: x in self.types, str(self))) - - def properties(self): - return "".join(filter(lambda x: x in self.props, str(self))) diff --git a/bitcoin_client/ledger_bitcoin/bip380/miniscript/satisfaction.py b/bitcoin_client/ledger_bitcoin/bip380/miniscript/satisfaction.py deleted file mode 100644 index 67e878060..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/miniscript/satisfaction.py +++ /dev/null @@ -1,409 +0,0 @@ -""" -Miniscript satisfaction. - -This module contains logic for "signing for" a Miniscript (constructing a valid witness -that meets the conditions set by the Script) and analysis of such satisfaction(s) (eg the -maximum cost in a given resource). -This is currently focused on non-malleable satisfaction. We take shortcuts to not care about -non-canonical (dis)satisfactions. -""" - - -def add_optional(a, b): - """Add two numbers that may be None together.""" - if a is None or b is None: - return None - return a + b - - -def max_optional(a, b): - """Return the maximum of two numbers that may be None.""" - if a is None: - return b - if b is None: - return a - return max(a, b) - - -class SatisfactionMaterial: - """Data that may be needed in order to satisfy a Minsicript fragment.""" - - def __init__( - self, preimages={}, signatures={}, max_sequence=2 ** 32, max_lock_time=2 ** 32 - ): - """ - :param preimages: Mapping from a hash (as bytes), to its 32-bytes preimage. - :param signatures: Mapping from a public key (as bytes), to a signature for this key. - :param max_sequence: The maximum relative timelock possible (coin age). - :param max_lock_time: The maximum absolute timelock possible (block height). - """ - self.preimages = preimages - self.signatures = signatures - self.max_sequence = max_sequence - self.max_lock_time = max_lock_time - - def clear(self): - self.preimages.clear() - self.signatures.clear() - self.max_sequence = 0 - self.max_lock_time = 0 - - def __repr__(self): - return ( - f"SatisfactionMaterial(preimages: {self.preimages}, signatures: " - f"{self.signatures}, max_sequence: {self.max_sequence}, max_lock_time: " - f"{self.max_lock_time}" - ) - - -class Satisfaction: - """All information about a satisfaction.""" - - def __init__(self, witness, has_sig=False): - assert isinstance(witness, list) or witness is None - self.witness = witness - self.has_sig = has_sig - # TODO: we probably need to take into account non-canon sats, as the algorithm - # described on the website mandates it: - # > Iterate over all the valid satisfactions/dissatisfactions in the table above - # > (including the non-canonical ones), - - def __add__(self, other): - """Concatenate two satisfactions together.""" - witness = add_optional(self.witness, other.witness) - has_sig = self.has_sig or other.has_sig - return Satisfaction(witness, has_sig) - - def __or__(self, other): - """Choose between two (dis)satisfactions.""" - assert isinstance(other, Satisfaction) - - # If one isn't available, return the other one. - if self.witness is None: - return other - if other.witness is None: - return self - - # > If among all valid solutions (including DONTUSE ones) more than one does not - # > have the HASSIG marker, return DONTUSE, as this is malleable because of reason - # > 1. - # TODO - # if not (self.has_sig or other.has_sig): - # return Satisfaction.unavailable() - - # > If instead exactly one does not have the HASSIG marker, return that solution - # > because of reason 2. - if self.has_sig and not other.has_sig: - return other - if not self.has_sig and other.has_sig: - return self - - # > Otherwise, all not-DONTUSE options are valid, so return the smallest one (in - # > terms of witness size). - if self.size() > other.size(): - return other - - # > If all valid solutions have the HASSIG marker, but all of them are DONTUSE, return DONTUSE-HASSIG. - # TODO - - return self - - def unavailable(): - return Satisfaction(witness=None) - - def is_unavailable(self): - return self.witness is None - - def size(self): - return len(self.witness) + sum(len(elem) for elem in self.witness) - - def from_concat(sat_material, sub_a, sub_b, disjunction=False): - """Get the satisfaction for a Miniscript whose Script corresponds to a - concatenation of two subscripts A and B. - - :param sub_a: The sub-fragment A. - :param sub_b: The sub-fragment B. - :param disjunction: Whether this fragment has an 'or()' semantic. - """ - if disjunction: - return (sub_b.dissatisfaction() + sub_a.satisfaction(sat_material)) | ( - sub_b.satisfaction(sat_material) + sub_a.dissatisfaction() - ) - return sub_b.satisfaction(sat_material) + sub_a.satisfaction(sat_material) - - def from_or_uneven(sat_material, sub_a, sub_b): - """Get the satisfaction for a Miniscript which unconditionally executes a first - sub A and only executes B if A was dissatisfied. - - :param sub_a: The sub-fragment A. - :param sub_b: The sub-fragment B. - """ - return sub_a.satisfaction(sat_material) | ( - sub_b.satisfaction(sat_material) + sub_a.dissatisfaction() - ) - - def from_thresh(sat_material, k, subs): - """Get the satisfaction for a Miniscript which satisfies k of the given subs, - and dissatisfies all the others. - - :param sat_material: The material to satisfy the challenges. - :param k: The number of subs that need to be satisfied. - :param subs: The list of all subs of the threshold. - """ - # Pick the k sub-fragments to satisfy, prefering (in order): - # 1. Fragments that don't require a signature to be satisfied - # 2. Fragments whose satisfaction's size is smaller - # Record the unavailable (in either way) ones as we go. - arbitrage, unsatisfiable, undissatisfiable = [], [], [] - for sub in subs: - sat, dissat = sub.satisfaction(sat_material), sub.dissatisfaction() - if sat.witness is None: - unsatisfiable.append(sub) - elif dissat.witness is None: - undissatisfiable.append(sub) - else: - arbitrage.append( - (int(sat.has_sig), len(sat.witness) - len(dissat.witness), sub) - ) - - # If not enough (dis)satisfactions are available, fail. - if len(unsatisfiable) > len(subs) - k or len(undissatisfiable) > k: - return Satisfaction.unavailable() - - # Otherwise, satisfy the k most optimal ones. - arbitrage = sorted(arbitrage, key=lambda x: x[:2]) - optimal_sat = undissatisfiable + [a[2] for a in arbitrage] + unsatisfiable - to_satisfy = set(optimal_sat[:k]) - return sum( - [ - sub.satisfaction(sat_material) - if sub in to_satisfy - else sub.dissatisfaction() - for sub in subs[::-1] - ], - start=Satisfaction(witness=[]), - ) - - -class ExecutionInfo: - """Information about the execution of a Miniscript.""" - - def __init__(self, stat_ops, _dyn_ops, sat_size, dissat_size): - # The *maximum* number of *always* executed non-PUSH Script OPs to satisfy this - # Miniscript fragment non-malleably. - self._static_ops_count = stat_ops - # The maximum possible number of counted-as-executed-by-interpreter OPs if this - # fragment is executed. - # It is only >0 for an executed multi() branch. That is, for a CHECKMULTISIG that - # is not part of an unexecuted branch of an IF .. ENDIF. - self._dyn_ops_count = _dyn_ops - # The *maximum* number of stack elements to satisfy this Miniscript fragment - # non-malleably. - self.sat_elems = sat_size - # The *maximum* number of stack elements to dissatisfy this Miniscript fragment - # non-malleably. - self.dissat_elems = dissat_size - - @property - def ops_count(self): - """ - The worst-case number of OPs that would be considered executed by the Script - interpreter. - Note it is considered alone and not necessarily coherent with the other maxima. - """ - return self._static_ops_count + self._dyn_ops_count - - def is_dissatisfiable(self): - """Whether the Miniscript is *non-malleably* dissatisfiable.""" - return self.dissat_elems is not None - - def set_undissatisfiable(self): - """Set the Miniscript as being impossible to dissatisfy.""" - self.dissat_elems = None - - def from_concat(sub_a, sub_b, ops_count=0, disjunction=False): - """Compute the execution info from a Miniscript whose Script corresponds to - a concatenation of two subscript A and B. - - :param sub_a: The execution information of the subscript A. - :param sub_b: The execution information of the subscript B. - :param ops_count: The added number of static OPs added on top. - :param disjunction: Whether this fragment has an 'or()' semantic. - """ - # Number of static OPs is simple, they are all executed. - static_ops = sub_a._static_ops_count + sub_b._static_ops_count + ops_count - # Same for the dynamic ones, there is no conditional branch here. - dyn_ops = sub_a._dyn_ops_count + sub_b._dyn_ops_count - # If this is an 'or', only one needs to be satisfied. Pick the most expensive - # satisfaction/dissatisfaction pair. - # If not, both need to be anyways. - if disjunction: - first = add_optional(sub_a.sat_elems, sub_b.dissat_elems) - second = add_optional(sub_a.dissat_elems, sub_b.sat_elems) - sat_elems = max_optional(first, second) - else: - sat_elems = add_optional(sub_a.sat_elems, sub_b.sat_elems) - # In any case dissatisfying the fragment requires dissatisfying both concatenated - # subs. - dissat_elems = add_optional(sub_a.dissat_elems, sub_b.dissat_elems) - - return ExecutionInfo(static_ops, dyn_ops, sat_elems, dissat_elems) - - def from_or_uneven(sub_a, sub_b, ops_count=0): - """Compute the execution info from a Miniscript which always executes A and only - executes B depending on the outcome of A's execution. - - :param sub_a: The execution information of the subscript A. - :param sub_b: The execution information of the subscript B. - :param ops_count: The added number of static OPs added on top. - """ - # Number of static OPs is simple, they are all executed. - static_ops = sub_a._static_ops_count + sub_b._static_ops_count + ops_count - # If the first sub is non-malleably dissatisfiable, the worst case is executing - # both. Otherwise it is necessarily satisfying only the first one. - if sub_a.is_dissatisfiable(): - dyn_ops = sub_a._dyn_ops_count + sub_b._dyn_ops_count - else: - dyn_ops = sub_a._dyn_ops_count - # Either we satisfy A, or satisfy B (and thereby dissatisfy A). Pick the most - # expensive. - first = sub_a.sat_elems - second = add_optional(sub_a.dissat_elems, sub_b.sat_elems) - sat_elems = max_optional(first, second) - # We only take canonical dissatisfactions into account. - dissat_elems = add_optional(sub_a.dissat_elems, sub_b.dissat_elems) - - return ExecutionInfo(static_ops, dyn_ops, sat_elems, dissat_elems) - - def from_or_even(sub_a, sub_b, ops_count): - """Compute the execution info from a Miniscript which executes either A or B, but - never both. - - :param sub_a: The execution information of the subscript A. - :param sub_b: The execution information of the subscript B. - :param ops_count: The added number of static OPs added on top. - """ - # Number of static OPs is simple, they are all executed. - static_ops = sub_a._static_ops_count + sub_b._static_ops_count + ops_count - # Only one of the branch is executed, pick the most expensive one. - dyn_ops = max(sub_a._dyn_ops_count, sub_b._dyn_ops_count) - # Same. Also, we add a stack element used to tell which branch to take. - sat_elems = add_optional(max_optional(sub_a.sat_elems, sub_b.sat_elems), 1) - # Same here. - dissat_elems = add_optional( - max_optional(sub_a.dissat_elems, sub_b.dissat_elems), 1 - ) - - return ExecutionInfo(static_ops, dyn_ops, sat_elems, dissat_elems) - - def from_andor_uneven(sub_a, sub_b, sub_c, ops_count=0): - """Compute the execution info from a Miniscript which always executes A, and then - executes B if A returned True else executes C. Semantic: or(and(A,B), C). - - :param sub_a: The execution information of the subscript A. - :param sub_b: The execution information of the subscript B. - :param sub_b: The execution information of the subscript C. - :param ops_count: The added number of static OPs added on top. - """ - # Number of static OPs is simple, they are all executed. - static_ops = ( - sum(sub._static_ops_count for sub in [sub_a, sub_b, sub_c]) + ops_count - ) - # If the first sub is non-malleably dissatisfiable, the worst case is executing - # it and the most expensive between B and C. - # If it isn't the worst case is then necessarily to execute A and B. - if sub_a.is_dissatisfiable(): - dyn_ops = sub_a._dyn_ops_count + max( - sub_b._dyn_ops_count, sub_c._dyn_ops_count - ) - else: - # If the first isn't non-malleably dissatisfiable, the worst case is - # satisfying it (and necessarily satisfying the second one too) - dyn_ops = sub_a._dyn_ops_count + sub_b._dyn_ops_count - # Same for the number of stack elements (implicit from None here). - first = add_optional(sub_a.sat_elems, sub_b.sat_elems) - second = add_optional(sub_a.dissat_elems, sub_c.sat_elems) - sat_elems = max_optional(first, second) - # The only canonical dissatisfaction is dissatisfying A and C. - dissat_elems = add_optional(sub_a.dissat_elems, sub_c.dissat_elems) - - return ExecutionInfo(static_ops, dyn_ops, sat_elems, dissat_elems) - - # TODO: i think it'd be possible to not have this be special-cased to 'thresh()' - def from_thresh(k, subs): - """Compute the execution info from a Miniscript 'thresh()' fragment. Specialized - to this specifc fragment for now. - - :param k: The actual threshold of the 'thresh()' fragment. - :param subs: All the possible sub scripts. - """ - # All the OPs from the subs + n-1 * OP_ADD + 1 * OP_EQUAL - static_ops = sum(sub._static_ops_count for sub in subs) + len(subs) - # dyn_ops = sum(sorted([sub._dyn_ops_count for sub in subs], reverse=True)[:k]) - # All subs are executed, there is no OP_IF branch. - dyn_ops = sum([sub._dyn_ops_count for sub in subs]) - - # In order to estimate the worst case we simulate to satisfy the k subs whose - # sat/dissat ratio is the largest, and dissatisfy the others. - # We do so by iterating through all the subs, recording their sat-dissat "score" - # and those that either cannot be satisfied or dissatisfied. - arbitrage, unsatisfiable, undissatisfiable = [], [], [] - for sub in subs: - if sub.sat_elems is None: - unsatisfiable.append(sub) - elif sub.dissat_elems is None: - undissatisfiable.append(sub) - else: - arbitrage.append((sub.sat_elems - sub.dissat_elems, sub)) - # Of course, if too many can't be (dis)satisfied, we have a problem. - # Otherwise, simulate satisfying first the subs that must be (no dissatisfaction) - # then the most expensive ones, and then dissatisfy all the others. - if len(unsatisfiable) > len(subs) - k or len(undissatisfiable) > k: - sat_elems = None - else: - arbitrage = sorted(arbitrage, key=lambda x: x[0], reverse=True) - worst_sat = undissatisfiable + [a[1] for a in arbitrage] + unsatisfiable - sat_elems = sum( - [sub.sat_elems for sub in worst_sat[:k]] - + [sub.dissat_elems for sub in worst_sat[k:]] - ) - if len(undissatisfiable) > 0: - dissat_elems = None - else: - dissat_elems = sum([sub.dissat_elems for sub in subs]) - - return ExecutionInfo(static_ops, dyn_ops, sat_elems, dissat_elems) - - def from_wrap(sub, ops_count, dyn=0, sat=0, dissat=0): - """Compute the execution info from a Miniscript which always executes a subscript - but adds some logic around. - - :param sub: The execution information of the single subscript. - :param ops_count: The added number of static OPs added on top. - :param dyn: The added number of dynamic OPs added on top. - :param sat: The added number of satisfaction stack elements added on top. - :param dissat: The added number of dissatisfcation stack elements added on top. - """ - return ExecutionInfo( - sub._static_ops_count + ops_count, - sub._dyn_ops_count + dyn, - add_optional(sub.sat_elems, sat), - add_optional(sub.dissat_elems, dissat), - ) - - def from_wrap_dissat(sub, ops_count, dyn=0, sat=0, dissat=0): - """Compute the execution info from a Miniscript which always executes a subscript - but adds some logic around. - - :param sub: The execution information of the single subscript. - :param ops_count: The added number of static OPs added on top. - :param dyn: The added number of dynamic OPs added on top. - :param sat: The added number of satisfaction stack elements added on top. - :param dissat: The added number of dissatisfcation stack elements added on top. - """ - return ExecutionInfo( - sub._static_ops_count + ops_count, - sub._dyn_ops_count + dyn, - add_optional(sub.sat_elems, sat), - dissat, - ) diff --git a/bitcoin_client/ledger_bitcoin/bip380/utils/bignum.py b/bitcoin_client/ledger_bitcoin/bip380/utils/bignum.py deleted file mode 100644 index 138493918..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/utils/bignum.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2015-2020 The Bitcoin Core developers -# Copyright (c) 2021 Antoine Poinsot -# Distributed under the MIT software license, see the accompanying -# file LICENSE or http://www.opensource.org/licenses/mit-license.php. -"""Big number routines. - -This file is taken from the Bitcoin Core test framework. It was previously -copied from python-bitcoinlib. -""" - -import struct - - -# generic big endian MPI format - - -def bn_bytes(v, have_ext=False): - ext = 0 - if have_ext: - ext = 1 - return ((v.bit_length() + 7) // 8) + ext - - -def bn2bin(v): - s = bytearray() - i = bn_bytes(v) - while i > 0: - s.append((v >> ((i - 1) * 8)) & 0xFF) - i -= 1 - return s - - -def bn2mpi(v): - have_ext = False - if v.bit_length() > 0: - have_ext = (v.bit_length() & 0x07) == 0 - - neg = False - if v < 0: - neg = True - v = -v - - s = struct.pack(b">I", bn_bytes(v, have_ext)) - ext = bytearray() - if have_ext: - ext.append(0) - v_bin = bn2bin(v) - if neg: - if have_ext: - ext[0] |= 0x80 - else: - v_bin[0] |= 0x80 - return s + ext + v_bin - - -# bitcoin-specific little endian format, with implicit size -def mpi2vch(s): - r = s[4:] # strip size - r = r[::-1] # reverse string, converting BE->LE - return r - - -def bn2vch(v): - return bytes(mpi2vch(bn2mpi(v))) diff --git a/bitcoin_client/ledger_bitcoin/bip380/utils/hashes.py b/bitcoin_client/ledger_bitcoin/bip380/utils/hashes.py deleted file mode 100644 index 1124dc57a..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/utils/hashes.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Common Bitcoin hashes. -""" - -import hashlib -from .ripemd_fallback import ripemd160_fallback - - -def sha256(data): - """{data} must be bytes, returns sha256(data)""" - assert isinstance(data, bytes) - return hashlib.sha256(data).digest() - - -def hash160(data): - """{data} must be bytes, returns ripemd160(sha256(data))""" - assert isinstance(data, bytes) - if 'ripemd160' in hashlib.algorithms_available: - return hashlib.new("ripemd160", sha256(data)).digest() - return ripemd160_fallback(sha256(data)) diff --git a/bitcoin_client/ledger_bitcoin/bip380/utils/ripemd_fallback.py b/bitcoin_client/ledger_bitcoin/bip380/utils/ripemd_fallback.py deleted file mode 100644 index a4043de9b..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/utils/ripemd_fallback.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) 2021 Pieter Wuille -# Distributed under the MIT software license, see the accompanying -# file COPYING or http://www.opensource.org/licenses/mit-license.php. -# -# Taken from https://github.com/bitcoin/bitcoin/blob/124e75a41ea0f3f0e90b63b0c41813184ddce2ab/test/functional/test_framework/ripemd160.py - -# fmt: off - -""" -Pure Python RIPEMD160 implementation. - -WARNING: This implementation is NOT constant-time. -Do not use without understanding the implications. -""" - -# Message schedule indexes for the left path. -ML = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 7, 4, 13, 1, 10, 6, 15, 3, 12, 0, 9, 5, 2, 14, 11, 8, - 3, 10, 14, 4, 9, 15, 8, 1, 2, 7, 0, 6, 13, 11, 5, 12, - 1, 9, 11, 10, 0, 8, 12, 4, 13, 3, 7, 15, 14, 5, 6, 2, - 4, 0, 5, 9, 7, 12, 2, 10, 14, 1, 3, 8, 11, 6, 15, 13 -] - -# Message schedule indexes for the right path. -MR = [ - 5, 14, 7, 0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12, - 6, 11, 3, 7, 0, 13, 5, 10, 14, 15, 8, 12, 4, 9, 1, 2, - 15, 5, 1, 3, 7, 14, 6, 9, 11, 8, 12, 2, 10, 0, 4, 13, - 8, 6, 4, 1, 3, 11, 15, 0, 5, 12, 2, 13, 9, 7, 10, 14, - 12, 15, 10, 4, 1, 5, 8, 7, 6, 2, 13, 14, 0, 3, 9, 11 -] - -# Rotation counts for the left path. -RL = [ - 11, 14, 15, 12, 5, 8, 7, 9, 11, 13, 14, 15, 6, 7, 9, 8, - 7, 6, 8, 13, 11, 9, 7, 15, 7, 12, 15, 9, 11, 7, 13, 12, - 11, 13, 6, 7, 14, 9, 13, 15, 14, 8, 13, 6, 5, 12, 7, 5, - 11, 12, 14, 15, 14, 15, 9, 8, 9, 14, 5, 6, 8, 6, 5, 12, - 9, 15, 5, 11, 6, 8, 13, 12, 5, 12, 13, 14, 11, 8, 5, 6 -] - -# Rotation counts for the right path. -RR = [ - 8, 9, 9, 11, 13, 15, 15, 5, 7, 7, 8, 11, 14, 14, 12, 6, - 9, 13, 15, 7, 12, 8, 9, 11, 7, 7, 12, 7, 6, 15, 13, 11, - 9, 7, 15, 11, 8, 6, 6, 14, 12, 13, 5, 14, 13, 13, 7, 5, - 15, 5, 8, 11, 14, 14, 6, 14, 6, 9, 12, 9, 12, 5, 15, 8, - 8, 5, 12, 9, 12, 5, 14, 6, 8, 13, 6, 5, 15, 13, 11, 11 -] - -# K constants for the left path. -KL = [0, 0x5a827999, 0x6ed9eba1, 0x8f1bbcdc, 0xa953fd4e] - -# K constants for the right path. -KR = [0x50a28be6, 0x5c4dd124, 0x6d703ef3, 0x7a6d76e9, 0] - - -def fi(x, y, z, i): - """The f1, f2, f3, f4, and f5 functions from the specification.""" - if i == 0: - return x ^ y ^ z - elif i == 1: - return (x & y) | (~x & z) - elif i == 2: - return (x | ~y) ^ z - elif i == 3: - return (x & z) | (y & ~z) - elif i == 4: - return x ^ (y | ~z) - else: - assert False - - -def rol(x, i): - """Rotate the bottom 32 bits of x left by i bits.""" - return ((x << i) | ((x & 0xffffffff) >> (32 - i))) & 0xffffffff - - -def compress(h0, h1, h2, h3, h4, block): - """Compress state (h0, h1, h2, h3, h4) with block.""" - # Left path variables. - al, bl, cl, dl, el = h0, h1, h2, h3, h4 - # Right path variables. - ar, br, cr, dr, er = h0, h1, h2, h3, h4 - # Message variables. - x = [int.from_bytes(block[4*i:4*(i+1)], 'little') for i in range(16)] - - # Iterate over the 80 rounds of the compression. - for j in range(80): - rnd = j >> 4 - # Perform left side of the transformation. - al = rol(al + fi(bl, cl, dl, rnd) + x[ML[j]] + KL[rnd], RL[j]) + el - al, bl, cl, dl, el = el, al, bl, rol(cl, 10), dl - # Perform right side of the transformation. - ar = rol(ar + fi(br, cr, dr, 4 - rnd) + x[MR[j]] + KR[rnd], RR[j]) + er - ar, br, cr, dr, er = er, ar, br, rol(cr, 10), dr - - # Compose old state, left transform, and right transform into new state. - return h1 + cl + dr, h2 + dl + er, h3 + el + ar, h4 + al + br, h0 + bl + cr - - -def ripemd160_fallback(data): - """Compute the RIPEMD-160 hash of data.""" - # Initialize state. - state = (0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0) - # Process full 64-byte blocks in the input. - for b in range(len(data) >> 6): - state = compress(*state, data[64*b:64*(b+1)]) - # Construct final blocks (with padding and size). - pad = b"\x80" + b"\x00" * ((119 - len(data)) & 63) - fin = data[len(data) & ~63:] + pad + (8 * len(data)).to_bytes(8, 'little') - # Process final blocks. - for b in range(len(fin) >> 6): - state = compress(*state, fin[64*b:64*(b+1)]) - # Produce output. - return b"".join((h & 0xffffffff).to_bytes(4, 'little') for h in state) \ No newline at end of file diff --git a/bitcoin_client/ledger_bitcoin/bip380/utils/script.py b/bitcoin_client/ledger_bitcoin/bip380/utils/script.py deleted file mode 100644 index 9ff0e703d..000000000 --- a/bitcoin_client/ledger_bitcoin/bip380/utils/script.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright (c) 2015-2020 The Bitcoin Core developers -# Copyright (c) 2021 Antoine Poinsot -# Distributed under the MIT software license, see the accompanying -# file LICENSE or http://www.opensource.org/licenses/mit-license.php. -"""Script utilities - -This file was taken from Bitcoin Core test framework, and was previously -modified from python-bitcoinlib. -""" -import struct - -from .bignum import bn2vch - - -OPCODE_NAMES = {} - - -class CScriptOp(int): - """A single script opcode""" - - __slots__ = () - - @staticmethod - def encode_op_pushdata(d): - """Encode a PUSHDATA op, returning bytes""" - if len(d) < 0x4C: - return b"" + bytes([len(d)]) + d # OP_PUSHDATA - elif len(d) <= 0xFF: - return b"\x4c" + bytes([len(d)]) + d # OP_PUSHDATA1 - elif len(d) <= 0xFFFF: - return b"\x4d" + struct.pack(b" 4: - raise ScriptNumError("Too large push") - - if size == 0: - return 0 - - # We always check for minimal encoding - if (data[size - 1] & 0x7f) == 0: - if size == 1 or (data[size - 2] & 0x80) == 0: - raise ScriptNumError("Non minimal encoding") - - res = int.from_bytes(data, byteorder="little") - - # Remove the sign bit if set, and negate the result - if data[size - 1] & 0x80: - return -(res & ~(0x80 << (size - 1))) - return res - - -class CScriptInvalidError(Exception): - """Base class for CScript exceptions""" - - pass - - -class CScriptTruncatedPushDataError(CScriptInvalidError): - """Invalid pushdata due to truncation""" - - def __init__(self, msg, data): - self.data = data - super(CScriptTruncatedPushDataError, self).__init__(msg) - - -# This is used, eg, for blockchain heights in coinbase scripts (bip34) -class CScriptNum: - __slots__ = ("value",) - - def __init__(self, d=0): - self.value = d - - @staticmethod - def encode(obj): - r = bytearray(0) - if obj.value == 0: - return bytes(r) - neg = obj.value < 0 - absvalue = -obj.value if neg else obj.value - while absvalue: - r.append(absvalue & 0xFF) - absvalue >>= 8 - if r[-1] & 0x80: - r.append(0x80 if neg else 0) - elif neg: - r[-1] |= 0x80 - return bytes([len(r)]) + r - - @staticmethod - def decode(vch): - result = 0 - # We assume valid push_size and minimal encoding - value = vch[1:] - if len(value) == 0: - return result - for i, byte in enumerate(value): - result |= int(byte) << 8 * i - if value[-1] >= 0x80: - # Mask for all but the highest result bit - num_mask = (2 ** (len(value) * 8) - 1) >> 1 - result &= num_mask - result *= -1 - return result - - -class CScript(bytes): - """Serialized script - - A bytes subclass, so you can use this directly whenever bytes are accepted. - Note that this means that indexing does *not* work - you'll get an index by - byte rather than opcode. This format was chosen for efficiency so that the - general case would not require creating a lot of little CScriptOP objects. - - iter(script) however does iterate by opcode. - """ - - __slots__ = () - - @classmethod - def __coerce_instance(cls, other): - # Coerce other into bytes - if isinstance(other, CScriptOp): - other = bytes([other]) - elif isinstance(other, CScriptNum): - if other.value == 0: - other = bytes([CScriptOp(OP_0)]) - else: - other = CScriptNum.encode(other) - elif isinstance(other, int): - if 0 <= other <= 16: - other = bytes([CScriptOp.encode_op_n(other)]) - elif other == -1: - other = bytes([OP_1NEGATE]) - else: - other = CScriptOp.encode_op_pushdata(bn2vch(other)) - elif isinstance(other, (bytes, bytearray)): - other = CScriptOp.encode_op_pushdata(other) - return other - - def __add__(self, other): - # Do the coercion outside of the try block so that errors in it are - # noticed. - other = self.__coerce_instance(other) - - try: - # bytes.__add__ always returns bytes instances unfortunately - return CScript(super(CScript, self).__add__(other)) - except TypeError: - raise TypeError("Can not add a %r instance to a CScript" % other.__class__) - - def join(self, iterable): - # join makes no sense for a CScript() - raise NotImplementedError - - def __new__(cls, value=b""): - if isinstance(value, bytes) or isinstance(value, bytearray): - return super(CScript, cls).__new__(cls, value) - else: - - def coerce_iterable(iterable): - for instance in iterable: - yield cls.__coerce_instance(instance) - - # Annoyingly on both python2 and python3 bytes.join() always - # returns a bytes instance even when subclassed. - return super(CScript, cls).__new__(cls, b"".join(coerce_iterable(value))) - - def raw_iter(self): - """Raw iteration - - Yields tuples of (opcode, data, sop_idx) so that the different possible - PUSHDATA encodings can be accurately distinguished, as well as - determining the exact opcode byte indexes. (sop_idx) - """ - i = 0 - while i < len(self): - sop_idx = i - opcode = self[i] - i += 1 - - if opcode > OP_PUSHDATA4: - yield (opcode, None, sop_idx) - else: - datasize = None - pushdata_type = None - if opcode < OP_PUSHDATA1: - pushdata_type = "PUSHDATA(%d)" % opcode - datasize = opcode - - elif opcode == OP_PUSHDATA1: - pushdata_type = "PUSHDATA1" - if i >= len(self): - raise CScriptInvalidError("PUSHDATA1: missing data length") - datasize = self[i] - i += 1 - - elif opcode == OP_PUSHDATA2: - pushdata_type = "PUSHDATA2" - if i + 1 >= len(self): - raise CScriptInvalidError("PUSHDATA2: missing data length") - datasize = self[i] + (self[i + 1] << 8) - i += 2 - - elif opcode == OP_PUSHDATA4: - pushdata_type = "PUSHDATA4" - if i + 3 >= len(self): - raise CScriptInvalidError("PUSHDATA4: missing data length") - datasize = ( - self[i] - + (self[i + 1] << 8) - + (self[i + 2] << 16) - + (self[i + 3] << 24) - ) - i += 4 - - else: - assert False # shouldn't happen - - data = bytes(self[i: i + datasize]) - - # Check for truncation - if len(data) < datasize: - raise CScriptTruncatedPushDataError( - "%s: truncated data" % pushdata_type, data - ) - - i += datasize - - yield (opcode, data, sop_idx) - - def __iter__(self): - """'Cooked' iteration - - Returns either a CScriptOP instance, an integer, or bytes, as - appropriate. - - See raw_iter() if you need to distinguish the different possible - PUSHDATA encodings. - """ - for (opcode, data, sop_idx) in self.raw_iter(): - if data is not None: - yield data - else: - opcode = CScriptOp(opcode) - - if opcode.is_small_int(): - yield opcode.decode_op_n() - else: - yield CScriptOp(opcode) - - def __repr__(self): - def _repr(o): - if isinstance(o, bytes): - return "x('%s')" % o.hex() - else: - return repr(o) - - ops = [] - i = iter(self) - while True: - op = None - try: - op = _repr(next(i)) - except CScriptTruncatedPushDataError as err: - op = "%s..." % (_repr(err.data), err) - break - except CScriptInvalidError as err: - op = "" % err - break - except StopIteration: - break - finally: - if op is not None: - ops.append(op) - - return "CScript([%s])" % ", ".join(ops) - - def GetSigOpCount(self, fAccurate): - """Get the SigOp count. - - fAccurate - Accurately count CHECKMULTISIG, see BIP16 for details. - - Note that this is consensus-critical. - """ - n = 0 - lastOpcode = OP_INVALIDOPCODE - for (opcode, data, sop_idx) in self.raw_iter(): - if opcode in (OP_CHECKSIG, OP_CHECKSIGVERIFY): - n += 1 - elif opcode in (OP_CHECKMULTISIG, OP_CHECKMULTISIGVERIFY): - if fAccurate and (OP_1 <= lastOpcode <= OP_16): - n += opcode.decode_op_n() - else: - n += 20 - lastOpcode = opcode - return n diff --git a/bitcoin_client/ledger_bitcoin/client.py b/bitcoin_client/ledger_bitcoin/client.py index 60ca5eec0..351370320 100644 --- a/bitcoin_client/ledger_bitcoin/client.py +++ b/bitcoin_client/ledger_bitcoin/client.py @@ -3,7 +3,9 @@ import base64 from io import BytesIO, BufferedReader -from .bip380.descriptors import Descriptor +from .embit.base import EmbitError +from .embit.descriptor import Descriptor +from .embit.networks import NETWORKS from .command_builder import BitcoinCommandBuilder, BitcoinInsType from .common import Chain, read_uint, read_varint @@ -111,12 +113,11 @@ def register_wallet(self, wallet: WalletPolicy) -> Tuple[bytes, bytes]: wallet_id = response[0:32] wallet_hmac = response[32:64] - if self._should_validate_address(wallet): - # sanity check: for miniscripts, derive the first address independently with python-bip380 - first_addr_device = self.get_wallet_address(wallet, wallet_hmac, 0, 0, False) + # sanity check: for miniscripts, derive the first address independently with python-bip380 + first_addr_device = self.get_wallet_address(wallet, wallet_hmac, 0, 0, False) - if first_addr_device != self._derive_segwit_address_for_policy(wallet, False, 0): - raise RuntimeError("Invalid address. Please update your Bitcoin app. If the problem persists, report a bug at https://github.com/LedgerHQ/app-bitcoin-new") + if first_addr_device != self._derive_address_for_policy(wallet, False, 0): + raise RuntimeError("Invalid address. Please update your Bitcoin app. If the problem persists, report a bug at https://github.com/LedgerHQ/app-bitcoin-new") return wallet_id, wallet_hmac @@ -154,11 +155,10 @@ def get_wallet_address( result = response.decode() - if self._should_validate_address(wallet): - # sanity check: for miniscripts, derive the address independently with python-bip380 + # sanity check: for miniscripts, derive the address independently with python-bip380 - if result != self._derive_segwit_address_for_policy(wallet, change, address_index): - raise RuntimeError("Invalid address. Please update your Bitcoin app. If the problem persists, report a bug at https://github.com/LedgerHQ/app-bitcoin-new") + if result != self._derive_address_for_policy(wallet, change, address_index): + raise RuntimeError("Invalid address. Please update your Bitcoin app. If the problem persists, report a bug at https://github.com/LedgerHQ/app-bitcoin-new") return result @@ -271,18 +271,16 @@ def sign_message(self, message: Union[str, bytes], bip32_path: str) -> str: return base64.b64encode(response).decode('utf-8') - def _should_validate_address(self, wallet: WalletPolicy) -> bool: - # TODO: extend to taproot miniscripts once supported - return wallet.descriptor_template.startswith("wsh(") and not wallet.descriptor_template.startswith("wsh(sortedmulti(") - - def _derive_segwit_address_for_policy(self, wallet: WalletPolicy, change: bool, address_index: int) -> bool: - desc = Descriptor.from_str(wallet.get_descriptor(change)) - desc.derive(address_index) - spk = desc.script_pubkey - if spk[0:2] != b'\x00\x20' or len(spk) != 34: - raise RuntimeError("Invalid scriptPubKey") - hrp = "bc" if self.chain == Chain.MAIN else "tb" - return segwit_addr.encode(hrp, 0, spk[2:]) + def _derive_address_for_policy(self, wallet: WalletPolicy, change: bool, address_index: int) -> Optional[str]: + desc_str = wallet.get_descriptor(change) + try: + desc = Descriptor.from_string(desc_str) + + desc = desc.derive(address_index) + net = NETWORKS['main'] if self.chain == Chain.MAIN else NETWORKS['test'] + return desc.script_pubkey().address(net) + except EmbitError: + return None def createClient(comm_client: Optional[TransportClient] = None, chain: Chain = Chain.MAIN, debug: bool = False) -> Union[LegacyClient, NewClient]: diff --git a/bitcoin_client/ledger_bitcoin/embit/LICENSE b/bitcoin_client/ledger_bitcoin/embit/LICENSE new file mode 100644 index 000000000..db295028b --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Stepan Snigirev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/bitcoin_client/ledger_bitcoin/embit/README.md b/bitcoin_client/ledger_bitcoin/embit/README.md new file mode 100644 index 000000000..66483faed --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/README.md @@ -0,0 +1,5 @@ +This is a stripped down version of the embit library, cloned at commit [189efc45](https://github.com/diybitcoinhardware/embit/tree/189efc4583d497a2b97632646daf1531d00442b0). + +Support for the `0` and `1` miniscript fragments was added after cloning. + +All the content of this folder is released according to the [LICENSE](LICENSE), as per the original repository. diff --git a/bitcoin_client/ledger_bitcoin/bip380/utils/__init__.py b/bitcoin_client/ledger_bitcoin/embit/__init__.py similarity index 100% rename from bitcoin_client/ledger_bitcoin/bip380/utils/__init__.py rename to bitcoin_client/ledger_bitcoin/embit/__init__.py diff --git a/bitcoin_client/ledger_bitcoin/embit/base.py b/bitcoin_client/ledger_bitcoin/embit/base.py new file mode 100644 index 000000000..9dbac7398 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/base.py @@ -0,0 +1,116 @@ +"""Base classes""" +from io import BytesIO +from binascii import hexlify, unhexlify + + +class EmbitError(Exception): + """Generic Embit error""" + + pass + + +class EmbitBase: + @classmethod + def read_from(cls, stream, *args, **kwargs): + """All classes should be readable from stream""" + raise NotImplementedError( + "%s doesn't implement reading from stream" % cls.__name__ + ) + + @classmethod + def parse(cls, s: bytes, *args, **kwargs): + """Parse raw bytes""" + stream = BytesIO(s) + res = cls.read_from(stream, *args, **kwargs) + if len(stream.read(1)) > 0: + raise EmbitError("Unexpected extra bytes") + return res + + def write_to(self, stream, *args, **kwargs) -> int: + """All classes should be writable to stream""" + raise NotImplementedError( + "%s doesn't implement writing to stream" % type(self).__name__ + ) + + def serialize(self, *args, **kwargs) -> bytes: + """Serialize instance to raw bytes""" + stream = BytesIO() + self.write_to(stream, *args, **kwargs) + return stream.getvalue() + + def to_string(self, *args, **kwargs) -> str: + """ + String representation. + If not implemented - uses hex or calls to_base58() method if defined. + """ + if hasattr(self, "to_base58"): + res = self.to_base58(*args, **kwargs) + if not isinstance(res, str): + raise ValueError("to_base58() must return string") + return res + return hexlify(self.serialize(*args, **kwargs)).decode() + + @classmethod + def from_string(cls, s, *args, **kwargs): + """Create class instance from string""" + if hasattr(cls, "from_base58"): + return cls.from_base58(s, *args, **kwargs) + return cls.parse(unhexlify(s)) + + def __str__(self): + """Internally calls `to_string()` method with no arguments""" + return self.to_string() + + def __repr__(self): + try: + return type(self).__name__ + "(%s)" % str(self) + except: + return type(self).__name__ + "()" + + def __eq__(self, other): + """Compare two objects by checking their serializations""" + if not hasattr(other, "serialize"): + return False + return self.serialize() == other.serialize() + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.serialize()) + + +class EmbitKey(EmbitBase): + def sec(self) -> bytes: + """ + Any EmbitKey should implement sec() method that returns + a sec-serialized public key + """ + raise NotImplementedError( + "%s doesn't implement sec() method" % type(self).__name__ + ) + + def xonly(self) -> bytes: + """xonly representation of the key""" + return self.sec()[1:33] + + @property + def is_private(self) -> bool: + """ + Any EmbitKey should implement `is_private` property to distinguish + between private and public keys. + """ + raise NotImplementedError( + "%s doesn't implement is_private property" % type(self).__name__ + ) + + def __lt__(self, other): + # for lexagraphic ordering + return self.sec() < other.sec() + + def __gt__(self, other): + # for lexagraphic ordering + return self.sec() > other.sec() + + def __hash__(self): + return hash(self.serialize()) diff --git a/bitcoin_client/ledger_bitcoin/embit/base58.py b/bitcoin_client/ledger_bitcoin/embit/base58.py new file mode 100644 index 000000000..196a79dbf --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/base58.py @@ -0,0 +1,79 @@ +# Partially copy-pasted from python-bitcoinlib: +# https://github.com/petertodd/python-bitcoinlib/blob/master/bitcoin/base58.py + +"""Base58 encoding and decoding""" + +import binascii +from . import hashes + +B58_DIGITS = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" + + +def encode(b: bytes) -> str: + """Encode bytes to a base58-encoded string""" + + # Convert big-endian bytes to integer + n = int("0x0" + binascii.hexlify(b).decode("utf8"), 16) + + # Divide that integer into bas58 + chars = [] + while n > 0: + n, r = divmod(n, 58) + chars.append(B58_DIGITS[r]) + result = "".join(chars[::-1]) + + pad = 0 + for c in b: + if c == 0: + pad += 1 + else: + break + return B58_DIGITS[0] * pad + result + + +def decode(s: str) -> bytes: + """Decode a base58-encoding string, returning bytes""" + if not s: + return b"" + + # Convert the string to an integer + n = 0 + for c in s: + n *= 58 + if c not in B58_DIGITS: + raise ValueError("Character %r is not a valid base58 character" % c) + digit = B58_DIGITS.index(c) + n += digit + + # Convert the integer to bytes + h = "%x" % n + if len(h) % 2: + h = "0" + h + res = binascii.unhexlify(h.encode("utf8")) + + # Add padding back. + pad = 0 + for c in s[:-1]: + if c == B58_DIGITS[0]: + pad += 1 + else: + break + return b"\x00" * pad + res + + +def encode_check(b: bytes) -> str: + """Encode bytes to a base58-encoded string with a checksum""" + return encode(b + hashes.double_sha256(b)[0:4]) + + +def decode_check(s: str) -> bytes: + """Decode a base58-encoding string with checksum check. + Returns bytes without checksum + """ + b = decode(s) + checksum = hashes.double_sha256(b[:-4])[:4] + if b[-4:] != checksum: + raise ValueError( + "Checksum mismatch: expected %r, calculated %r" % (b[-4:], checksum) + ) + return b[:-4] diff --git a/bitcoin_client/ledger_bitcoin/embit/bech32.py b/bitcoin_client/ledger_bitcoin/embit/bech32.py new file mode 100644 index 000000000..24ee7fa72 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/bech32.py @@ -0,0 +1,146 @@ +# Copyright (c) 2017 Pieter Wuille +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +"""Reference implementation for Bech32 and segwit addresses.""" +from .misc import const + +CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" +BECH32_CONST = const(1) +BECH32M_CONST = const(0x2BC830A3) + + +class Encoding: + """Enumeration type to list the various supported encodings.""" + + BECH32 = 1 + BECH32M = 2 + + +def bech32_polymod(values): + """Internal function that computes the Bech32 checksum.""" + generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3] + chk = 1 + for value in values: + top = chk >> 25 + chk = (chk & 0x1FFFFFF) << 5 ^ value + for i in range(5): + chk ^= generator[i] if ((top >> i) & 1) else 0 + return chk + + +def bech32_hrp_expand(hrp: str): + """Expand the HRP into values for checksum computation.""" + return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp] + + +def bech32_verify_checksum(hrp, data): + """Verify a checksum given HRP and converted data characters.""" + check = bech32_polymod(bech32_hrp_expand(hrp) + data) + if check == BECH32_CONST: + return Encoding.BECH32 + elif check == BECH32M_CONST: + return Encoding.BECH32M + else: + return None + + +def bech32_create_checksum(encoding, hrp, data): + """Compute the checksum values given HRP and data.""" + values = bech32_hrp_expand(hrp) + data + const = BECH32M_CONST if encoding == Encoding.BECH32M else BECH32_CONST + polymod = bech32_polymod(values + [0, 0, 0, 0, 0, 0]) ^ const + return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)] + + +def bech32_encode(encoding, hrp, data): + """Compute a Bech32 or Bech32m string given HRP and data values.""" + combined = data + bech32_create_checksum(encoding, hrp, data) + return hrp + "1" + "".join([CHARSET[d] for d in combined]) + + +def bech32_decode(bech): + """Validate a Bech32/Bech32m string, and determine HRP and data.""" + if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or ( + bech.lower() != bech and bech.upper() != bech + ): + return (None, None, None) + bech = bech.lower() + pos = bech.rfind("1") + if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: + return (None, None, None) + if not all(x in CHARSET for x in bech[pos + 1 :]): + return (None, None, None) + hrp = bech[:pos] + data = [CHARSET.find(x) for x in bech[pos + 1 :]] + encoding = bech32_verify_checksum(hrp, data) + if encoding is None: + return (None, None, None) + return (encoding, hrp, data[:-6]) + + +def convertbits(data, frombits, tobits, pad=True): + """General power-of-2 base conversion.""" + acc = 0 + bits = 0 + ret = [] + maxv = (1 << tobits) - 1 + max_acc = (1 << (frombits + tobits - 1)) - 1 + for value in data: + if value < 0 or (value >> frombits): + return None + acc = ((acc << frombits) | value) & max_acc + bits += frombits + while bits >= tobits: + bits -= tobits + ret.append((acc >> bits) & maxv) + if pad: + if bits: + ret.append((acc << (tobits - bits)) & maxv) + elif bits >= frombits or ((acc << (tobits - bits)) & maxv): + return None + return ret + + +def decode(hrp, addr): + """Decode a segwit address.""" + encoding, hrpgot, data = bech32_decode(addr) + if hrpgot != hrp: + return (None, None) + decoded = convertbits(data[1:], 5, 8, False) + if decoded is None or len(decoded) < 2 or len(decoded) > 40: + return (None, None) + if data[0] > 16: + return (None, None) + if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: + return (None, None) + if (data[0] == 0 and encoding != Encoding.BECH32) or ( + data[0] != 0 and encoding != Encoding.BECH32M + ): + return (None, None) + return (data[0], decoded) + + +def encode(hrp, witver, witprog): + """Encode a segwit address.""" + encoding = Encoding.BECH32 if witver == 0 else Encoding.BECH32M + ret = bech32_encode(encoding, hrp, [witver] + convertbits(witprog, 8, 5)) + if decode(hrp, ret) == (None, None): + return None + return ret diff --git a/bitcoin_client/ledger_bitcoin/embit/bip32.py b/bitcoin_client/ledger_bitcoin/embit/bip32.py new file mode 100644 index 000000000..e31cb7a26 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/bip32.py @@ -0,0 +1,309 @@ +from . import ec +from .base import EmbitKey, EmbitError +from .misc import copy, const, secp256k1 +from .networks import NETWORKS +from . import base58 +from . import hashes +import hmac +from binascii import hexlify + +HARDENED_INDEX = const(0x80000000) + + +class HDError(EmbitError): + pass + + +class HDKey(EmbitKey): + """HD Private or Public key""" + + def __init__( + self, + key: EmbitKey, # more specifically, PrivateKey or PublicKey + chain_code: bytes, + version=None, + depth: int = 0, + fingerprint: bytes = b"\x00\x00\x00\x00", + child_number: int = 0, + ): + self.key = key + if len(key.serialize()) != 32 and len(key.serialize()) != 33: + raise HDError("Invalid key. Should be private or compressed public") + if version is not None: + self.version = version + else: + if len(key.serialize()) == 32: + self.version = NETWORKS["main"]["xprv"] + else: + self.version = NETWORKS["main"]["xpub"] + self.chain_code = chain_code + self.depth = depth + self.fingerprint = fingerprint + self._my_fingerprint = b"" + self.child_number = child_number + # check that base58[1:4] is "prv" or "pub" + if self.is_private and self.to_base58()[1:4] != "prv": + raise HDError("Invalid version") + if not self.is_private and self.to_base58()[1:4] != "pub": + raise HDError("Invalid version") + + @classmethod + def from_seed(cls, seed: bytes, version=NETWORKS["main"]["xprv"]): + """Creates a root private key from 64-byte seed""" + raw = hmac.new(b"Bitcoin seed", seed, digestmod="sha512").digest() + private_key = ec.PrivateKey(raw[:32]) + chain_code = raw[32:] + return cls(private_key, chain_code, version=version) + + @classmethod + def from_base58(cls, s: str): + b = base58.decode_check(s) + return cls.parse(b) + + @property + def my_fingerprint(self) -> bytes: + if not self._my_fingerprint: + sec = self.sec() + self._my_fingerprint = hashes.hash160(sec)[:4] + return self._my_fingerprint + + @property + def is_private(self) -> bool: + """checks if the HDKey is private or public""" + return self.key.is_private + + @property + def secret(self): + if not self.is_private: + raise HDError("Key is not private") + return self.key.secret + + def write_to(self, stream, version=None) -> int: + if version is None: + version = self.version + res = stream.write(version) + res += stream.write(bytes([self.depth])) + res += stream.write(self.fingerprint) + res += stream.write(self.child_number.to_bytes(4, "big")) + res += stream.write(self.chain_code) + if self.is_private: + res += stream.write(b"\x00") + res += stream.write(self.key.serialize()) + return res + + def to_base58(self, version=None) -> str: + b = self.serialize(version) + res = base58.encode_check(b) + if res[1:4] == "prv" and not self.is_private: + raise HDError("Invalid version for private key") + if res[1:4] == "pub" and self.is_private: + raise HDError("Invalid version for public key") + return res + + @classmethod + def from_string(cls, s: str): + return cls.from_base58(s) + + def to_string(self, version=None): + return self.to_base58(version) + + @classmethod + def read_from(cls, stream): + version = stream.read(4) + depth = stream.read(1)[0] + fingerprint = stream.read(4) + child_number = int.from_bytes(stream.read(4), "big") + chain_code = stream.read(32) + k = stream.read(33) + if k[0] == 0: + key = ec.PrivateKey.parse(k[1:]) + else: + key = ec.PublicKey.parse(k) + + if len(version) < 4 or len(fingerprint) < 4 or len(chain_code) < 32: + raise HDError("Not enough bytes") + hd = cls( + key, + chain_code, + version=version, + depth=depth, + fingerprint=fingerprint, + child_number=child_number, + ) + subver = hd.to_base58()[1:4] + if subver != "prv" and subver != "pub": + raise HDError("Invalid version") + if depth == 0 and child_number != 0: + raise HDError("zero depth with non-zero index") + if depth == 0 and fingerprint != b"\x00\x00\x00\x00": + raise HDError("zero depth with non-zero parent") + return hd + + def to_public(self, version=None): + if not self.is_private: + raise HDError("Already public") + if version is None: + # detect network + for net in NETWORKS: + for k in NETWORKS[net]: + if "prv" in k and NETWORKS[net][k] == self.version: + # xprv -> xpub, zprv -> zpub etc + version = NETWORKS[net][k.replace("prv", "pub")] + break + if version is None: + raise HDError("Can't find proper version. Provide it with version keyword") + return self.__class__( + self.key.get_public_key(), + self.chain_code, + version=version, + depth=self.depth, + fingerprint=self.fingerprint, + child_number=self.child_number, + ) + + def get_public_key(self): + return self.key.get_public_key() if self.is_private else self.key + + def sec(self) -> bytes: + """Returns SEC serialization of the public key""" + return self.key.sec() + + def xonly(self) -> bytes: + return self.key.xonly() + + def taproot_tweak(self, h=b""): + return HDKey( + self.key.taproot_tweak(h), + self.chain_code, + version=self.version, + depth=self.depth, + fingerprint=self.fingerprint, + child_number=self.child_number, + ) + + def child(self, index: int, hardened: bool = False): + """Derives a child HDKey""" + if index > 0xFFFFFFFF: + raise HDError("Index should be less then 2^32") + if hardened and index < HARDENED_INDEX: + index += HARDENED_INDEX + if index >= HARDENED_INDEX: + hardened = True + if hardened and not self.is_private: + raise HDError("Can't do hardened with public key") + + # we need pubkey for fingerprint anyways + sec = self.sec() + fingerprint = hashes.hash160(sec)[:4] + if hardened: + data = b"\x00" + self.key.serialize() + index.to_bytes(4, "big") + else: + data = sec + index.to_bytes(4, "big") + raw = hmac.new(self.chain_code, data, digestmod="sha512").digest() + secret = raw[:32] + chain_code = raw[32:] + if self.is_private: + secret = secp256k1.ec_privkey_add(secret, self.key.serialize()) + key = ec.PrivateKey(secret) + else: + # copy of internal secp256k1 point structure + point = copy(self.key._point) + point = secp256k1.ec_pubkey_add(point, secret) + key = ec.PublicKey(point) + return HDKey( + key, + chain_code, + version=self.version, + depth=self.depth + 1, + fingerprint=fingerprint, + child_number=index, + ) + + def derive(self, path): + """path: int array or a string starting with m/""" + if isinstance(path, str): + # string of the form m/44h/0'/ind + path = parse_path(path) + child = self + for idx in path: + child = child.child(idx) + return child + + def sign(self, msg_hash: bytes) -> ec.Signature: + """signs a hash of the message with the private key""" + if not self.is_private: + raise HDError("HD public key can't sign") + return self.key.sign(msg_hash) + + def schnorr_sign(self, msg_hash): + if not self.is_private: + raise HDError("HD public key can't sign") + return self.key.schnorr_sign(msg_hash) + + def verify(self, sig, msg_hash) -> bool: + return self.key.verify(sig, msg_hash) + + def schnorr_verify(self, sig, msg_hash) -> bool: + return self.key.schnorr_verify(sig, msg_hash) + + def __eq__(self, other): + # skip version + return self.serialize()[4:] == other.serialize()[4:] + + def __hash__(self): + return hash(self.serialize()) + + +def detect_version(path, default="xprv", network=None) -> bytes: + """ + Detects slip-132 version from the path for certain network. + Trying to be smart, use if you want, but with care. + """ + key = default + net = network + if network is None: + net = NETWORKS["main"] + if isinstance(path, str): + path = parse_path(path) + if len(path) == 0: + return network[key] + if path[0] == HARDENED_INDEX + 84: + key = "z" + default[1:] + elif path[0] == HARDENED_INDEX + 49: + key = "y" + default[1:] + elif path[0] == HARDENED_INDEX + 48: + if len(path) >= 4: + if path[3] == HARDENED_INDEX + 1: + key = "Y" + default[1:] + elif path[3] == HARDENED_INDEX + 2: + key = "Z" + default[1:] + if network is None and len(path) > 1 and path[1] == HARDENED_INDEX + 1: + net = NETWORKS["test"] + return net[key] + + +def _parse_der_item(e: str) -> int: + if e[-1] in {"h", "H", "'"}: + return int(e[:-1]) + HARDENED_INDEX + else: + return int(e) + + +def parse_path(path: str) -> list: + """converts derivation path of the form m/44h/1'/0'/0/32 to int array""" + arr = path.rstrip("/").split("/") + if arr[0] == "m": + arr = arr[1:] + if len(arr) == 0: + return [] + return [_parse_der_item(e) for e in arr] + + +def path_to_str(path: list, fingerprint=None) -> str: + s = "m" if fingerprint is None else hexlify(fingerprint).decode() + for el in path: + if el >= HARDENED_INDEX: + s += "/%dh" % (el - HARDENED_INDEX) + else: + s += "/%d" % el + return s diff --git a/bitcoin_client/ledger_bitcoin/embit/compact.py b/bitcoin_client/ledger_bitcoin/embit/compact.py new file mode 100644 index 000000000..0138f394e --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/compact.py @@ -0,0 +1,41 @@ +""" Compact Int parsing / serialization """ +import io + + +def to_bytes(i: int) -> bytes: + """encodes an integer as a compact int""" + if i < 0: + raise ValueError("integer can't be negative: {}".format(i)) + order = 0 + while i >> (8 * (2**order)): + order += 1 + if order == 0: + if i < 0xFD: + return bytes([i]) + order = 1 + if order > 3: + raise ValueError("integer too large: {}".format(i)) + return bytes([0xFC + order]) + i.to_bytes(2**order, "little") + + +def from_bytes(b: bytes) -> int: + s = io.BytesIO(b) + res = read_from(s) + if len(s.read(1)) > 0: + raise ValueError("Too many bytes") + return res + + +def read_from(stream) -> int: + """reads a compact integer from a stream""" + c = stream.read(1) + if not isinstance(c, bytes): + raise TypeError("Bytes must be returned from stream.read()") + if len(c) != 1: + raise RuntimeError("Can't read one byte from the stream") + i = c[0] + if i >= 0xFD: + bytes_to_read = 2 ** (i - 0xFC) + return int.from_bytes(stream.read(bytes_to_read), "little") + else: + return i diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/__init__.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/__init__.py new file mode 100644 index 000000000..600296d3c --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/__init__.py @@ -0,0 +1,3 @@ +from . import miniscript +from .descriptor import Descriptor +from .arguments import Key diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/arguments.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/arguments.py new file mode 100644 index 000000000..3f92500a7 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/arguments.py @@ -0,0 +1,515 @@ +from binascii import hexlify, unhexlify +from .base import DescriptorBase +from .errors import ArgumentError +from .. import bip32, ec, compact, hashes +from ..bip32 import HARDENED_INDEX +from ..misc import read_until + + +class KeyOrigin: + def __init__(self, fingerprint: bytes, derivation: list): + self.fingerprint = fingerprint + self.derivation = derivation + + @classmethod + def from_string(cls, s: str): + arr = s.split("/") + mfp = unhexlify(arr[0]) + assert len(mfp) == 4 + arr[0] = "m" + path = "/".join(arr) + derivation = bip32.parse_path(path) + return cls(mfp, derivation) + + def __str__(self): + return bip32.path_to_str(self.derivation, fingerprint=self.fingerprint) + + +class AllowedDerivation(DescriptorBase): + # xpub/<0;1>/* - <0;1> is a set of allowed branches, wildcard * is stored as None + def __init__(self, indexes=[[0, 1], None]): + # check only one wildcard + if ( + len( + [i for i in indexes if i is None or (isinstance(i, list) and None in i)] + ) + > 1 + ): + raise ArgumentError("Only one wildcard is allowed") + # check only one set is in the derivation + if len([i for i in indexes if isinstance(i, list)]) > 1: + raise ArgumentError("Only one set of branches is allowed") + self.indexes = indexes + + @property + def is_wildcard(self): + return None in self.indexes + + def fill(self, idx, branch_index=None): + # None is ok + if idx is not None and (idx < 0 or idx >= HARDENED_INDEX): + raise ArgumentError("Hardened indexes are not allowed in wildcard") + arr = [i for i in self.indexes] + for i, el in enumerate(arr): + if el is None: + arr[i] = idx + if isinstance(el, list): + if branch_index is None: + arr[i] = el[0] + else: + if branch_index < 0 or branch_index >= len(el): + raise ArgumentError("Invalid branch index") + arr[i] = el[branch_index] + return arr + + def branch(self, branch_index): + arr = self.fill(None, branch_index) + return type(self)(arr) + + def check_derivation(self, derivation: list): + if len(derivation) != len(self.indexes): + return None + branch_idx = 0 # default branch if no branches in descriptor + idx = None + for i, el in enumerate(self.indexes): + der = derivation[i] + if isinstance(el, int): + if el != der: + return None + # branch + elif isinstance(el, list): + if der not in el: + return None + branch_idx = el.index(der) + # wildcard + elif el is None: + idx = der + # shouldn't happen + else: + raise ArgumentError("Strange derivation index...") + if branch_idx is not None and idx is not None: + return idx, branch_idx + + @classmethod + def default(cls): + return AllowedDerivation([[0, 1], None]) + + @property + def branches(self): + for el in self.indexes: + if isinstance(el, list): + return el + return None + + @property + def has_hardend(self): + for idx in self.indexes: + if isinstance(idx, int) and idx >= HARDENED_INDEX: + return True + if ( + isinstance(idx, list) + and len([i for i in idx if i >= HARDENED_INDEX]) > 0 + ): + return True + return False + + @classmethod + def from_string(cls, der: str, allow_hardened=False, allow_set=True): + if len(der) == 0: + return None + indexes = [ + cls.parse_element(d, allow_hardened, allow_set) for d in der.split("/") + ] + return cls(indexes) + + @classmethod + def parse_element(cls, d: str, allow_hardened=False, allow_set=True): + # wildcard + if d == "*": + return None + # branch set - legacy `{m,n}` + if d[0] == "{" and d[-1] == "}": + if not allow_set: + raise ArgumentError("Set is not allowed in derivation %s" % d) + return [ + cls.parse_element(dd, allow_hardened, allow_set=False) + for dd in d[1:-1].split(",") + ] + # branch set - multipart `` + if d[0] == "<" and d[-1] == ">": + if not allow_set: + raise ArgumentError("Set is not allowed in derivation %s" % d) + return [ + cls.parse_element(dd, allow_hardened, allow_set=False) + for dd in d[1:-1].split(";") + ] + idx = 0 + if d[-1] in ["h", "H", "'"]: + if not allow_hardened: + raise ArgumentError("Hardened derivation is not allowed in %s" % d) + idx = HARDENED_INDEX + d = d[:-1] + i = int(d) + if i < 0 or i >= HARDENED_INDEX: + raise ArgumentError( + "Derivation index can be in a range [0, %d)" % HARDENED_INDEX + ) + return idx + i + + def __str__(self): + r = "" + for idx in self.indexes: + if idx is None: + r += "/*" + if isinstance(idx, int): + if idx >= HARDENED_INDEX: + r += "/%dh" % (idx - HARDENED_INDEX) + else: + r += "/%d" % idx + if isinstance(idx, list): + r += "/<" + r += ";".join( + [ + str(i) if i < HARDENED_INDEX else str(i - HARDENED_INDEX) + "h" + for i in idx + ] + ) + r += ">" + return r + + +class Key(DescriptorBase): + def __init__( + self, + key, + origin=None, + derivation=None, + taproot=False, + xonly_repr=False, + ): + self.origin = origin + self.key = key + self.taproot = taproot + self.xonly_repr = xonly_repr and taproot + if not hasattr(key, "derive") and derivation: + raise ArgumentError("Key %s doesn't support derivation" % key) + self.allowed_derivation = derivation + + def __len__(self): + return 34 - int(self.taproot) # <33:sec> or <32:xonly> + + @property + def my_fingerprint(self): + if self.is_extended: + return self.key.my_fingerprint + return None + + @property + def fingerprint(self): + if self.origin is not None: + return self.origin.fingerprint + else: + if self.is_extended: + return self.key.my_fingerprint + return None + + @property + def derivation(self): + return [] if self.origin is None else self.origin.derivation + + @classmethod + def read_from(cls, s, taproot: bool = False): + """ + Reads key argument from stream. + If taproot is set to True - allows both x-only and sec pubkeys. + If taproot is False - will raise when finds xonly pubkey. + """ + first = s.read(1) + origin = None + if first == b"[": + prefix, char = read_until(s, b"]") + if char != b"]": + raise ArgumentError("Invalid key - missing ]") + origin = KeyOrigin.from_string(prefix.decode()) + else: + s.seek(-1, 1) + k, char = read_until(s, b",)/") + der = b"" + # there is a following derivation + if char == b"/": + der, char = read_until(s, b"<{,)") + # legacy branches: {a,b,c...} + if char == b"{": + der += b"{" + branch, char = read_until(s, b"}") + if char is None: + raise ArgumentError("Failed reading the key, missing }") + der += branch + b"}" + rest, char = read_until(s, b",)") + der += rest + # multipart descriptor: + elif char == b"<": + der += b"<" + branch, char = read_until(s, b">") + if char is None: + raise ArgumentError("Failed reading the key, missing >") + der += branch + b">" + rest, char = read_until(s, b",)") + der += rest + if char is not None: + s.seek(-1, 1) + # parse key + k, xonly_repr = cls.parse_key(k, taproot) + # parse derivation + allow_hardened = isinstance(k, bip32.HDKey) and isinstance(k.key, ec.PrivateKey) + derivation = AllowedDerivation.from_string( + der.decode(), allow_hardened=allow_hardened + ) + return cls(k, origin, derivation, taproot, xonly_repr) + + @classmethod + def parse_key(cls, key: bytes, taproot: bool = False): + # convert to string + k = key.decode() + if len(k) in [66, 130] and k[:2] in ["02", "03", "04"]: + # bare public key + return ec.PublicKey.parse(unhexlify(k)), False + elif taproot and len(k) == 64: + # x-only pubkey + return ec.PublicKey.parse(b"\x02" + unhexlify(k)), True + elif k[1:4] in ["pub", "prv"]: + # bip32 key + return bip32.HDKey.from_base58(k), False + else: + return ec.PrivateKey.from_wif(k), False + + @property + def is_extended(self): + return isinstance(self.key, bip32.HDKey) + + def check_derivation(self, derivation_path): + rest = None + # full derivation path + if self.fingerprint == derivation_path.fingerprint: + origin = self.derivation + if origin == derivation_path.derivation[: len(origin)]: + rest = derivation_path.derivation[len(origin) :] + # short derivation path + if self.my_fingerprint == derivation_path.fingerprint: + rest = derivation_path.derivation + if self.allowed_derivation is None or rest is None: + return None + return self.allowed_derivation.check_derivation(rest) + + def get_public_key(self): + return ( + self.key.get_public_key() + if (self.is_extended or self.is_private) + else self.key + ) + + def sec(self): + return self.key.sec() + + def xonly(self): + return self.key.xonly() + + def taproot_tweak(self, h=b""): + assert self.taproot + return self.key.taproot_tweak(h) + + def serialize(self): + if self.taproot: + return self.sec()[1:33] + return self.sec() + + def compile(self): + d = self.serialize() + return compact.to_bytes(len(d)) + d + + @property + def prefix(self): + if self.origin: + return "[%s]" % self.origin + return "" + + @property + def suffix(self): + return "" if self.allowed_derivation is None else str(self.allowed_derivation) + + @property + def can_derive(self): + return self.allowed_derivation is not None and hasattr(self.key, "derive") + + @property + def branches(self): + return self.allowed_derivation.branches if self.allowed_derivation else None + + @property + def num_branches(self): + return 1 if self.branches is None else len(self.branches) + + def branch(self, branch_index=None): + der = ( + self.allowed_derivation.branch(branch_index) + if self.allowed_derivation is not None + else None + ) + return type(self)(self.key, self.origin, der, self.taproot) + + @property + def is_wildcard(self): + return self.allowed_derivation.is_wildcard if self.allowed_derivation else False + + def derive(self, idx, branch_index=None): + # nothing to derive + if self.allowed_derivation is None: + return self + der = self.allowed_derivation.fill(idx, branch_index=branch_index) + k = self.key.derive(der) + if self.origin: + origin = KeyOrigin(self.origin.fingerprint, self.origin.derivation + der) + else: + origin = KeyOrigin(self.key.child(0).fingerprint, der) + # empty derivation + derivation = None + return type(self)(k, origin, derivation, self.taproot) + + @property + def is_private(self): + return isinstance(self.key, ec.PrivateKey) or ( + self.is_extended and self.key.is_private + ) + + def to_public(self): + if not self.is_private: + return self + if isinstance(self.key, ec.PrivateKey): + return type(self)( + self.key.get_public_key(), + self.origin, + self.allowed_derivation, + self.taproot, + ) + else: + return type(self)( + self.key.to_public(), self.origin, self.allowed_derivation, self.taproot + ) + + @property + def private_key(self): + if not self.is_private: + raise ArgumentError("Key is not private") + # either HDKey.key or just the key + return self.key.key if self.is_extended else self.key + + @property + def secret(self): + return self.private_key.secret + + def to_string(self, version=None): + if isinstance(self.key, ec.PublicKey): + k = self.key.sec() if not self.xonly_repr else self.key.xonly() + return self.prefix + hexlify(k).decode() + if isinstance(self.key, bip32.HDKey): + return self.prefix + self.key.to_base58(version) + self.suffix + if isinstance(self.key, ec.PrivateKey): + return self.prefix + self.key.wif() + return self.prefix + self.key + + @classmethod + def from_string(cls, s, taproot=False): + return cls.parse(s.encode(), taproot) + + +class KeyHash(Key): + @classmethod + def parse_key(cls, k: bytes, *args, **kwargs): + # convert to string + kd = k.decode() + # raw 20-byte hash + if len(kd) == 40: + return kd, False + return super().parse_key(k, *args, **kwargs) + + def serialize(self, *args, **kwargs): + if isinstance(self.key, str): + return unhexlify(self.key) + # TODO: should it be xonly? + if self.taproot: + return hashes.hash160(self.key.sec()[1:33]) + return hashes.hash160(self.key.sec()) + + def __len__(self): + return 21 # <20:pkh> + + def compile(self): + d = self.serialize() + return compact.to_bytes(len(d)) + d + + +class Number(DescriptorBase): + def __init__(self, num): + self.num = num + + @classmethod + def read_from(cls, s, taproot=False): + num = 0 + char = s.read(1) + while char in b"0123456789": + num = 10 * num + int(char.decode()) + char = s.read(1) + s.seek(-1, 1) + return cls(num) + + def compile(self): + if self.num == 0: + return b"\x00" + if self.num <= 16: + return bytes([80 + self.num]) + b = self.num.to_bytes(32, "little").rstrip(b"\x00") + if b[-1] >= 128: + b += b"\x00" + return bytes([len(b)]) + b + + def __len__(self): + return len(self.compile()) + + def __str__(self): + return "%d" % self.num + + +class Raw(DescriptorBase): + LEN = 32 + + def __init__(self, raw): + if len(raw) != self.LEN * 2: + raise ArgumentError("Invalid raw element length: %d" % len(raw)) + self.raw = unhexlify(raw) + + @classmethod + def read_from(cls, s, taproot=False): + return cls(s.read(2 * cls.LEN).decode()) + + def __str__(self): + return hexlify(self.raw).decode() + + def compile(self): + return compact.to_bytes(len(self.raw)) + self.raw + + def __len__(self): + return len(compact.to_bytes(self.LEN)) + self.LEN + + +class Raw32(Raw): + LEN = 32 + + def __len__(self): + return 33 + + +class Raw20(Raw): + LEN = 20 + + def __len__(self): + return 21 diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/base.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/base.py new file mode 100644 index 000000000..0c203b21b --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/base.py @@ -0,0 +1,22 @@ +from io import BytesIO +from ..base import EmbitBase + + +class DescriptorBase(EmbitBase): + """ + Descriptor is purely text-based, so parse/serialize do + the same as from/to_string, just returning ascii bytes + instead of ascii string. + """ + + @classmethod + def from_string(cls, s: str, *args, **kwargs): + return cls.parse(s.encode(), *args, **kwargs) + + def serialize(self, *args, **kwargs) -> bytes: + stream = BytesIO() + self.write_to(stream) + return stream.getvalue() + + def to_string(self, *args, **kwargs) -> str: + return self.serialize().decode() diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/checksum.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/checksum.py new file mode 100644 index 000000000..1e487ea51 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/checksum.py @@ -0,0 +1,56 @@ +from .errors import DescriptorError + + +def polymod(c: int, val: int) -> int: + c0 = c >> 35 + c = ((c & 0x7FFFFFFFF) << 5) ^ val + if c0 & 1: + c ^= 0xF5DEE51989 + if c0 & 2: + c ^= 0xA9FDCA3312 + if c0 & 4: + c ^= 0x1BAB10E32D + if c0 & 8: + c ^= 0x3706B1677A + if c0 & 16: + c ^= 0x644D626FFD + return c + + +def checksum(desc: str) -> str: + """Calculate checksum of desciptor string""" + INPUT_CHARSET = ( + "0123456789()[],'/*abcdefgh@:$%{}IJKLMNOPQRSTUVW" + 'XYZ&+-.;<=>?!^_|~ijklmnopqrstuvwxyzABCDEFGH`#"\\ ' + ) + CHECKSUM_CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" + + c = 1 + cls = 0 + clscount = 0 + for ch in desc: + pos = INPUT_CHARSET.find(ch) + if pos == -1: + raise DescriptorError("Invalid character '%s' in the input string" % ch) + c = polymod(c, pos & 31) + cls = cls * 3 + (pos >> 5) + clscount += 1 + if clscount == 3: + c = polymod(c, cls) + cls = 0 + clscount = 0 + if clscount > 0: + c = polymod(c, cls) + for j in range(0, 8): + c = polymod(c, 0) + c ^= 1 + + ret = [CHECKSUM_CHARSET[(c >> (5 * (7 - j))) & 31] for j in range(0, 8)] + return "".join(ret) + + +def add_checksum(desc: str) -> str: + """Add checksum to descriptor string""" + if "#" in desc: + desc = desc.split("#")[0] + return desc + "#" + checksum(desc) diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/descriptor.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/descriptor.py new file mode 100644 index 000000000..9f585dc5e --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/descriptor.py @@ -0,0 +1,384 @@ +from io import BytesIO +from .. import script +from ..networks import NETWORKS +from .errors import DescriptorError +from .base import DescriptorBase +from .miniscript import Miniscript, Multi, Sortedmulti +from .arguments import Key +from .taptree import TapTree + + +class Descriptor(DescriptorBase): + def __init__( + self, + miniscript=None, + sh=False, + wsh=True, + key=None, + wpkh=True, + taproot=False, + taptree=None, + ): + # TODO: add support for taproot scripts + # Should: + # - accept taptree without a key + # - accept key without taptree + # - raise if miniscript is not None, but taproot=True + # - raise if taptree is not None, but taproot=False + if key is None and miniscript is None and taptree is None: + raise DescriptorError("Provide a key, miniscript or taptree") + if miniscript is not None: + # will raise if can't verify + miniscript.verify() + if miniscript.type != "B": + raise DescriptorError("Top level miniscript should be 'B'") + # check all branches have the same length + branches = { + len(k.branches) for k in miniscript.keys if k.branches is not None + } + if len(branches) > 1: + raise DescriptorError("All branches should have the same length") + self.sh = sh + self.wsh = wsh + self.key = key + self.miniscript = miniscript + self.wpkh = wpkh + self.taproot = taproot + self.taptree = taptree or TapTree() + # make sure all keys are either taproot or not + for k in self.keys: + k.taproot = taproot + + @property + def script_len(self): + if self.taproot: + return 34 # OP_1 <32:xonly> + if self.miniscript: + return len(self.miniscript) + if self.wpkh: + return 22 # 00 <20:pkh> + return 25 # OP_DUP OP_HASH160 <20:pkh> OP_EQUALVERIFY OP_CHECKSIG + + @property + def num_branches(self): + return max([k.num_branches for k in self.keys]) + + def branch(self, branch_index=None): + if self.miniscript: + return type(self)( + self.miniscript.branch(branch_index), + self.sh, + self.wsh, + None, + self.wpkh, + self.taproot, + ) + else: + return type(self)( + None, + self.sh, + self.wsh, + self.key.branch(branch_index), + self.wpkh, + self.taproot, + self.taptree.branch(branch_index), + ) + + @property + def is_wildcard(self): + return any([key.is_wildcard for key in self.keys]) + + @property + def is_wrapped(self): + return self.sh and self.is_segwit + + @property + def is_legacy(self): + return not (self.is_segwit or self.is_taproot) + + @property + def is_segwit(self): + return ( + (self.wsh and self.miniscript) or (self.wpkh and self.key) or self.taproot + ) + + @property + def is_pkh(self): + return self.key is not None and not self.taproot + + @property + def is_taproot(self): + return self.taproot + + @property + def is_basic_multisig(self) -> bool: + # TODO: should be true for taproot basic multisig with NUMS as internal key + # Sortedmulti is subclass of Multi + return bool(self.miniscript and isinstance(self.miniscript, Multi)) + + @property + def is_sorted(self) -> bool: + return bool(self.is_basic_multisig and isinstance(self.miniscript, Sortedmulti)) + + def scriptpubkey_type(self): + if self.is_taproot: + return "p2tr" + if self.sh: + return "p2sh" + if self.is_pkh: + if self.is_legacy: + return "p2pkh" + if self.is_segwit: + return "p2wpkh" + else: + return "p2wsh" + + @property + def brief_policy(self): + if self.taptree: + return "taptree" + if self.key: + return "single key" + if self.is_basic_multisig: + return ( + str(self.miniscript.args[0]) + + " of " + + str(len(self.keys)) + + " multisig" + + (" (sorted)" if self.is_sorted else "") + ) + return "miniscript" + + @property + def full_policy(self): + if (self.key and not self.taptree) or self.is_basic_multisig: + return self.brief_policy + s = str(self.miniscript or self) + for i, k in enumerate(self.keys): + s = s.replace(str(k), chr(65 + i)) + return s + + def derive(self, idx, branch_index=None): + if self.miniscript: + return type(self)( + self.miniscript.derive(idx, branch_index), + self.sh, + self.wsh, + None, + self.wpkh, + self.taproot, + ) + else: + return type(self)( + None, + self.sh, + self.wsh, + self.key.derive(idx, branch_index), + self.wpkh, + self.taproot, + self.taptree.derive(idx, branch_index), + ) + + def to_public(self): + if self.miniscript: + return type(self)( + self.miniscript.to_public(), + self.sh, + self.wsh, + None, + self.wpkh, + self.taproot, + ) + else: + return type(self)( + None, + self.sh, + self.wsh, + self.key.to_public(), + self.wpkh, + self.taproot, + self.taptree.to_public(), + ) + + def owns(self, psbt_scope): + """Checks if psbt input or output belongs to this descriptor""" + # we can't check if we don't know script_pubkey + if psbt_scope.script_pubkey is None: + return False + # quick check of script_pubkey type + if psbt_scope.script_pubkey.script_type() != self.scriptpubkey_type(): + return False + for pub, der in psbt_scope.bip32_derivations.items(): + # check of the fingerprints + for k in self.keys: + if not k.is_extended: + continue + res = k.check_derivation(der) + if res: + idx, branch_idx = res + sc = self.derive(idx, branch_index=branch_idx).script_pubkey() + # if derivation is found but scriptpubkey doesn't match - fail + return sc == psbt_scope.script_pubkey + for pub, (leafs, der) in psbt_scope.taproot_bip32_derivations.items(): + # check of the fingerprints + for k in self.keys: + if not k.is_extended: + continue + res = k.check_derivation(der) + if res: + idx, branch_idx = res + sc = self.derive(idx, branch_index=branch_idx).script_pubkey() + # if derivation is found but scriptpubkey doesn't match - fail + return sc == psbt_scope.script_pubkey + return False + + def check_derivation(self, derivation_path): + for k in self.keys: + # returns a tuple branch_idx, idx + der = k.check_derivation(derivation_path) + if der is not None: + return der + return None + + def witness_script(self): + if self.wsh and self.miniscript is not None: + return script.Script(self.miniscript.compile()) + + def redeem_script(self): + if not self.sh: + return None + if self.miniscript: + if not self.wsh: + return script.Script(self.miniscript.compile()) + else: + return script.p2wsh(script.Script(self.miniscript.compile())) + else: + return script.p2wpkh(self.key) + + def script_pubkey(self): + # covers sh-wpkh, sh and sh-wsh + if self.taproot: + return script.p2tr(self.key, self.taptree) + if self.sh: + return script.p2sh(self.redeem_script()) + if self.wsh: + return script.p2wsh(self.witness_script()) + if self.miniscript: + return script.Script(self.miniscript.compile()) + if self.wpkh: + return script.p2wpkh(self.key) + return script.p2pkh(self.key) + + def address(self, network=NETWORKS["main"]): + return self.script_pubkey().address(network) + + @property + def keys(self): + if self.taptree and self.key: + return [self.key] + self.taptree.keys + elif self.taptree: + return self.taptree.keys + elif self.key: + return [self.key] + return self.miniscript.keys + + @classmethod + def from_string(cls, desc): + s = BytesIO(desc.encode()) + res = cls.read_from(s) + left = s.read() + if len(left) > 0 and not left.startswith(b"#"): + raise DescriptorError("Unexpected characters after descriptor: %r" % left) + return res + + @classmethod + def read_from(cls, s): + # starts with sh(wsh()), sh() or wsh() + start = s.read(7) + sh = False + wsh = False + wpkh = False + is_miniscript = True + taproot = False + taptree = TapTree() + if start.startswith(b"tr("): + taproot = True + s.seek(-4, 1) + elif start.startswith(b"sh(wsh("): + sh = True + wsh = True + elif start.startswith(b"wsh("): + sh = False + wsh = True + s.seek(-3, 1) + elif start.startswith(b"sh(wpkh"): + is_miniscript = False + sh = True + wpkh = True + assert s.read(1) == b"(" + elif start.startswith(b"wpkh("): + is_miniscript = False + wpkh = True + s.seek(-2, 1) + elif start.startswith(b"pkh("): + is_miniscript = False + s.seek(-3, 1) + elif start.startswith(b"sh("): + sh = True + wsh = False + s.seek(-4, 1) + else: + raise ValueError("Invalid descriptor (starts with '%s')" % start.decode()) + # taproot always has a key, and may have taptree miniscript + if taproot: + miniscript = None + key = Key.read_from(s, taproot=True) + nbrackets = 1 + c = s.read(1) + # TODO: should it be ok to pass just taptree without a key? + # check if we have taptree after the key + if c != b",": + s.seek(-1, 1) + else: + taptree = TapTree.read_from(s) + elif is_miniscript: + miniscript = Miniscript.read_from(s) + key = None + nbrackets = int(sh) + int(wsh) + # single key for sure + else: + miniscript = None + key = Key.read_from(s, taproot=taproot) + nbrackets = 1 + int(sh) + end = s.read(nbrackets) + if end != b")" * nbrackets: + raise ValueError( + "Invalid descriptor (expected ')' but ends with '%s')" % end.decode() + ) + return cls( + miniscript, + sh=sh, + wsh=wsh, + key=key, + wpkh=wpkh, + taproot=taproot, + taptree=taptree, + ) + + def to_string(self): + if self.taproot: + if self.taptree: + return "tr(%s,%s)" % (self.key, self.taptree) + return "tr(%s)" % self.key + if self.miniscript is not None: + res = str(self.miniscript) + if self.wsh: + res = "wsh(%s)" % res + else: + if self.wpkh: + res = "wpkh(%s)" % self.key + else: + res = "pkh(%s)" % self.key + if self.sh: + res = "sh(%s)" % res + return res diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/errors.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/errors.py new file mode 100644 index 000000000..b125f9f6c --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/errors.py @@ -0,0 +1,17 @@ +from ..base import EmbitError + + +class DescriptorError(EmbitError): + pass + + +class MiniscriptError(DescriptorError): + pass + + +class ArgumentError(MiniscriptError): + pass + + +class KeyError(ArgumentError): + pass diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/miniscript.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/miniscript.py new file mode 100644 index 000000000..397317f52 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/miniscript.py @@ -0,0 +1,1070 @@ +from ..misc import read_until +from .errors import MiniscriptError +from .base import DescriptorBase +from .arguments import Key, KeyHash, Number, Raw32, Raw20 + + +class Miniscript(DescriptorBase): + def __init__(self, *args, **kwargs): + self.args = args + self.taproot = kwargs.get("taproot", False) + + def compile(self): + return self.inner_compile() + + def verify(self): + for arg in self.args: + if isinstance(arg, Miniscript): + arg.verify() + + @property + def keys(self): + return sum( + [arg.keys for arg in self.args if isinstance(arg, Miniscript)], + [k for k in self.args if isinstance(k, Key) or isinstance(k, KeyHash)], + ) + + def derive(self, idx, branch_index=None): + args = [ + arg.derive(idx, branch_index) if hasattr(arg, "derive") else arg + for arg in self.args + ] + return type(self)(*args, taproot=self.taproot) + + def to_public(self): + args = [ + arg.to_public() if hasattr(arg, "to_public") else arg for arg in self.args + ] + return type(self)(*args, taproot=self.taproot) + + def branch(self, branch_index): + args = [ + arg.branch(branch_index) if hasattr(arg, "branch") else arg + for arg in self.args + ] + return type(self)(*args, taproot=self.taproot) + + @property + def properties(self): + return self.PROPS + + @property + def type(self): + return self.TYPE + + @classmethod + def read_from(cls, s, taproot=False): + op, char = read_until(s, b"(,)") + op = op.decode() + wrappers = "" + if ":" in op: + wrappers, op = op.split(":") + if op not in OPERATOR_NAMES: + raise MiniscriptError("Unknown operator '%s'" % op) + # number of arguments, classes of args, compile fn, type, validity checker + MiniscriptCls = OPERATORS[OPERATOR_NAMES.index(op)] + if MiniscriptCls.NARGS != 0 and char != b"(": + raise MiniscriptError("Missing operator") + + if MiniscriptCls.NARGS is None or MiniscriptCls.NARGS > 0: + args = MiniscriptCls.read_arguments(s, taproot=taproot) + else: + s.seek(-1, 1) + args = [] + miniscript = MiniscriptCls(*args, taproot=taproot) + for w in reversed(wrappers): + if w not in WRAPPER_NAMES: + raise MiniscriptError("Unknown wrapper") + WrapperCls = WRAPPERS[WRAPPER_NAMES.index(w)] + miniscript = WrapperCls(miniscript, taproot=taproot) + return miniscript + + @classmethod + def read_arguments(cls, s, taproot=False): + args = [] + if cls.NARGS is None: + if type(cls.ARGCLS) == tuple: + firstcls, nextcls = cls.ARGCLS + else: + firstcls, nextcls = cls.ARGCLS, cls.ARGCLS + args.append(firstcls.read_from(s, taproot=taproot)) + while True: + char = s.read(1) + if char == b",": + args.append(nextcls.read_from(s, taproot=taproot)) + elif char == b")": + break + else: + raise MiniscriptError( + "Expected , or ), got: %s" % (char + s.read()) + ) + else: + for i in range(cls.NARGS): + args.append(cls.ARGCLS.read_from(s, taproot=taproot)) + if i < cls.NARGS - 1: + char = s.read(1) + if char != b",": + raise MiniscriptError("Missing arguments, %s" % char) + char = s.read(1) + if char != b")": + raise MiniscriptError("Expected ) got %s" % (char + s.read())) + return args + + def __str__(self): + return type(self).NAME + "(" + ",".join([str(arg) for arg in self.args]) + ")" + + def __len__(self): + """Length of the compiled script, override this if you know the length""" + return len(self.compile()) + + def len_args(self): + return sum([len(arg) for arg in self.args]) + + +########### Known fragments (miniscript operators) ############## + + +class OneArg(Miniscript): + NARGS = 1 + + # small handy functions + @property + def arg(self): + return self.args[0] + + @property + def carg(self): + return self.arg.compile() + + +class NumberZero(Miniscript): + # 0 + + NARGS = 0 + NAME = "0" + TYPE = "B" + PROPS = "zud" + + def inner_compile(self): + return b"\x00" + + def __len__(self): + return 1 + + +class NumberOne(Miniscript): + # 1 + + NARGS = 0 + NAME = "1" + TYPE = "B" + PROPS = "zu" + + def inner_compile(self): + return b"\x51" + + def __len__(self): + return 1 + + +class PkK(OneArg): + # + NAME = "pk_k" + ARGCLS = Key + TYPE = "K" + PROPS = "ondu" + + def inner_compile(self): + return self.carg + + def __len__(self): + return self.len_args() + + +class PkH(OneArg): + # DUP HASH160 EQUALVERIFY + NAME = "pk_h" + ARGCLS = KeyHash + TYPE = "K" + PROPS = "ndu" + + def inner_compile(self): + return b"\x76\xa9" + self.carg + b"\x88" + + def __len__(self): + return self.len_args() + 3 + + +class Older(OneArg): + # CHECKSEQUENCEVERIFY + NAME = "older" + ARGCLS = Number + TYPE = "B" + PROPS = "z" + + def inner_compile(self): + return self.carg + b"\xb2" + + def verify(self): + super().verify() + if (self.arg.num < 1) or (self.arg.num >= 0x80000000): + raise MiniscriptError( + "%s should have an argument in range [1, 0x80000000)" % self.NAME + ) + + def __len__(self): + return self.len_args() + 1 + + +class After(Older): + # CHECKLOCKTIMEVERIFY + NAME = "after" + + def inner_compile(self): + return self.carg + b"\xb1" + + +class Sha256(OneArg): + # SIZE <32> EQUALVERIFY SHA256 EQUAL + NAME = "sha256" + ARGCLS = Raw32 + TYPE = "B" + PROPS = "ondu" + + def inner_compile(self): + return b"\x82" + Number(32).compile() + b"\x88\xa8" + self.carg + b"\x87" + + def __len__(self): + return self.len_args() + 6 + + +class Hash256(Sha256): + # SIZE <32> EQUALVERIFY HASH256 EQUAL + NAME = "hash256" + + def inner_compile(self): + return b"\x82" + Number(32).compile() + b"\x88\xaa" + self.carg + b"\x87" + + +class Ripemd160(Sha256): + # SIZE <32> EQUALVERIFY RIPEMD160 EQUAL + NAME = "ripemd160" + ARGCLS = Raw20 + + def inner_compile(self): + return b"\x82" + Number(32).compile() + b"\x88\xa6" + self.carg + b"\x87" + + +class Hash160(Ripemd160): + # SIZE <32> EQUALVERIFY HASH160 EQUAL + NAME = "hash160" + + def inner_compile(self): + return b"\x82" + Number(32).compile() + b"\x88\xa9" + self.carg + b"\x87" + + +class AndOr(Miniscript): + # [X] NOTIF [Z] ELSE [Y] ENDIF + NAME = "andor" + NARGS = 3 + ARGCLS = Miniscript + + @property + def type(self): + # same as Y/Z + return self.args[1].type + + def verify(self): + # requires: X is Bdu; Y and Z are both B, K, or V + super().verify() + if self.args[0].type != "B": + raise MiniscriptError("andor: X should be 'B'") + px = self.args[0].properties + if "d" not in px and "u" not in px: + raise MiniscriptError("andor: X should be 'du'") + if self.args[1].type != self.args[2].type: + raise MiniscriptError("andor: Y and Z should have the same types") + if self.args[1].type not in "BKV": + raise MiniscriptError("andor: Y and Z should be B K or V") + + @property + def properties(self): + # props: z=zXzYzZ; o=zXoYoZ or oXzYzZ; u=uYuZ; d=dZ + props = "" + px, py, pz = [arg.properties for arg in self.args] + if "z" in px and "z" in py and "z" in pz: + props += "z" + if ("z" in px and "o" in py and "o" in pz) or ( + "o" in px and "z" in py and "z" in pz + ): + props += "o" + if "u" in py and "u" in pz: + props += "u" + if "d" in pz: + props += "d" + return props + + def inner_compile(self): + return ( + self.args[0].compile() + + b"\x64" + + self.args[2].compile() + + b"\x67" + + self.args[1].compile() + + b"\x68" + ) + + def __len__(self): + return self.len_args() + 3 + + +class AndV(Miniscript): + # [X] [Y] + NAME = "and_v" + NARGS = 2 + ARGCLS = Miniscript + + def inner_compile(self): + return self.args[0].compile() + self.args[1].compile() + + def __len__(self): + return self.len_args() + + def verify(self): + # X is V; Y is B, K, or V + super().verify() + if self.args[0].type != "V": + raise MiniscriptError("and_v: X should be 'V'") + if self.args[1].type not in "BKV": + raise MiniscriptError("and_v: Y should be B K or V") + + @property + def type(self): + # same as Y + return self.args[1].type + + @property + def properties(self): + # z=zXzY; o=zXoY or zYoX; n=nX or zXnY; u=uY + px, py = [arg.properties for arg in self.args] + props = "" + if "z" in px and "z" in py: + props += "z" + if ("z" in px and "o" in py) or ("z" in py and "o" in px): + props += "o" + if "n" in px or ("z" in px and "n" in py): + props += "n" + if "u" in py: + props += "u" + return props + + +class AndB(Miniscript): + # [X] [Y] BOOLAND + NAME = "and_b" + NARGS = 2 + ARGCLS = Miniscript + TYPE = "B" + + def inner_compile(self): + return self.args[0].compile() + self.args[1].compile() + b"\x9a" + + def __len__(self): + return self.len_args() + 1 + + def verify(self): + # X is B; Y is W + super().verify() + if self.args[0].type != "B": + raise MiniscriptError("and_b: X should be B") + if self.args[1].type != "W": + raise MiniscriptError("and_b: Y should be W") + + @property + def properties(self): + # z=zXzY; o=zXoY or zYoX; n=nX or zXnY; d=dXdY; u + px, py = [arg.properties for arg in self.args] + props = "" + if "z" in px and "z" in py: + props += "z" + if ("z" in px and "o" in py) or ("z" in py and "o" in px): + props += "o" + if "n" in px or ("z" in px and "n" in py): + props += "n" + if "d" in px and "d" in py: + props += "d" + props += "u" + return props + + +class AndN(Miniscript): + # [X] NOTIF 0 ELSE [Y] ENDIF + # andor(X,Y,0) + NAME = "and_n" + NARGS = 2 + ARGCLS = Miniscript + + def inner_compile(self): + return ( + self.args[0].compile() + + b"\x64" + + Number(0).compile() + + b"\x67" + + self.args[1].compile() + + b"\x68" + ) + + def __len__(self): + return self.len_args() + 4 + + @property + def type(self): + # same as Y/Z + return self.args[1].type + + def verify(self): + # requires: X is Bdu; Y and Z are both B, K, or V + super().verify() + if self.args[0].type != "B": + raise MiniscriptError("and_n: X should be 'B'") + px = self.args[0].properties + if "d" not in px and "u" not in px: + raise MiniscriptError("and_n: X should be 'du'") + if self.args[1].type != "B": + raise MiniscriptError("and_n: Y should be B") + + @property + def properties(self): + # props: z=zXzYzZ; o=zXoYoZ or oXzYzZ; u=uYuZ; d=dZ + props = "" + px, py = [arg.properties for arg in self.args] + pz = "zud" + if "z" in px and "z" in py and "z" in pz: + props += "z" + if ("z" in px and "o" in py and "o" in pz) or ( + "o" in px and "z" in py and "z" in pz + ): + props += "o" + if "u" in py and "u" in pz: + props += "u" + if "d" in pz: + props += "d" + return props + + +class OrB(Miniscript): + # [X] [Z] BOOLOR + NAME = "or_b" + NARGS = 2 + ARGCLS = Miniscript + TYPE = "B" + + def inner_compile(self): + return self.args[0].compile() + self.args[1].compile() + b"\x9b" + + def __len__(self): + return self.len_args() + 1 + + def verify(self): + # X is Bd; Z is Wd + super().verify() + if self.args[0].type != "B": + raise MiniscriptError("or_b: X should be B") + if "d" not in self.args[0].properties: + raise MiniscriptError("or_b: X should be d") + if self.args[1].type != "W": + raise MiniscriptError("or_b: Z should be W") + if "d" not in self.args[1].properties: + raise MiniscriptError("or_b: Z should be d") + + @property + def properties(self): + # z=zXzZ; o=zXoZ or zZoX; d; u + props = "" + px, pz = [arg.properties for arg in self.args] + if "z" in px and "z" in pz: + props += "z" + if ("z" in px and "o" in pz) or ("z" in pz and "o" in px): + props += "o" + props += "du" + return props + + +class OrC(Miniscript): + # [X] NOTIF [Z] ENDIF + NAME = "or_c" + NARGS = 2 + ARGCLS = Miniscript + TYPE = "V" + + def inner_compile(self): + return self.args[0].compile() + b"\x64" + self.args[1].compile() + b"\x68" + + def __len__(self): + return self.len_args() + 2 + + def verify(self): + # X is Bdu; Z is V + super().verify() + if self.args[0].type != "B": + raise MiniscriptError("or_c: X should be B") + if self.args[1].type != "V": + raise MiniscriptError("or_c: Z should be V") + px = self.args[0].properties + if "d" not in px or "u" not in px: + raise MiniscriptError("or_c: X should be du") + + @property + def properties(self): + # z=zXzZ; o=oXzZ + props = "" + px, pz = [arg.properties for arg in self.args] + if "z" in px and "z" in pz: + props += "z" + if "o" in px and "z" in pz: + props += "o" + return props + + +class OrD(Miniscript): + # [X] IFDUP NOTIF [Z] ENDIF + NAME = "or_d" + NARGS = 2 + ARGCLS = Miniscript + TYPE = "B" + + def inner_compile(self): + return self.args[0].compile() + b"\x73\x64" + self.args[1].compile() + b"\x68" + + def __len__(self): + return self.len_args() + 3 + + def verify(self): + # X is Bdu; Z is B + super().verify() + if self.args[0].type != "B": + raise MiniscriptError("or_d: X should be B") + if self.args[1].type != "B": + raise MiniscriptError("or_d: Z should be B") + px = self.args[0].properties + if "d" not in px or "u" not in px: + raise MiniscriptError("or_d: X should be du") + + @property + def properties(self): + # z=zXzZ; o=oXzZ; d=dZ; u=uZ + props = "" + px, pz = [arg.properties for arg in self.args] + if "z" in px and "z" in pz: + props += "z" + if "o" in px and "z" in pz: + props += "o" + if "d" in pz: + props += "d" + if "u" in pz: + props += "u" + return props + + +class OrI(Miniscript): + # IF [X] ELSE [Z] ENDIF + NAME = "or_i" + NARGS = 2 + ARGCLS = Miniscript + + def inner_compile(self): + return ( + b"\x63" + + self.args[0].compile() + + b"\x67" + + self.args[1].compile() + + b"\x68" + ) + + def __len__(self): + return self.len_args() + 3 + + def verify(self): + # both are B, K, or V + super().verify() + if self.args[0].type != self.args[1].type: + raise MiniscriptError("or_i: X and Z should be the same type") + if self.args[0].type not in "BKV": + raise MiniscriptError("or_i: X and Z should be B K or V") + + @property + def type(self): + return self.args[0].type + + @property + def properties(self): + # o=zXzZ; u=uXuZ; d=dX or dZ + props = "" + px, pz = [arg.properties for arg in self.args] + if "z" in px and "z" in pz: + props += "o" + if "u" in px and "u" in pz: + props += "u" + if "d" in px or "d" in pz: + props += "d" + return props + + +class Thresh(Miniscript): + # [X1] [X2] ADD ... [Xn] ADD ... EQUAL + NAME = "thresh" + NARGS = None + ARGCLS = (Number, Miniscript) + TYPE = "B" + + def inner_compile(self): + return ( + self.args[1].compile() + + b"".join([arg.compile() + b"\x93" for arg in self.args[2:]]) + + self.args[0].compile() + + b"\x87" + ) + + def __len__(self): + return self.len_args() + len(self.args) - 1 + + def verify(self): + # 1 <= k <= n; X1 is Bdu; others are Wdu + super().verify() + if self.args[0].num < 1 or self.args[0].num >= len(self.args): + raise MiniscriptError( + "thresh: Invalid k! Should be 1 <= k <= %d, got %d" + % (len(self.args) - 1, self.args[0].num) + ) + if self.args[1].type != "B": + raise MiniscriptError("thresh: X1 should be B") + px = self.args[1].properties + if "d" not in px or "u" not in px: + raise MiniscriptError("thresh: X1 should be du") + for i, arg in enumerate(self.args[2:]): + if arg.type != "W": + raise MiniscriptError("thresh: X%d should be W" % (i + 1)) + p = arg.properties + if "d" not in p or "u" not in p: + raise MiniscriptError("thresh: X%d should be du" % (i + 1)) + + @property + def properties(self): + # z=all are z; o=all are z except one is o; d; u + props = "" + parr = [arg.properties for arg in self.args[1:]] + zarr = ["z" for p in parr if "z" in p] + if len(zarr) == len(parr): + props += "z" + noz = [p for p in parr if "z" not in p] + if len(noz) == 1 and "o" in noz[0]: + props += "o" + props += "du" + return props + + +class Multi(Miniscript): + # ... CHECKMULTISIG + NAME = "multi" + NARGS = None + ARGCLS = (Number, Key) + TYPE = "B" + PROPS = "ndu" + _expected_taproot = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.taproot is not self._expected_taproot: + raise MiniscriptError( + "%s can't be used if taproot is %s" % (self.NAME, self.taproot) + ) + + def inner_compile(self): + return ( + b"".join([arg.compile() for arg in self.args]) + + Number(len(self.args) - 1).compile() + + b"\xae" + ) + + def __len__(self): + return self.len_args() + 2 + + def verify(self): + super().verify() + if self.args[0].num < 1 or self.args[0].num > (len(self.args) - 1): + raise MiniscriptError( + "multi: 1 <= k <= %d, got %d" % ((len(self.args) - 1), self.args[0].num) + ) + + +class Sortedmulti(Multi): + # ... CHECKMULTISIG + NAME = "sortedmulti" + + def inner_compile(self): + return ( + self.args[0].compile() + + b"".join(sorted([arg.compile() for arg in self.args[1:]])) + + Number(len(self.args) - 1).compile() + + b"\xae" + ) + + +class MultiA(Multi): + # CHECKSIG CHECKSIGADD ... CHECKSIGNADD NUMEQUAL + NAME = "multi_a" + _expected_taproot = True + + def inner_compile(self): + return ( + self.args[1].compile() + + b"\xac" + + b"".join([arg.compile() + b"\xba" for arg in self.args[2:]]) + + self.args[0].compile() + + b"\x9c" + ) + + def __len__(self): + return self.len_args() + len(self.args) + + +class SortedmultiA(MultiA): + # CHECKSIG CHECKSIGADD ... CHECKSIGNADD NUMEQUAL + NAME = "sortedmulti_a" + + def inner_compile(self): + keys = list(sorted([k.compile() for k in self.args[1:]])) + return ( + keys[0] + + b"\xac" + + b"".join([k + b"\xba" for k in keys[1:]]) + + self.args[0].compile() + + b"\x9c" + ) + + +class Pk(OneArg): + # CHECKSIG + NAME = "pk" + ARGCLS = Key + TYPE = "B" + PROPS = "ondu" + + def inner_compile(self): + return self.carg + b"\xac" + + def __len__(self): + return self.len_args() + 1 + + +class Pkh(OneArg): + # DUP HASH160 EQUALVERIFY CHECKSIG + NAME = "pkh" + ARGCLS = KeyHash + TYPE = "B" + PROPS = "ndu" + + def inner_compile(self): + return b"\x76\xa9" + self.carg + b"\x88\xac" + + def __len__(self): + return self.len_args() + 4 + + # TODO: 0, 1 - they are without brackets, so it should be different... + + +OPERATORS = [ + NumberZero, + NumberOne, + PkK, + PkH, + Older, + After, + Sha256, + Hash256, + Ripemd160, + Hash160, + AndOr, + AndV, + AndB, + AndN, + OrB, + OrC, + OrD, + OrI, + Thresh, + Multi, + Sortedmulti, + MultiA, + SortedmultiA, + Pk, + Pkh, +] +OPERATOR_NAMES = [cls.NAME for cls in OPERATORS] + + +class Wrapper(OneArg): + ARGCLS = Miniscript + + @property + def op(self): + return type(self).__name__.lower() + + def __str__(self): + # more wrappers follow + if isinstance(self.arg, Wrapper): + return self.op + str(self.arg) + # we are the last wrapper + return self.op + ":" + str(self.arg) + + +class A(Wrapper): + # TOALTSTACK [X] FROMALTSTACK + TYPE = "W" + + def inner_compile(self): + return b"\x6b" + self.carg + b"\x6c" + + def __len__(self): + return len(self.arg) + 2 + + def verify(self): + super().verify() + if self.arg.type != "B": + raise MiniscriptError("a: X should be B") + + @property + def properties(self): + props = "" + px = self.arg.properties + if "d" in px: + props += "d" + if "u" in px: + props += "u" + return props + + +class S(Wrapper): + # SWAP [X] + TYPE = "W" + + def inner_compile(self): + return b"\x7c" + self.carg + + def __len__(self): + return len(self.arg) + 1 + + def verify(self): + super().verify() + if self.arg.type != "B": + raise MiniscriptError("s: X should be B") + if "o" not in self.arg.properties: + raise MiniscriptError("s: X should be o") + + @property + def properties(self): + props = "" + px = self.arg.properties + if "d" in px: + props += "d" + if "u" in px: + props += "u" + return props + + +class C(Wrapper): + # [X] CHECKSIG + TYPE = "B" + + def inner_compile(self): + return self.carg + b"\xac" + + def __len__(self): + return len(self.arg) + 1 + + def verify(self): + super().verify() + if self.arg.type != "K": + raise MiniscriptError("c: X should be K") + + @property + def properties(self): + props = "" + px = self.arg.properties + for p in ["o", "n", "d"]: + if p in px: + props += p + props += "u" + return props + + +class T(Wrapper): + # [X] 1 + TYPE = "B" + + def inner_compile(self): + return self.carg + Number(1).compile() + + def __len__(self): + return len(self.arg) + 1 + + @property + def properties(self): + # z=zXzY; o=zXoY or zYoX; n=nX or zXnY; u=uY + px = self.arg.properties + py = "zu" + props = "" + if "z" in px and "z" in py: + props += "z" + if ("z" in px and "o" in py) or ("z" in py and "o" in px): + props += "o" + if "n" in px or ("z" in px and "n" in py): + props += "n" + if "u" in py: + props += "u" + return props + + +class D(Wrapper): + # DUP IF [X] ENDIF + TYPE = "B" + + def inner_compile(self): + return b"\x76\x63" + self.carg + b"\x68" + + def __len__(self): + return len(self.arg) + 3 + + def verify(self): + super().verify() + if self.arg.type != "V": + raise MiniscriptError("d: X should be V") + if "z" not in self.arg.properties: + raise MiniscriptError("d: X should be z") + + @property + def properties(self): + # https://github.com/bitcoin/bitcoin/pull/24906 + if self.taproot: + props = "ndu" + else: + props = "nd" + px = self.arg.properties + if "z" in px: + props += "o" + return props + + +class V(Wrapper): + # [X] VERIFY (or VERIFY version of last opcode in [X]) + TYPE = "V" + + def inner_compile(self): + """Checks last check code and makes it verify""" + if self.carg[-1] in [0xAC, 0xAE, 0x9C, 0x87]: + return self.carg[:-1] + bytes([self.carg[-1] + 1]) + return self.carg + b"\x69" + + def verify(self): + super().verify() + if self.arg.type != "B": + raise MiniscriptError("v: X should be B") + + @property + def properties(self): + props = "" + px = self.arg.properties + for p in ["z", "o", "n"]: + if p in px: + props += p + return props + + +class J(Wrapper): + # SIZE 0NOTEQUAL IF [X] ENDIF + TYPE = "B" + + def inner_compile(self): + return b"\x82\x92\x63" + self.carg + b"\x68" + + def verify(self): + super().verify() + if self.arg.type != "B": + raise MiniscriptError("j: X should be B") + if "n" not in self.arg.properties: + raise MiniscriptError("j: X should be n") + + @property + def properties(self): + props = "nd" + px = self.arg.properties + for p in ["o", "u"]: + if p in px: + props += p + return props + + +class N(Wrapper): + # [X] 0NOTEQUAL + TYPE = "B" + + def inner_compile(self): + return self.carg + b"\x92" + + def __len__(self): + return len(self.arg) + 1 + + def verify(self): + super().verify() + if self.arg.type != "B": + raise MiniscriptError("n: X should be B") + + @property + def properties(self): + props = "u" + px = self.arg.properties + for p in ["z", "o", "n", "d"]: + if p in px: + props += p + return props + + +class L(Wrapper): + # IF 0 ELSE [X] ENDIF + TYPE = "B" + + def inner_compile(self): + return b"\x63" + Number(0).compile() + b"\x67" + self.carg + b"\x68" + + def __len__(self): + return len(self.arg) + 4 + + def verify(self): + # both are B, K, or V + super().verify() + if self.arg.type != "B": + raise MiniscriptError("or_i: X and Z should be the same type") + + @property + def properties(self): + # o=zXzZ; u=uXuZ; d=dX or dZ + props = "d" + pz = self.arg.properties + if "z" in pz: + props += "o" + if "u" in pz: + props += "u" + return props + + +class U(L): + # IF [X] ELSE 0 ENDIF + def inner_compile(self): + return b"\x63" + self.carg + b"\x67" + Number(0).compile() + b"\x68" + + def __len__(self): + return len(self.arg) + 4 + + +WRAPPERS = [A, S, C, T, D, V, J, N, L, U] +WRAPPER_NAMES = [w.__name__.lower() for w in WRAPPERS] diff --git a/bitcoin_client/ledger_bitcoin/embit/descriptor/taptree.py b/bitcoin_client/ledger_bitcoin/embit/descriptor/taptree.py new file mode 100644 index 000000000..7f611e5ec --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/descriptor/taptree.py @@ -0,0 +1,151 @@ +from .errors import MiniscriptError +from .base import DescriptorBase +from .miniscript import Miniscript +from ..hashes import tagged_hash +from ..script import Script + + +class TapLeaf(DescriptorBase): + def __init__(self, miniscript=None, version=0xC0): + self.miniscript = miniscript + self.version = version + + def __str__(self): + return str(self.miniscript) + + @classmethod + def read_from(cls, s): + ms = Miniscript.read_from(s, taproot=True) + return cls(ms) + + def serialize(self): + if self.miniscript is None: + return b"" + return bytes([self.version]) + Script(self.miniscript.compile()).serialize() + + @property + def keys(self): + return self.miniscript.keys + + def derive(self, *args, **kwargs): + if self.miniscript is None: + return type(self)(None, version=self.version) + return type(self)( + self.miniscript.derive(*args, **kwargs), + self.version, + ) + + def branch(self, *args, **kwargs): + if self.miniscript is None: + return type(self)(None, version=self.version) + return type(self)( + self.miniscript.branch(*args, **kwargs), + self.version, + ) + + def to_public(self, *args, **kwargs): + if self.miniscript is None: + return type(self)(None, version=self.version) + return type(self)( + self.miniscript.to_public(*args, **kwargs), + self.version, + ) + + +def _tweak_helper(tree): + # https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki#constructing-and-spending-taproot-outputs + if isinstance(tree, TapTree): + tree = tree.tree + if isinstance(tree, TapLeaf): + # one leaf on this branch + h = tagged_hash("TapLeaf", tree.serialize()) + return ([(tree, b"")], h) + left, left_h = _tweak_helper(tree[0]) + right, right_h = _tweak_helper(tree[1]) + ret = [(leaf, c + right_h) for leaf, c in left] + [ + (leaf, c + left_h) for leaf, c in right + ] + if right_h < left_h: + left_h, right_h = right_h, left_h + return (ret, tagged_hash("TapBranch", left_h + right_h)) + + +class TapTree(DescriptorBase): + def __init__(self, tree=None): + """tree can be None, TapLeaf or a tuple (taptree, taptree)""" + self.tree = tree + # make sure all keys are taproot + for k in self.keys: + k.taproot = True + + def __bool__(self): + return bool(self.tree) + + def tweak(self): + if self.tree is None: + return b"" + _, h = _tweak_helper(self.tree) + return h + + @property + def keys(self): + if self.tree is None: + return [] + if isinstance(self.tree, TapLeaf): + return self.tree.keys + left, right = self.tree + return left.keys + right.keys + + @classmethod + def read_from(cls, s): + c = s.read(1) + if len(c) == 0: + return cls() + if c == b"{": # more than one miniscript + left = cls.read_from(s) + c = s.read(1) + if c == b"}": + return left + if c != b",": + raise MiniscriptError("Invalid taptree syntax: expected ','") + right = cls.read_from(s) + if s.read(1) != b"}": + raise MiniscriptError("Invalid taptree syntax: expected '}'") + return cls((left, right)) + s.seek(-1, 1) + ms = TapLeaf.read_from(s) + return cls(ms) + + def derive(self, *args, **kwargs): + if self.tree is None: + return type(self)(None) + if isinstance(self.tree, TapLeaf): + return type(self)(self.tree.derive(*args, **kwargs)) + left, right = self.tree + return type(self)((left.derive(*args, **kwargs), right.derive(*args, **kwargs))) + + def branch(self, *args, **kwargs): + if self.tree is None: + return type(self)(None) + if isinstance(self.tree, TapLeaf): + return type(self)(self.tree.branch(*args, **kwargs)) + left, right = self.tree + return type(self)((left.branch(*args, **kwargs), right.branch(*args, **kwargs))) + + def to_public(self, *args, **kwargs): + if self.tree is None: + return type(self)(None) + if isinstance(self.tree, TapLeaf): + return type(self)(self.tree.to_public(*args, **kwargs)) + left, right = self.tree + return type(self)( + (left.to_public(*args, **kwargs), right.to_public(*args, **kwargs)) + ) + + def __str__(self): + if self.tree is None: + return "" + if isinstance(self.tree, TapLeaf): + return str(self.tree) + (left, right) = self.tree + return "{%s,%s}" % (left, right) diff --git a/bitcoin_client/ledger_bitcoin/embit/ec.py b/bitcoin_client/ledger_bitcoin/embit/ec.py new file mode 100644 index 000000000..a93fc7143 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/ec.py @@ -0,0 +1,263 @@ +from . import base58 +from . import hashes +from .misc import secp256k1 +from .networks import NETWORKS +from .base import EmbitBase, EmbitError, EmbitKey +from binascii import hexlify, unhexlify + + +class ECError(EmbitError): + pass + + +class Signature(EmbitBase): + def __init__(self, sig): + self._sig = sig + + def write_to(self, stream) -> int: + return stream.write(secp256k1.ecdsa_signature_serialize_der(self._sig)) + + @classmethod + def read_from(cls, stream): + der = stream.read(2) + der += stream.read(der[1]) + return cls(secp256k1.ecdsa_signature_parse_der(der)) + + +class SchnorrSig(EmbitBase): + def __init__(self, sig): + assert len(sig) == 64 + self._sig = sig + + def write_to(self, stream) -> int: + return stream.write(self._sig) + + @classmethod + def read_from(cls, stream): + return cls(stream.read(64)) + + +class PublicKey(EmbitKey): + def __init__(self, point: bytes, compressed: bool = True): + self._point = point + self.compressed = compressed + + @classmethod + def read_from(cls, stream): + b = stream.read(1) + if b not in [b"\x02", b"\x03", b"\x04"]: + raise ECError("Invalid public key") + if b == b"\x04": + b += stream.read(64) + else: + b += stream.read(32) + try: + point = secp256k1.ec_pubkey_parse(b) + except Exception as e: + raise ECError(str(e)) + compressed = b[0] != 0x04 + return cls(point, compressed) + + def sec(self) -> bytes: + """Sec representation of the key""" + flag = secp256k1.EC_COMPRESSED if self.compressed else secp256k1.EC_UNCOMPRESSED + return secp256k1.ec_pubkey_serialize(self._point, flag) + + def xonly(self) -> bytes: + return self.sec()[1:33] + + def taproot_tweak(self, h=b""): + """Returns a tweaked public key""" + x = self.xonly() + tweak = hashes.tagged_hash("TapTweak", x + h) + if not secp256k1.ec_seckey_verify(tweak): + raise EmbitError("Tweak is too large") + point = secp256k1.ec_pubkey_parse(b"\x02" + x) + pub = secp256k1.ec_pubkey_add(point, tweak) + sec = secp256k1.ec_pubkey_serialize(pub) + return PublicKey.from_xonly(sec[1:33]) + + def write_to(self, stream) -> int: + return stream.write(self.sec()) + + def serialize(self) -> bytes: + return self.sec() + + def verify(self, sig, msg_hash) -> bool: + return bool(secp256k1.ecdsa_verify(sig._sig, msg_hash, self._point)) + + def _xonly(self): + """Returns internal representation of the xonly-pubkey (64 bytes)""" + pub, _ = secp256k1.xonly_pubkey_from_pubkey(self._point) + return pub + + @classmethod + def from_xonly(cls, data: bytes): + assert len(data) == 32 + return cls.parse(b"\x02" + data) + + def schnorr_verify(self, sig, msg_hash) -> bool: + return bool(secp256k1.schnorrsig_verify(sig._sig, msg_hash, self._xonly())) + + @classmethod + def from_string(cls, s): + return cls.parse(unhexlify(s)) + + @property + def is_private(self) -> bool: + return False + + def to_string(self): + return hexlify(self.sec()).decode() + + def __lt__(self, other): + # for lexagraphic ordering + return self.sec() < other.sec() + + def __gt__(self, other): + # for lexagraphic ordering + return self.sec() > other.sec() + + def __eq__(self, other): + return self.sec() == other.sec() + + def __hash__(self): + return hash(self._point) + + +class PrivateKey(EmbitKey): + def __init__(self, secret, compressed: bool = True, network=NETWORKS["main"]): + """Creates a private key from 32-byte array""" + if len(secret) != 32: + raise ECError("Secret should be 32-byte array") + if not secp256k1.ec_seckey_verify(secret): + raise ECError("Secret is not valid (larger then N?)") + self.compressed = compressed + self._secret = secret + self.network = network + + def wif(self, network=None) -> str: + """Export private key as Wallet Import Format string. + Prefix 0x80 is used for mainnet, 0xEF for testnet. + This class doesn't store this information though. + """ + if network is None: + network = self.network + prefix = network["wif"] + b = prefix + self._secret + if self.compressed: + b += bytes([0x01]) + return base58.encode_check(b) + + @property + def secret(self): + return self._secret + + def sec(self) -> bytes: + """Sec representation of the corresponding public key""" + return self.get_public_key().sec() + + def xonly(self) -> bytes: + return self.sec()[1:] + + def taproot_tweak(self, h=b""): + """Returns a tweaked private key""" + sec = self.sec() + negate = sec[0] != 0x02 + x = sec[1:33] + tweak = hashes.tagged_hash("TapTweak", x + h) + if not secp256k1.ec_seckey_verify(tweak): + raise EmbitError("Tweak is too large") + if negate: + secret = secp256k1.ec_privkey_negate(self._secret) + else: + secret = self._secret + res = secp256k1.ec_privkey_add(secret, tweak) + pk = PrivateKey(res) + if pk.sec()[0] == 0x03: + pk = PrivateKey(secp256k1.ec_privkey_negate(res)) + return pk + + @classmethod + def from_wif(cls, s): + """Import private key from Wallet Import Format string.""" + b = base58.decode_check(s) + prefix = b[:1] + network = None + for net in NETWORKS: + if NETWORKS[net]["wif"] == prefix: + network = NETWORKS[net] + secret = b[1:33] + compressed = False + if len(b) not in [33, 34]: + raise ECError("Wrong WIF length") + if len(b) == 34: + if b[-1] == 0x01: + compressed = True + else: + raise ECError("Wrong WIF compressed flag") + return cls(secret, compressed, network) + + # to unify API + def to_base58(self, network=None) -> str: + return self.wif(network) + + @classmethod + def from_base58(cls, s): + return cls.from_wif(s) + + def get_public_key(self) -> PublicKey: + return PublicKey(secp256k1.ec_pubkey_create(self._secret), self.compressed) + + def to_public(self) -> PublicKey: + """Alias to get_public_key for API consistency""" + return self.get_public_key() + + def sign(self, msg_hash, grind=True) -> Signature: + sig = Signature(secp256k1.ecdsa_sign(msg_hash, self._secret)) + if grind: + counter = 1 + while len(sig.serialize()) > 70: + sig = Signature( + secp256k1.ecdsa_sign( + msg_hash, self._secret, None, counter.to_bytes(32, "little") + ) + ) + counter += 1 + # just in case we get in infinite loop for some reason + if counter > 200: + break + return sig + + def schnorr_sign(self, msg_hash) -> SchnorrSig: + return SchnorrSig(secp256k1.schnorrsig_sign(msg_hash, self._secret)) + + def verify(self, sig, msg_hash) -> bool: + return self.get_public_key().verify(sig, msg_hash) + + def schnorr_verify(self, sig, msg_hash) -> bool: + return self.get_public_key().schnorr_verify(sig, msg_hash) + + def write_to(self, stream) -> int: + # return a copy of the secret + return stream.write(self._secret) + + def ecdh(self, public_key: PublicKey, hashfn=None, data=None) -> bytes: + pubkey_point = secp256k1.ec_pubkey_parse(public_key.sec()) + return secp256k1.ecdh(pubkey_point, self._secret, hashfn, data) + + @classmethod + def read_from(cls, stream): + # just to unify the API + return cls(stream.read(32)) + + @property + def is_private(self) -> bool: + return True + + +# Nothing up my sleeve point for no-internal-key taproot +# see https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki#constructing-and-spending-taproot-outputs +NUMS_PUBKEY = PublicKey.from_string( + "0250929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0" +) diff --git a/bitcoin_client/ledger_bitcoin/embit/hashes.py b/bitcoin_client/ledger_bitcoin/embit/hashes.py new file mode 100644 index 000000000..c5edd081f --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/hashes.py @@ -0,0 +1,41 @@ +import hashlib + +try: + # this will work with micropython and python < 3.10 + # but will raise and exception if ripemd is not supported (python3.10, openssl 3) + hashlib.new("ripemd160") + + def ripemd160(msg: bytes) -> bytes: + return hashlib.new("ripemd160", msg).digest() + +except: + # otherwise use pure python implementation + from .util.py_ripemd160 import ripemd160 + + +def double_sha256(msg: bytes) -> bytes: + """sha256(sha256(msg)) -> bytes""" + return hashlib.sha256(hashlib.sha256(msg).digest()).digest() + + +def hash160(msg: bytes) -> bytes: + """ripemd160(sha256(msg)) -> bytes""" + return ripemd160(hashlib.sha256(msg).digest()) + + +def sha256(msg: bytes) -> bytes: + """one-line sha256(msg) -> bytes""" + return hashlib.sha256(msg).digest() + + +def tagged_hash(tag: str, data: bytes) -> bytes: + """BIP-Schnorr tag-specific key derivation""" + hashtag = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(hashtag + hashtag + data).digest() + + +def tagged_hash_init(tag: str, data: bytes = b""): + """Prepares a tagged hash function to digest extra data""" + hashtag = hashlib.sha256(tag.encode()).digest() + h = hashlib.sha256(hashtag + hashtag + data) + return h diff --git a/bitcoin_client/ledger_bitcoin/embit/misc.py b/bitcoin_client/ledger_bitcoin/embit/misc.py new file mode 100644 index 000000000..fc2c8046d --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/misc.py @@ -0,0 +1,70 @@ +"""Misc utility functions used across embit""" +import sys + +# implementation-specific functions and libraries: +if sys.implementation.name == "micropython": + from micropython import const + import secp256k1 +else: + from .util import secp256k1 + + def const(x): + return x + + +try: + # if urandom is available from os module: + from os import urandom as urandom +except ImportError: + # otherwise - try reading from /dev/urandom + def urandom(n: int) -> bytes: + with open("/dev/urandom", "rb") as f: + return f.read(n) + + +def getrandbits(k: int) -> int: + b = urandom(k // 8 + 1) + return int.from_bytes(b, "big") % (2**k) + + +def secure_randint(vmin: int, vmax: int) -> int: + """ + Normal random.randint uses PRNG that is not suitable + for cryptographic applications. + This one uses os.urandom for randomness. + """ + import math + + assert vmax > vmin + delta = vmax - vmin + nbits = math.ceil(math.log2(delta + 1)) + randn = getrandbits(nbits) + while randn > delta: + randn = getrandbits(nbits) + return vmin + randn + + +def copy(a: bytes) -> bytes: + """Ugly copy that works everywhere incl micropython""" + if len(a) == 0: + return b"" + return a[:1] + a[1:] + + +def read_until(s, chars=b",)(#"): + """Read from stream until one of `char` characters. + By default `chars=,)(#`. + + Return a tuple (result: bytes, char: bytes | None) + where result is bytes read from the stream until char, + char contains this character or None if the end of stream reached. + """ + res = b"" + chunk = b"" + while True: + chunk = s.read(1) + if len(chunk) == 0: + return res, None + if chunk in chars: + return res, chunk + res += chunk diff --git a/bitcoin_client/ledger_bitcoin/embit/networks.py b/bitcoin_client/ledger_bitcoin/embit/networks.py new file mode 100644 index 000000000..6f1a54180 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/networks.py @@ -0,0 +1,76 @@ +from .misc import const + +NETWORKS = { + "main": { + "name": "Mainnet", + "wif": b"\x80", + "p2pkh": b"\x00", + "p2sh": b"\x05", + "bech32": "bc", + "xprv": b"\x04\x88\xad\xe4", + "xpub": b"\x04\x88\xb2\x1e", + "yprv": b"\x04\x9d\x78\x78", + "zprv": b"\x04\xb2\x43\x0c", + "Yprv": b"\x02\x95\xb0\x05", + "Zprv": b"\x02\xaa\x7a\x99", + "ypub": b"\x04\x9d\x7c\xb2", + "zpub": b"\x04\xb2\x47\x46", + "Ypub": b"\x02\x95\xb4\x3f", + "Zpub": b"\x02\xaa\x7e\xd3", + "bip32": const(0), # coin type for bip32 derivation + }, + "test": { + "name": "Testnet", + "wif": b"\xEF", + "p2pkh": b"\x6F", + "p2sh": b"\xC4", + "bech32": "tb", + "xprv": b"\x04\x35\x83\x94", + "xpub": b"\x04\x35\x87\xcf", + "yprv": b"\x04\x4a\x4e\x28", + "zprv": b"\x04\x5f\x18\xbc", + "Yprv": b"\x02\x42\x85\xb5", + "Zprv": b"\x02\x57\x50\x48", + "ypub": b"\x04\x4a\x52\x62", + "zpub": b"\x04\x5f\x1c\xf6", + "Ypub": b"\x02\x42\x89\xef", + "Zpub": b"\x02\x57\x54\x83", + "bip32": const(1), + }, + "regtest": { + "name": "Regtest", + "wif": b"\xEF", + "p2pkh": b"\x6F", + "p2sh": b"\xC4", + "bech32": "bcrt", + "xprv": b"\x04\x35\x83\x94", + "xpub": b"\x04\x35\x87\xcf", + "yprv": b"\x04\x4a\x4e\x28", + "zprv": b"\x04\x5f\x18\xbc", + "Yprv": b"\x02\x42\x85\xb5", + "Zprv": b"\x02\x57\x50\x48", + "ypub": b"\x04\x4a\x52\x62", + "zpub": b"\x04\x5f\x1c\xf6", + "Ypub": b"\x02\x42\x89\xef", + "Zpub": b"\x02\x57\x54\x83", + "bip32": const(1), + }, + "signet": { + "name": "Signet", + "wif": b"\xEF", + "p2pkh": b"\x6F", + "p2sh": b"\xC4", + "bech32": "tb", + "xprv": b"\x04\x35\x83\x94", + "xpub": b"\x04\x35\x87\xcf", + "yprv": b"\x04\x4a\x4e\x28", + "zprv": b"\x04\x5f\x18\xbc", + "Yprv": b"\x02\x42\x85\xb5", + "Zprv": b"\x02\x57\x50\x48", + "ypub": b"\x04\x4a\x52\x62", + "zpub": b"\x04\x5f\x1c\xf6", + "Ypub": b"\x02\x42\x89\xef", + "Zpub": b"\x02\x57\x54\x83", + "bip32": const(1), + }, +} diff --git a/bitcoin_client/ledger_bitcoin/embit/script.py b/bitcoin_client/ledger_bitcoin/embit/script.py new file mode 100644 index 000000000..5cea7f98f --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/script.py @@ -0,0 +1,212 @@ +from .networks import NETWORKS +from . import base58 +from . import bech32 +from . import hashes +from . import compact +from .base import EmbitBase, EmbitError + +SIGHASH_ALL = 1 + + +class Script(EmbitBase): + def __init__(self, data=b""): + self.data = data + + def address(self, network=NETWORKS["main"]): + script_type = self.script_type() + data = self.data + + if script_type is None: + raise ValueError("This type of script doesn't have address representation") + + if script_type == "p2pkh": + d = network["p2pkh"] + data[3:23] + return base58.encode_check(d) + + if script_type == "p2sh": + d = network["p2sh"] + data[2:22] + return base58.encode_check(d) + + if script_type in ["p2wpkh", "p2wsh", "p2tr"]: + ver = data[0] + # FIXME: should be one of OP_N + if ver > 0: + ver = ver % 0x50 + return bech32.encode(network["bech32"], ver, data[2:]) + + # we should never get here + raise ValueError("Unsupported script type") + + def push(self, data): + self.data += compact.to_bytes(len(data)) + data + + def script_type(self): + data = self.data + # OP_DUP OP_HASH160 <20:hash160(pubkey)> OP_EQUALVERIFY OP_CHECKSIG + if len(data) == 25 and data[:3] == b"\x76\xa9\x14" and data[-2:] == b"\x88\xac": + return "p2pkh" + # OP_HASH160 <20:hash160(script)> OP_EQUAL + if len(data) == 23 and data[:2] == b"\xa9\x14" and data[-1] == 0x87: + return "p2sh" + # 0 <20:hash160(pubkey)> + if len(data) == 22 and data[:2] == b"\x00\x14": + return "p2wpkh" + # 0 <32:sha256(script)> + if len(data) == 34 and data[:2] == b"\x00\x20": + return "p2wsh" + # OP_1 + if len(data) == 34 and data[:2] == b"\x51\x20": + return "p2tr" + # unknown type + return None + + def write_to(self, stream): + res = stream.write(compact.to_bytes(len(self.data))) + res += stream.write(self.data) + return res + + @classmethod + def read_from(cls, stream): + l = compact.read_from(stream) + data = stream.read(l) + if len(data) != l: + raise ValueError("Cant read %d bytes" % l) + return cls(data) + + @classmethod + def from_address(cls, addr: str): + """ + Decodes a bitcoin address and returns corresponding scriptpubkey. + """ + return address_to_scriptpubkey(addr) + + def __eq__(self, other): + return self.data == other.data + + def __ne__(self, other): + return self.data != other.data + + def __hash__(self): + return hash(self.data) + + def __len__(self): + return len(self.data) + + +class Witness(EmbitBase): + def __init__(self, items=[]): + self.items = items[:] + + def write_to(self, stream): + res = stream.write(compact.to_bytes(len(self.items))) + for item in self.items: + res += stream.write(compact.to_bytes(len(item))) + res += stream.write(item) + return res + + @classmethod + def read_from(cls, stream): + num = compact.read_from(stream) + items = [] + for i in range(num): + l = compact.read_from(stream) + data = stream.read(l) + items.append(data) + return cls(items) + + def __hash__(self): + return hash(self.items) + + def __len__(self): + return len(self.items) + + +def p2pkh(pubkey): + """Return Pay-To-Pubkey-Hash ScriptPubkey""" + return Script(b"\x76\xa9\x14" + hashes.hash160(pubkey.sec()) + b"\x88\xac") + + +def p2sh(script): + """Return Pay-To-Script-Hash ScriptPubkey""" + return Script(b"\xa9\x14" + hashes.hash160(script.data) + b"\x87") + + +def p2wpkh(pubkey): + """Return Pay-To-Witness-Pubkey-Hash ScriptPubkey""" + return Script(b"\x00\x14" + hashes.hash160(pubkey.sec())) + + +def p2wsh(script): + """Return Pay-To-Witness-Pubkey-Hash ScriptPubkey""" + return Script(b"\x00\x20" + hashes.sha256(script.data)) + + +def p2tr(pubkey, script_tree=None): + """Return Pay-To-Taproot ScriptPubkey""" + if script_tree is None: + h = b"" + else: + h = script_tree.tweak() + output_pubkey = pubkey.taproot_tweak(h) + return Script(b"\x51\x20" + output_pubkey.xonly()) + + +def p2pkh_from_p2wpkh(script): + """Convert p2wpkh to p2pkh script""" + return Script(b"\x76\xa9" + script.serialize()[2:] + b"\x88\xac") + + +def multisig(m: int, pubkeys): + if m <= 0 or m > 16: + raise ValueError("m must be between 1 and 16") + n = len(pubkeys) + if n < m or n > 16: + raise ValueError("Number of pubkeys must be between %d and 16" % m) + data = bytes([80 + m]) + for pubkey in pubkeys: + sec = pubkey.sec() + data += bytes([len(sec)]) + sec + # OP_m ... OP_n OP_CHECKMULTISIG + data += bytes([80 + n, 0xAE]) + return Script(data) + + +def address_to_scriptpubkey(addr): + # try with base58 address + try: + data = base58.decode_check(addr) + prefix = data[:1] + for net in NETWORKS.values(): + if prefix == net["p2pkh"]: + return Script(b"\x76\xa9\x14" + data[1:] + b"\x88\xac") + elif prefix == net["p2sh"]: + return Script(b"\xa9\x14" + data[1:] + b"\x87") + except: + # fail - then it's bech32 address + hrp = addr.split("1")[0] + ver, data = bech32.decode(hrp, addr) + if ver not in [0, 1] or len(data) not in [20, 32]: + raise EmbitError("Invalid bech32 address") + if ver == 1 and len(data) != 32: + raise EmbitError("Invalid bech32 address") + # OP_1..OP_N + if ver > 0: + ver += 0x50 + return Script(bytes([ver, len(data)] + data)) + + +def script_sig_p2pkh(signature, pubkey, sighash=SIGHASH_ALL): + sec = pubkey.sec() + der = signature.serialize() + bytes([sighash]) + data = compact.to_bytes(len(der)) + der + compact.to_bytes(len(sec)) + sec + return Script(data) + + +def script_sig_p2sh(redeem_script): + """Creates scriptsig for p2sh""" + # FIXME: implement for legacy p2sh as well + return Script(redeem_script.serialize()) + + +def witness_p2wpkh(signature, pubkey, sighash=SIGHASH_ALL): + return Witness([signature.serialize() + bytes([sighash]), pubkey.sec()]) diff --git a/bitcoin_client/ledger_bitcoin/embit/util/__init__.py b/bitcoin_client/ledger_bitcoin/embit/util/__init__.py new file mode 100644 index 000000000..d2f2564a6 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/util/__init__.py @@ -0,0 +1,6 @@ +from . import secp256k1 + +try: + from micropython import const +except: + const = lambda x: x diff --git a/bitcoin_client/ledger_bitcoin/embit/util/ctypes_secp256k1.py b/bitcoin_client/ledger_bitcoin/embit/util/ctypes_secp256k1.py new file mode 100644 index 000000000..232abda18 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/util/ctypes_secp256k1.py @@ -0,0 +1,1202 @@ +import ctypes, os +import ctypes.util +import platform +import threading + +from ctypes import ( + cast, + byref, + c_char, + c_byte, + c_int, + c_uint, + c_char_p, + c_size_t, + c_void_p, + c_uint64, + create_string_buffer, + CFUNCTYPE, + POINTER, +) + +_lock = threading.Lock() + + +# @locked decorator +def locked(func): + def wrapper(*args, **kwargs): + with _lock: + return func(*args, **kwargs) + + return wrapper + + +# Flags to pass to context_create. +CONTEXT_VERIFY = 0b0100000001 +CONTEXT_SIGN = 0b1000000001 +CONTEXT_NONE = 0b0000000001 + +# Flags to pass to ec_pubkey_serialize +EC_COMPRESSED = 0b0100000010 +EC_UNCOMPRESSED = 0b0000000010 + + +def _copy(a: bytes) -> bytes: + """Ugly copy that works everywhere incl micropython""" + if len(a) == 0: + return b"" + return a[:1] + a[1:] + + +def _find_library(): + library_path = None + extension = "" + if platform.system() == "Darwin": + extension = ".dylib" + elif platform.system() == "Linux": + extension = ".so" + elif platform.system() == "Windows": + extension = ".dll" + + path = os.path.join( + os.path.dirname(__file__), + "prebuilt/libsecp256k1_%s_%s%s" + % (platform.system().lower(), platform.machine().lower(), extension), + ) + if os.path.isfile(path): + return path + # try searching + if not library_path: + library_path = ctypes.util.find_library("libsecp256k1") + if not library_path: + library_path = ctypes.util.find_library("secp256k1") + # library search failed + if not library_path: + if platform.system() == "Linux" and os.path.isfile( + "/usr/local/lib/libsecp256k1.so.0" + ): + library_path = "/usr/local/lib/libsecp256k1.so.0" + return library_path + + +@locked +def _init(flags=(CONTEXT_SIGN | CONTEXT_VERIFY)): + library_path = _find_library() + # meh, can't find library + if not library_path: + raise RuntimeError( + "Can't find libsecp256k1 library. Make sure to compile and install it." + ) + + secp256k1 = ctypes.cdll.LoadLibrary(library_path) + + secp256k1.secp256k1_context_create.argtypes = [c_uint] + secp256k1.secp256k1_context_create.restype = c_void_p + + secp256k1.secp256k1_context_randomize.argtypes = [c_void_p, c_char_p] + secp256k1.secp256k1_context_randomize.restype = c_int + + secp256k1.secp256k1_ec_seckey_verify.argtypes = [c_void_p, c_char_p] + secp256k1.secp256k1_ec_seckey_verify.restype = c_int + + secp256k1.secp256k1_ec_privkey_negate.argtypes = [c_void_p, c_char_p] + secp256k1.secp256k1_ec_privkey_negate.restype = c_int + + secp256k1.secp256k1_ec_pubkey_negate.argtypes = [c_void_p, c_char_p] + secp256k1.secp256k1_ec_pubkey_negate.restype = c_int + + secp256k1.secp256k1_ec_privkey_tweak_add.argtypes = [c_void_p, c_char_p, c_char_p] + secp256k1.secp256k1_ec_privkey_tweak_add.restype = c_int + + secp256k1.secp256k1_ec_privkey_tweak_mul.argtypes = [c_void_p, c_char_p, c_char_p] + secp256k1.secp256k1_ec_privkey_tweak_mul.restype = c_int + + secp256k1.secp256k1_ec_pubkey_create.argtypes = [c_void_p, c_void_p, c_char_p] + secp256k1.secp256k1_ec_pubkey_create.restype = c_int + + secp256k1.secp256k1_ec_pubkey_parse.argtypes = [c_void_p, c_char_p, c_char_p, c_int] + secp256k1.secp256k1_ec_pubkey_parse.restype = c_int + + secp256k1.secp256k1_ec_pubkey_serialize.argtypes = [ + c_void_p, + c_char_p, + c_void_p, + c_char_p, + c_uint, + ] + secp256k1.secp256k1_ec_pubkey_serialize.restype = c_int + + secp256k1.secp256k1_ec_pubkey_tweak_add.argtypes = [c_void_p, c_char_p, c_char_p] + secp256k1.secp256k1_ec_pubkey_tweak_add.restype = c_int + + secp256k1.secp256k1_ec_pubkey_tweak_mul.argtypes = [c_void_p, c_char_p, c_char_p] + secp256k1.secp256k1_ec_pubkey_tweak_mul.restype = c_int + + secp256k1.secp256k1_ecdsa_signature_parse_compact.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_ecdsa_signature_parse_compact.restype = c_int + + secp256k1.secp256k1_ecdsa_signature_serialize_compact.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_ecdsa_signature_serialize_compact.restype = c_int + + secp256k1.secp256k1_ecdsa_signature_parse_der.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_uint, + ] + secp256k1.secp256k1_ecdsa_signature_parse_der.restype = c_int + + secp256k1.secp256k1_ecdsa_signature_serialize_der.argtypes = [ + c_void_p, + c_char_p, + c_void_p, + c_char_p, + ] + secp256k1.secp256k1_ecdsa_signature_serialize_der.restype = c_int + + secp256k1.secp256k1_ecdsa_sign.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_char_p, + c_void_p, + c_char_p, + ] + secp256k1.secp256k1_ecdsa_sign.restype = c_int + + secp256k1.secp256k1_ecdsa_verify.argtypes = [c_void_p, c_char_p, c_char_p, c_char_p] + secp256k1.secp256k1_ecdsa_verify.restype = c_int + + secp256k1.secp256k1_ec_pubkey_combine.argtypes = [ + c_void_p, + c_char_p, + c_void_p, + c_size_t, + ] + secp256k1.secp256k1_ec_pubkey_combine.restype = c_int + + # ecdh + try: + secp256k1.secp256k1_ecdh.argtypes = [ + c_void_p, # ctx + c_char_p, # output + c_char_p, # point + c_char_p, # scalar + CFUNCTYPE, # hashfp + c_void_p, # data + ] + secp256k1.secp256k1_ecdh.restype = c_int + except: + pass + + # schnorr sig + try: + secp256k1.secp256k1_xonly_pubkey_from_pubkey.argtypes = [ + c_void_p, # ctx + c_char_p, # xonly pubkey + POINTER(c_int), # parity + c_char_p, # pubkey + ] + secp256k1.secp256k1_xonly_pubkey_from_pubkey.restype = c_int + + secp256k1.secp256k1_schnorrsig_verify.argtypes = [ + c_void_p, # ctx + c_char_p, # sig + c_char_p, # msg + c_char_p, # pubkey + ] + secp256k1.secp256k1_schnorrsig_verify.restype = c_int + + secp256k1.secp256k1_schnorrsig_sign.argtypes = [ + c_void_p, # ctx + c_char_p, # sig + c_char_p, # msg + c_char_p, # keypair + c_void_p, # nonce_function + c_char_p, # extra data + ] + secp256k1.secp256k1_schnorrsig_sign.restype = c_int + + secp256k1.secp256k1_keypair_create.argtypes = [ + c_void_p, # ctx + c_char_p, # keypair + c_char_p, # secret + ] + secp256k1.secp256k1_keypair_create.restype = c_int + except: + pass + + # recoverable module + try: + secp256k1.secp256k1_ecdsa_sign_recoverable.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_char_p, + c_void_p, + c_void_p, + ] + secp256k1.secp256k1_ecdsa_sign_recoverable.restype = c_int + + secp256k1.secp256k1_ecdsa_recoverable_signature_parse_compact.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_int, + ] + secp256k1.secp256k1_ecdsa_recoverable_signature_parse_compact.restype = c_int + + secp256k1.secp256k1_ecdsa_recoverable_signature_serialize_compact.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_ecdsa_recoverable_signature_serialize_compact.restype = ( + c_int + ) + + secp256k1.secp256k1_ecdsa_recoverable_signature_convert.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_ecdsa_recoverable_signature_convert.restype = c_int + + secp256k1.secp256k1_ecdsa_recover.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_ecdsa_recover.restype = c_int + except: + pass + + # zkp modules + try: + # generator module + secp256k1.secp256k1_generator_parse.argtypes = [c_void_p, c_char_p, c_char_p] + secp256k1.secp256k1_generator_parse.restype = c_int + + secp256k1.secp256k1_generator_serialize.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_generator_serialize.restype = c_int + + secp256k1.secp256k1_generator_generate.argtypes = [c_void_p, c_char_p, c_char_p] + secp256k1.secp256k1_generator_generate.restype = c_int + + secp256k1.secp256k1_generator_generate_blinded.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_generator_generate_blinded.restype = c_int + + # pederson commitments + secp256k1.secp256k1_pedersen_commitment_parse.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_pedersen_commitment_parse.restype = c_int + + secp256k1.secp256k1_pedersen_commitment_serialize.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + ] + secp256k1.secp256k1_pedersen_commitment_serialize.restype = c_int + + secp256k1.secp256k1_pedersen_commit.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_uint64, + c_char_p, + ] + secp256k1.secp256k1_pedersen_commit.restype = c_int + + secp256k1.secp256k1_pedersen_blind_generator_blind_sum.argtypes = [ + c_void_p, # const secp256k1_context* ctx, + POINTER(c_uint64), # const uint64_t *value, + c_void_p, # const unsigned char* const* generator_blind, + c_void_p, # unsigned char* const* blinding_factor, + c_size_t, # size_t n_total, + c_size_t, # size_t n_inputs + ] + secp256k1.secp256k1_pedersen_blind_generator_blind_sum.restype = c_int + + secp256k1.secp256k1_pedersen_verify_tally.argtypes = [ + c_void_p, + c_void_p, + c_size_t, + c_void_p, + c_size_t, + ] + secp256k1.secp256k1_pedersen_verify_tally.restype = c_int + + # rangeproof + secp256k1.secp256k1_rangeproof_rewind.argtypes = [ + c_void_p, # ctx + c_char_p, # vbf out + POINTER(c_uint64), # value out + c_char_p, # message out + POINTER(c_uint64), # msg out len + c_char_p, # nonce + POINTER(c_uint64), # min value + POINTER(c_uint64), # max value + c_char_p, # pedersen commitment + c_char_p, # range proof + c_uint64, # proof len + c_char_p, # extra commitment (scriptpubkey) + c_uint64, # extra len + c_char_p, # generator + ] + secp256k1.secp256k1_rangeproof_rewind.restype = c_int + + secp256k1.secp256k1_rangeproof_verify.argtypes = [ + c_void_p, # ctx + POINTER(c_uint64), # min value + POINTER(c_uint64), # max value + c_char_p, # pedersen commitment + c_char_p, # proof + c_uint64, # proof len + c_char_p, # extra + c_uint64, # extra len + c_char_p, # generator + ] + secp256k1.secp256k1_rangeproof_verify.restype = c_int + + secp256k1.secp256k1_rangeproof_sign.argtypes = [ + c_void_p, # ctx + c_char_p, # proof + POINTER(c_uint64), # plen + c_uint64, # min_value + c_char_p, # commit + c_char_p, # blind + c_char_p, # nonce + c_int, # exp + c_int, # min_bits + c_uint64, # value + c_char_p, # message + c_uint64, # msg_len + c_char_p, # extra_commit + c_uint64, # extra_commit_len + c_char_p, # gen + ] + secp256k1.secp256k1_rangeproof_sign.restype = c_int + + # musig + secp256k1.secp256k1_musig_pubkey_combine.argtypes = [ + c_void_p, + c_void_p, + c_char_p, + c_void_p, + c_void_p, + c_size_t, + ] + secp256k1.secp256k1_musig_pubkey_combine.restype = c_int + + # surjection proofs + secp256k1.secp256k1_surjectionproof_initialize.argtypes = [ + c_void_p, # const secp256k1_context* ctx, + c_char_p, # secp256k1_surjectionproof* proof, + POINTER(c_size_t), # size_t *input_index, + c_void_p, # c_char_p, # const secp256k1_fixed_asset_tag* fixed_input_tags, + c_size_t, # const size_t n_input_tags, + c_size_t, # const size_t n_input_tags_to_use, + c_char_p, # const secp256k1_fixed_asset_tag* fixed_output_tag, + c_size_t, # const size_t n_max_iterations, + c_char_p, # const unsigned char *random_seed32 + ] + secp256k1.secp256k1_surjectionproof_initialize.restype = c_int + + secp256k1.secp256k1_surjectionproof_generate.argtypes = [ + c_void_p, # const secp256k1_context* ctx, + c_char_p, # secp256k1_surjectionproof* proof, + c_char_p, # const secp256k1_generator* ephemeral_input_tags, + c_size_t, # size_t n_ephemeral_input_tags, + c_char_p, # const secp256k1_generator* ephemeral_output_tag, + c_size_t, # size_t input_index, + c_char_p, # const unsigned char *input_blinding_key, + c_char_p, # const unsigned char *output_blinding_key + ] + secp256k1.secp256k1_surjectionproof_generate.restype = c_int + + secp256k1.secp256k1_surjectionproof_verify.argtypes = [ + c_void_p, # const secp256k1_context* ctx, + c_char_p, # const secp256k1_surjectionproof* proof, + c_char_p, # const secp256k1_generator* ephemeral_input_tags, + c_size_t, # size_t n_ephemeral_input_tags, + c_char_p, # const secp256k1_generator* ephemeral_output_tag + ] + secp256k1.secp256k1_surjectionproof_verify.restype = c_int + + secp256k1.secp256k1_surjectionproof_serialize.argtypes = [ + c_void_p, # const secp256k1_context* ctx, + c_char_p, # unsigned char *output, + POINTER(c_size_t), # size_t *outputlen, + c_char_p, # const secp256k1_surjectionproof *proof + ] + secp256k1.secp256k1_surjectionproof_serialize.restype = c_int + + secp256k1.secp256k1_surjectionproof_serialized_size.argtypes = [ + c_void_p, # const secp256k1_context* ctx, + c_char_p, # const secp256k1_surjectionproof* proof + ] + secp256k1.secp256k1_surjectionproof_serialized_size.restype = c_size_t + + secp256k1.secp256k1_surjectionproof_parse.argtypes = [ + c_void_p, + c_char_p, + c_char_p, + c_size_t, + ] + secp256k1.secp256k1_surjectionproof_parse.restype = c_int + + except: + pass + + secp256k1.ctx = secp256k1.secp256k1_context_create(flags) + + r = secp256k1.secp256k1_context_randomize(secp256k1.ctx, os.urandom(32)) + + return secp256k1 + + +_secp = _init() + + +# bindings equal to ones in micropython +@locked +def context_randomize(seed, context=_secp.ctx): + if len(seed) != 32: + raise ValueError("Seed should be 32 bytes long") + if _secp.secp256k1_context_randomize(context, seed) == 0: + raise RuntimeError("Failed to randomize context") + + +@locked +def ec_pubkey_create(secret, context=_secp.ctx): + if len(secret) != 32: + raise ValueError("Private key should be 32 bytes long") + pub = bytes(64) + r = _secp.secp256k1_ec_pubkey_create(context, pub, secret) + if r == 0: + raise ValueError("Invalid private key") + return pub + + +@locked +def ec_pubkey_parse(sec, context=_secp.ctx): + if len(sec) != 33 and len(sec) != 65: + raise ValueError("Serialized pubkey should be 33 or 65 bytes long") + if len(sec) == 33: + if sec[0] != 0x02 and sec[0] != 0x03: + raise ValueError("Compressed pubkey should start with 0x02 or 0x03") + else: + if sec[0] != 0x04: + raise ValueError("Uncompressed pubkey should start with 0x04") + pub = bytes(64) + r = _secp.secp256k1_ec_pubkey_parse(context, pub, sec, len(sec)) + if r == 0: + raise ValueError("Failed parsing public key") + return pub + + +@locked +def ec_pubkey_serialize(pubkey, flag=EC_COMPRESSED, context=_secp.ctx): + if len(pubkey) != 64: + raise ValueError("Pubkey should be 64 bytes long") + if flag not in [EC_COMPRESSED, EC_UNCOMPRESSED]: + raise ValueError("Invalid flag") + sec = bytes(33) if (flag == EC_COMPRESSED) else bytes(65) + sz = c_size_t(len(sec)) + r = _secp.secp256k1_ec_pubkey_serialize(context, sec, byref(sz), pubkey, flag) + if r == 0: + raise ValueError("Failed to serialize pubkey") + return sec + + +@locked +def ecdsa_signature_parse_compact(compact_sig, context=_secp.ctx): + if len(compact_sig) != 64: + raise ValueError("Compact signature should be 64 bytes long") + sig = bytes(64) + r = _secp.secp256k1_ecdsa_signature_parse_compact(context, sig, compact_sig) + if r == 0: + raise ValueError("Failed parsing compact signature") + return sig + + +@locked +def ecdsa_signature_parse_der(der, context=_secp.ctx): + sig = bytes(64) + r = _secp.secp256k1_ecdsa_signature_parse_der(context, sig, der, len(der)) + if r == 0: + raise ValueError("Failed parsing compact signature") + return sig + + +@locked +def ecdsa_signature_serialize_der(sig, context=_secp.ctx): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + der = bytes(78) # max + sz = c_size_t(len(der)) + r = _secp.secp256k1_ecdsa_signature_serialize_der(context, der, byref(sz), sig) + if r == 0: + raise ValueError("Failed serializing der signature") + return der[: sz.value] + + +@locked +def ecdsa_signature_serialize_compact(sig, context=_secp.ctx): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + ser = bytes(64) + r = _secp.secp256k1_ecdsa_signature_serialize_compact(context, ser, sig) + if r == 0: + raise ValueError("Failed serializing der signature") + return ser + + +@locked +def ecdsa_signature_normalize(sig, context=_secp.ctx): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + sig2 = bytes(64) + r = _secp.secp256k1_ecdsa_signature_normalize(context, sig2, sig) + return sig2 + + +@locked +def ecdsa_verify(sig, msg, pub, context=_secp.ctx): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") + if len(pub) != 64: + raise ValueError("Public key should be 64 bytes long") + r = _secp.secp256k1_ecdsa_verify(context, sig, msg, pub) + return bool(r) + + +@locked +def ecdsa_sign(msg, secret, nonce_function=None, extra_data=None, context=_secp.ctx): + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") + if len(secret) != 32: + raise ValueError("Secret key should be 32 bytes long") + if extra_data and len(extra_data) != 32: + raise ValueError("Extra data should be 32 bytes long") + sig = bytes(64) + r = _secp.secp256k1_ecdsa_sign( + context, sig, msg, secret, nonce_function, extra_data + ) + if r == 0: + raise ValueError("Failed to sign") + return sig + + +@locked +def ec_seckey_verify(secret, context=_secp.ctx): + if len(secret) != 32: + raise ValueError("Secret should be 32 bytes long") + return bool(_secp.secp256k1_ec_seckey_verify(context, secret)) + + +@locked +def ec_privkey_negate(secret, context=_secp.ctx): + if len(secret) != 32: + raise ValueError("Secret should be 32 bytes long") + b = _copy(secret) + _secp.secp256k1_ec_privkey_negate(context, b) + return b + + +@locked +def ec_pubkey_negate(pubkey, context=_secp.ctx): + if len(pubkey) != 64: + raise ValueError("Pubkey should be a 64-byte structure") + pub = _copy(pubkey) + r = _secp.secp256k1_ec_pubkey_negate(context, pub) + if r == 0: + raise ValueError("Failed to negate pubkey") + return pub + + +@locked +def ec_privkey_tweak_add(secret, tweak, context=_secp.ctx): + if len(secret) != 32 or len(tweak) != 32: + raise ValueError("Secret and tweak should both be 32 bytes long") + t = _copy(tweak) + if _secp.secp256k1_ec_privkey_tweak_add(context, secret, tweak) == 0: + raise ValueError("Failed to tweak the secret") + return None + + +@locked +def ec_pubkey_tweak_add(pub, tweak, context=_secp.ctx): + if len(pub) != 64: + raise ValueError("Public key should be 64 bytes long") + if len(tweak) != 32: + raise ValueError("Tweak should be 32 bytes long") + t = _copy(tweak) + if _secp.secp256k1_ec_pubkey_tweak_add(context, pub, tweak) == 0: + raise ValueError("Failed to tweak the public key") + return None + + +@locked +def ec_privkey_add(secret, tweak, context=_secp.ctx): + if len(secret) != 32 or len(tweak) != 32: + raise ValueError("Secret and tweak should both be 32 bytes long") + # ugly copy that works in mpy and py + s = _copy(secret) + t = _copy(tweak) + if _secp.secp256k1_ec_privkey_tweak_add(context, s, t) == 0: + raise ValueError("Failed to tweak the secret") + return s + + +@locked +def ec_pubkey_add(pub, tweak, context=_secp.ctx): + if len(pub) != 64: + raise ValueError("Public key should be 64 bytes long") + if len(tweak) != 32: + raise ValueError("Tweak should be 32 bytes long") + p = _copy(pub) + if _secp.secp256k1_ec_pubkey_tweak_add(context, p, tweak) == 0: + raise ValueError("Failed to tweak the public key") + return p + + +@locked +def ec_privkey_tweak_mul(secret, tweak, context=_secp.ctx): + if len(secret) != 32 or len(tweak) != 32: + raise ValueError("Secret and tweak should both be 32 bytes long") + if _secp.secp256k1_ec_privkey_tweak_mul(context, secret, tweak) == 0: + raise ValueError("Failed to tweak the secret") + + +@locked +def ec_pubkey_tweak_mul(pub, tweak, context=_secp.ctx): + if len(pub) != 64: + raise ValueError("Public key should be 64 bytes long") + if len(tweak) != 32: + raise ValueError("Tweak should be 32 bytes long") + if _secp.secp256k1_ec_pubkey_tweak_mul(context, pub, tweak) == 0: + raise ValueError("Failed to tweak the public key") + + +@locked +def ec_pubkey_combine(*args, context=_secp.ctx): + pub = bytes(64) + pubkeys = (c_char_p * len(args))(*args) + r = _secp.secp256k1_ec_pubkey_combine(context, pub, pubkeys, len(args)) + if r == 0: + raise ValueError("Failed to combine pubkeys") + return pub + + +# ecdh +@locked +def ecdh(pubkey, scalar, hashfn=None, data=None, context=_secp.ctx): + if not len(pubkey) == 64: + raise ValueError("Pubkey should be 64 bytes long") + if not len(scalar) == 32: + raise ValueError("Scalar should be 32 bytes long") + secret = bytes(32) + if hashfn is None: + res = _secp.secp256k1_ecdh(context, secret, pubkey, scalar, None, None) + else: + + def _hashfn(out, x, y): + x = ctypes.string_at(x, 32) + y = ctypes.string_at(y, 32) + try: + res = hashfn(x, y, data) + except Exception as e: + return 0 + out = cast(out, POINTER(c_char * 32)) + out.contents.value = res + return 1 + + HASHFN = CFUNCTYPE(c_int, c_void_p, c_void_p, c_void_p) + res = _secp.secp256k1_ecdh( + context, secret, pubkey, scalar, HASHFN(_hashfn), data + ) + if res != 1: + raise RuntimeError("Failed to compute the shared secret") + return secret + + +# schnorrsig +@locked +def xonly_pubkey_from_pubkey(pubkey, context=_secp.ctx): + if len(pubkey) != 64: + raise ValueError("Pubkey should be 64 bytes long") + pointer = POINTER(c_int) + parity = pointer(c_int(0)) + xonly_pub = bytes(64) + res = _secp.secp256k1_xonly_pubkey_from_pubkey(context, xonly_pub, parity, pubkey) + if res != 1: + raise RuntimeError("Failed to convert the pubkey") + return xonly_pub, bool(parity.contents.value) + + +@locked +def schnorrsig_verify(sig, msg, pubkey, context=_secp.ctx): + assert len(sig) == 64 + assert len(msg) == 32 + assert len(pubkey) == 64 + res = _secp.secp256k1_schnorrsig_verify(context, sig, msg, pubkey) + return bool(res) + + +@locked +def keypair_create(secret, context=_secp.ctx): + assert len(secret) == 32 + keypair = bytes(96) + r = _secp.secp256k1_keypair_create(context, keypair, secret) + if r == 0: + raise ValueError("Failed to create keypair") + return keypair + + +# not @locked because it uses keypair_create inside +def schnorrsig_sign( + msg, keypair, nonce_function=None, extra_data=None, context=_secp.ctx +): + assert len(msg) == 32 + if len(keypair) == 32: + keypair = keypair_create(keypair, context=context) + with _lock: + assert len(keypair) == 96 + sig = bytes(64) + r = _secp.secp256k1_schnorrsig_sign( + context, sig, msg, keypair, nonce_function, extra_data + ) + if r == 0: + raise ValueError("Failed to sign") + return sig + + +# recoverable +@locked +def ecdsa_sign_recoverable(msg, secret, context=_secp.ctx): + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") + if len(secret) != 32: + raise ValueError("Secret key should be 32 bytes long") + sig = bytes(65) + r = _secp.secp256k1_ecdsa_sign_recoverable(context, sig, msg, secret, None, None) + if r == 0: + raise ValueError("Failed to sign") + return sig + + +@locked +def ecdsa_recoverable_signature_serialize_compact(sig, context=_secp.ctx): + if len(sig) != 65: + raise ValueError("Recoverable signature should be 65 bytes long") + ser = bytes(64) + idx = bytes(1) + r = _secp.secp256k1_ecdsa_recoverable_signature_serialize_compact( + context, ser, idx, sig + ) + if r == 0: + raise ValueError("Failed serializing der signature") + return ser, idx[0] + + +@locked +def ecdsa_recoverable_signature_parse_compact(compact_sig, recid, context=_secp.ctx): + if len(compact_sig) != 64: + raise ValueError("Signature should be 64 bytes long") + sig = bytes(65) + r = _secp.secp256k1_ecdsa_recoverable_signature_parse_compact( + context, sig, compact_sig, recid + ) + if r == 0: + raise ValueError("Failed parsing compact signature") + return sig + + +@locked +def ecdsa_recoverable_signature_convert(sigin, context=_secp.ctx): + if len(sigin) != 65: + raise ValueError("Recoverable signature should be 65 bytes long") + sig = bytes(64) + r = _secp.secp256k1_ecdsa_recoverable_signature_convert(context, sig, sigin) + if r == 0: + raise ValueError("Failed converting signature") + return sig + + +@locked +def ecdsa_recover(sig, msghash, context=_secp.ctx): + if len(sig) != 65: + raise ValueError("Recoverable signature should be 65 bytes long") + if len(msghash) != 32: + raise ValueError("Message should be 32 bytes long") + pub = bytes(64) + r = _secp.secp256k1_ecdsa_recover(context, pub, sig, msghash) + if r == 0: + raise ValueError("Failed to recover public key") + return pub + + +# zkp modules + + +@locked +def pedersen_commitment_parse(inp, context=_secp.ctx): + if len(inp) != 33: + raise ValueError("Serialized commitment should be 33 bytes long") + commit = bytes(64) + r = _secp.secp256k1_pedersen_commitment_parse(context, commit, inp) + if r == 0: + raise ValueError("Failed to parse commitment") + return commit + + +@locked +def pedersen_commitment_serialize(commit, context=_secp.ctx): + if len(commit) != 64: + raise ValueError("Commitment should be 64 bytes long") + sec = bytes(33) + r = _secp.secp256k1_pedersen_commitment_serialize(context, sec, commit) + if r == 0: + raise ValueError("Failed to serialize commitment") + return sec + + +@locked +def pedersen_commit(vbf, value, gen, context=_secp.ctx): + if len(gen) != 64: + raise ValueError("Generator should be 64 bytes long") + if len(vbf) != 32: + raise ValueError(f"Blinding factor should be 32 bytes long, not {len(vbf)}") + commit = bytes(64) + r = _secp.secp256k1_pedersen_commit(context, commit, vbf, value, gen) + if r == 0: + raise ValueError("Failed to create commitment") + return commit + + +@locked +def pedersen_blind_generator_blind_sum( + values, gens, vbfs, num_inputs, context=_secp.ctx +): + vals = (c_uint64 * len(values))(*values) + vbf = bytes(vbfs[-1]) + p = c_char_p(vbf) # obtain a pointer of various types + address = cast(p, c_void_p).value + + vbfs_joined = (c_char_p * len(vbfs))(*vbfs[:-1], address) + gens_joined = (c_char_p * len(gens))(*gens) + res = _secp.secp256k1_pedersen_blind_generator_blind_sum( + context, vals, gens_joined, vbfs_joined, len(values), num_inputs + ) + if res == 0: + raise ValueError("Failed to get the last blinding factor.") + res = (c_char * 32).from_address(address).raw + assert len(res) == 32 + return res + + +@locked +def pedersen_verify_tally(ins, outs, context=_secp.ctx): + in_ptr = (c_char_p * len(ins))(*ins) + out_ptr = (c_char_p * len(outs))(*outs) + res = _secp.secp256k1_pedersen_verify_tally( + context, in_ptr, len(in_ptr), out_ptr, len(out_ptr) + ) + return bool(res) + + +# generator +@locked +def generator_parse(inp, context=_secp.ctx): + if len(inp) != 33: + raise ValueError("Serialized generator should be 33 bytes long") + gen = bytes(64) + r = _secp.secp256k1_generator_parse(context, gen, inp) + if r == 0: + raise ValueError("Failed to parse generator") + return gen + + +@locked +def generator_generate(asset, context=_secp.ctx): + if len(asset) != 32: + raise ValueError("Asset should be 32 bytes long") + gen = bytes(64) + r = _secp.secp256k1_generator_generate(context, gen, asset) + if r == 0: + raise ValueError("Failed to generate generator") + return gen + + +@locked +def generator_generate_blinded(asset, abf, context=_secp.ctx): + if len(asset) != 32: + raise ValueError("Asset should be 32 bytes long") + if len(abf) != 32: + raise ValueError("Asset blinding factor should be 32 bytes long") + gen = bytes(64) + r = _secp.secp256k1_generator_generate_blinded(context, gen, asset, abf) + if r == 0: + raise ValueError("Failed to generate generator") + return gen + + +@locked +def generator_serialize(generator, context=_secp.ctx): + if len(generator) != 64: + raise ValueError("Generator should be 64 bytes long") + sec = bytes(33) + if _secp.secp256k1_generator_serialize(context, sec, generator) == 0: + raise RuntimeError("Failed to serialize generator") + return sec + + +# rangeproof +@locked +def rangeproof_rewind( + proof, + nonce, + value_commitment, + script_pubkey, + generator, + message_length=64, + context=_secp.ctx, +): + if len(generator) != 64: + raise ValueError("Generator should be 64 bytes long") + if len(nonce) != 32: + raise ValueError("Nonce should be 32 bytes long") + if len(value_commitment) != 64: + raise ValueError("Value commitment should be 64 bytes long") + + pointer = POINTER(c_uint64) + + msg = b"\x00" * message_length + msglen = pointer(c_uint64(len(msg))) + + vbf_out = b"\x00" * 32 + value_out = pointer(c_uint64(0)) + min_value = pointer(c_uint64(0)) + max_value = pointer(c_uint64(0)) + res = _secp.secp256k1_rangeproof_rewind( + context, + vbf_out, + value_out, + msg, + msglen, + nonce, + min_value, + max_value, + value_commitment, + proof, + len(proof), + script_pubkey, + len(script_pubkey), + generator, + ) + if res != 1: + raise RuntimeError("Failed to rewind the proof") + return ( + value_out.contents.value, + vbf_out, + msg[: msglen.contents.value], + min_value.contents.value, + max_value.contents.value, + ) + + +# rangeproof + + +@locked +def rangeproof_verify( + proof, value_commitment, script_pubkey, generator, context=_secp.ctx +): + if len(generator) != 64: + raise ValueError("Generator should be 64 bytes long") + if len(value_commitment) != 64: + raise ValueError("Value commitment should be 64 bytes long") + + pointer = POINTER(c_uint64) + min_value = pointer(c_uint64(0)) + max_value = pointer(c_uint64(0)) + res = _secp.secp256k1_rangeproof_verify( + context, + min_value, + max_value, + value_commitment, + proof, + len(proof), + script_pubkey, + len(script_pubkey), + generator, + ) + if res != 1: + raise RuntimeError("Failed to verify the proof") + return min_value.contents.value, max_value.contents.value + + +@locked +def rangeproof_sign( + nonce, + value, + value_commitment, + vbf, + message, + extra, + gen, + min_value=1, + exp=0, + min_bits=52, + context=_secp.ctx, +): + if value == 0: + min_value = 0 + if len(gen) != 64: + raise ValueError("Generator should be 64 bytes long") + if len(nonce) != 32: + raise ValueError("Nonce should be 32 bytes long") + if len(value_commitment) != 64: + raise ValueError("Value commitment should be 64 bytes long") + if len(vbf) != 32: + raise ValueError("Value blinding factor should be 32 bytes long") + proof = bytes(5134) + pointer = POINTER(c_uint64) + prooflen = pointer(c_uint64(len(proof))) + res = _secp.secp256k1_rangeproof_sign( + context, + proof, + prooflen, + min_value, + value_commitment, + vbf, + nonce, + exp, + min_bits, + value, + message, + len(message), + extra, + len(extra), + gen, + ) + if res != 1: + raise RuntimeError("Failed to generate the proof") + return bytes(proof[: prooflen.contents.value]) + + +@locked +def musig_pubkey_combine(*args, context=_secp.ctx): + pub = bytes(64) + # TODO: strange that behaviour is different from pubkey_combine... + pubkeys = b"".join(args) # (c_char_p * len(args))(*args) + res = _secp.secp256k1_musig_pubkey_combine( + context, None, pub, None, pubkeys, len(args) + ) + if res == 0: + raise ValueError("Failed to combine pubkeys") + return pub + + +# surjection proof +@locked +def surjectionproof_initialize( + in_tags, out_tag, seed, tags_to_use=None, iterations=100, context=_secp.ctx +): + if tags_to_use is None: + tags_to_use = min(3, len(in_tags)) + if seed is None: + seed = os.urandom(32) + proof = bytes(4 + 8 + 256 // 8 + 32 * 257) + pointer = POINTER(c_size_t) + input_index = pointer(c_size_t(0)) + input_tags = b"".join(in_tags) + res = _secp.secp256k1_surjectionproof_initialize( + context, + proof, + input_index, + input_tags, + len(in_tags), + tags_to_use, + out_tag, + iterations, + seed, + ) + if res == 0: + raise RuntimeError("Failed to initialize the proof") + return proof, input_index.contents.value + + +@locked +def surjectionproof_generate( + proof, in_idx, in_tags, out_tag, in_abf, out_abf, context=_secp.ctx +): + res = _secp.secp256k1_surjectionproof_generate( + context, + proof, + b"".join(in_tags), + len(in_tags), + out_tag, + in_idx, + in_abf, + out_abf, + ) + if not res: + raise RuntimeError("Failed to generate surjection proof") + return proof + + +@locked +def surjectionproof_verify(proof, in_tags, out_tag, context=_secp.ctx): + res = _secp.secp256k1_surjectionproof_verify( + context, proof, b"".join(in_tags), len(in_tags), out_tag + ) + return bool(res) + + +@locked +def surjectionproof_serialize(proof, context=_secp.ctx): + s = _secp.secp256k1_surjectionproof_serialized_size(context, proof) + b = bytes(s) + pointer = POINTER(c_size_t) + sz = pointer(c_size_t(s)) + _secp.secp256k1_surjectionproof_serialize(context, b, sz, proof) + if s != sz.contents.value: + raise RuntimeError("Failed to serialize surjection proof - size mismatch") + return b + + +@locked +def surjectionproof_parse(proof, context=_secp.ctx): + parsed_proof = bytes(4 + 8 + 256 // 8 + 32 * 257) + res = _secp.secp256k1_surjectionproof_parse( + context, parsed_proof, proof, len(proof) + ) + if res == 0: + raise RuntimeError("Failed to parse surjection proof") + return parsed_proof diff --git a/bitcoin_client/ledger_bitcoin/embit/util/key.py b/bitcoin_client/ledger_bitcoin/embit/util/key.py new file mode 100644 index 000000000..13b01d955 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/util/key.py @@ -0,0 +1,597 @@ +""" +Copy-paste from key.py in bitcoin test_framework. +This is a fallback option if the library can't do ctypes bindings to secp256k1 library. +""" +import random +import hmac +import hashlib + + +def TaggedHash(tag, data): + ss = hashlib.sha256(tag.encode("utf-8")).digest() + ss += ss + ss += data + return hashlib.sha256(ss).digest() + + +def modinv(a, n): + """Compute the modular inverse of a modulo n using the extended Euclidean + Algorithm. See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers. + """ + # TODO: Change to pow(a, -1, n) available in Python 3.8 + t1, t2 = 0, 1 + r1, r2 = n, a + while r2 != 0: + q = r1 // r2 + t1, t2 = t2, t1 - q * t2 + r1, r2 = r2, r1 - q * r2 + if r1 > 1: + return None + if t1 < 0: + t1 += n + return t1 + + +def xor_bytes(b0, b1): + return bytes(x ^ y for (x, y) in zip(b0, b1)) + + +def jacobi_symbol(n, k): + """Compute the Jacobi symbol of n modulo k + + See http://en.wikipedia.org/wiki/Jacobi_symbol + + For our application k is always prime, so this is the same as the Legendre symbol. + """ + assert k > 0 and k & 1, "jacobi symbol is only defined for positive odd k" + n %= k + t = 0 + while n != 0: + while n & 1 == 0: + n >>= 1 + r = k & 7 + t ^= r == 3 or r == 5 + n, k = k, n + t ^= n & k & 3 == 3 + n = n % k + if k == 1: + return -1 if t else 1 + return 0 + + +def modsqrt(a, p): + """Compute the square root of a modulo p when p % 4 = 3. + + The Tonelli-Shanks algorithm can be used. See https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm + + Limiting this function to only work for p % 4 = 3 means we don't need to + iterate through the loop. The highest n such that p - 1 = 2^n Q with Q odd + is n = 1. Therefore Q = (p-1)/2 and sqrt = a^((Q+1)/2) = a^((p+1)/4) + + secp256k1's is defined over field of size 2**256 - 2**32 - 977, which is 3 mod 4. + """ + if p % 4 != 3: + raise NotImplementedError("modsqrt only implemented for p % 4 = 3") + sqrt = pow(a, (p + 1) // 4, p) + if pow(sqrt, 2, p) == a % p: + return sqrt + return None + + +class EllipticCurve: + def __init__(self, p, a, b): + """Initialize elliptic curve y^2 = x^3 + a*x + b over GF(p).""" + self.p = p + self.a = a % p + self.b = b % p + + def affine(self, p1): + """Convert a Jacobian point tuple p1 to affine form, or None if at infinity. + + An affine point is represented as the Jacobian (x, y, 1)""" + x1, y1, z1 = p1 + if z1 == 0: + return None + inv = modinv(z1, self.p) + inv_2 = (inv**2) % self.p + inv_3 = (inv_2 * inv) % self.p + return ((inv_2 * x1) % self.p, (inv_3 * y1) % self.p, 1) + + def has_even_y(self, p1): + """Whether the point p1 has an even Y coordinate when expressed in affine coordinates.""" + return not (p1[2] == 0 or self.affine(p1)[1] & 1) + + def negate(self, p1): + """Negate a Jacobian point tuple p1.""" + x1, y1, z1 = p1 + return (x1, (self.p - y1) % self.p, z1) + + def on_curve(self, p1): + """Determine whether a Jacobian tuple p is on the curve (and not infinity)""" + x1, y1, z1 = p1 + z2 = pow(z1, 2, self.p) + z4 = pow(z2, 2, self.p) + return ( + z1 != 0 + and ( + pow(x1, 3, self.p) + + self.a * x1 * z4 + + self.b * z2 * z4 + - pow(y1, 2, self.p) + ) + % self.p + == 0 + ) + + def is_x_coord(self, x): + """Test whether x is a valid X coordinate on the curve.""" + x_3 = pow(x, 3, self.p) + return jacobi_symbol(x_3 + self.a * x + self.b, self.p) != -1 + + def lift_x(self, x): + """Given an X coordinate on the curve, return a corresponding affine point for which the Y coordinate is even.""" + x_3 = pow(x, 3, self.p) + v = x_3 + self.a * x + self.b + y = modsqrt(v, self.p) + if y is None: + return None + return (x, self.p - y if y & 1 else y, 1) + + def double(self, p1): + """Double a Jacobian tuple p1 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Doubling + """ + x1, y1, z1 = p1 + if z1 == 0: + return (0, 1, 0) + y1_2 = (y1**2) % self.p + y1_4 = (y1_2**2) % self.p + x1_2 = (x1**2) % self.p + s = (4 * x1 * y1_2) % self.p + m = 3 * x1_2 + if self.a: + m += self.a * pow(z1, 4, self.p) + m = m % self.p + x2 = (m**2 - 2 * s) % self.p + y2 = (m * (s - x2) - 8 * y1_4) % self.p + z2 = (2 * y1 * z1) % self.p + return (x2, y2, z2) + + def add_mixed(self, p1, p2): + """Add a Jacobian tuple p1 and an affine tuple p2 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition (with affine point) + """ + x1, y1, z1 = p1 + x2, y2, z2 = p2 + assert z2 == 1 + # Adding to the point at infinity is a no-op + if z1 == 0: + return p2 + z1_2 = (z1**2) % self.p + z1_3 = (z1_2 * z1) % self.p + u2 = (x2 * z1_2) % self.p + s2 = (y2 * z1_3) % self.p + if x1 == u2: + if y1 != s2: + # p1 and p2 are inverses. Return the point at infinity. + return (0, 1, 0) + # p1 == p2. The formulas below fail when the two points are equal. + return self.double(p1) + h = u2 - x1 + r = s2 - y1 + h_2 = (h**2) % self.p + h_3 = (h_2 * h) % self.p + u1_h_2 = (x1 * h_2) % self.p + x3 = (r**2 - h_3 - 2 * u1_h_2) % self.p + y3 = (r * (u1_h_2 - x3) - y1 * h_3) % self.p + z3 = (h * z1) % self.p + return (x3, y3, z3) + + def add(self, p1, p2): + """Add two Jacobian tuples p1 and p2 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition + """ + x1, y1, z1 = p1 + x2, y2, z2 = p2 + # Adding the point at infinity is a no-op + if z1 == 0: + return p2 + if z2 == 0: + return p1 + # Adding an Affine to a Jacobian is more efficient since we save field multiplications and squarings when z = 1 + if z1 == 1: + return self.add_mixed(p2, p1) + if z2 == 1: + return self.add_mixed(p1, p2) + z1_2 = (z1**2) % self.p + z1_3 = (z1_2 * z1) % self.p + z2_2 = (z2**2) % self.p + z2_3 = (z2_2 * z2) % self.p + u1 = (x1 * z2_2) % self.p + u2 = (x2 * z1_2) % self.p + s1 = (y1 * z2_3) % self.p + s2 = (y2 * z1_3) % self.p + if u1 == u2: + if s1 != s2: + # p1 and p2 are inverses. Return the point at infinity. + return (0, 1, 0) + # p1 == p2. The formulas below fail when the two points are equal. + return self.double(p1) + h = u2 - u1 + r = s2 - s1 + h_2 = (h**2) % self.p + h_3 = (h_2 * h) % self.p + u1_h_2 = (u1 * h_2) % self.p + x3 = (r**2 - h_3 - 2 * u1_h_2) % self.p + y3 = (r * (u1_h_2 - x3) - s1 * h_3) % self.p + z3 = (h * z1 * z2) % self.p + return (x3, y3, z3) + + def mul(self, ps): + """Compute a (multi) point multiplication + + ps is a list of (Jacobian tuple, scalar) pairs. + """ + r = (0, 1, 0) + for i in range(255, -1, -1): + r = self.double(r) + for p, n in ps: + if (n >> i) & 1: + r = self.add(r, p) + return r + + +SECP256K1_FIELD_SIZE = 2**256 - 2**32 - 977 +SECP256K1 = EllipticCurve(SECP256K1_FIELD_SIZE, 0, 7) +SECP256K1_G = ( + 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, + 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8, + 1, +) +SECP256K1_ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 +SECP256K1_ORDER_HALF = SECP256K1_ORDER // 2 + + +class ECPubKey: + """A secp256k1 public key""" + + def __init__(self): + """Construct an uninitialized public key""" + self.valid = False + + def set(self, data): + """Construct a public key from a serialization in compressed or uncompressed format""" + if len(data) == 65 and data[0] == 0x04: + p = ( + int.from_bytes(data[1:33], "big"), + int.from_bytes(data[33:65], "big"), + 1, + ) + self.valid = SECP256K1.on_curve(p) + if self.valid: + self.p = p + self.compressed = False + elif len(data) == 33 and (data[0] == 0x02 or data[0] == 0x03): + x = int.from_bytes(data[1:33], "big") + if SECP256K1.is_x_coord(x): + p = SECP256K1.lift_x(x) + # Make the Y coordinate odd if required (lift_x always produces + # a point with an even Y coordinate). + if data[0] & 1: + p = SECP256K1.negate(p) + self.p = p + self.valid = True + self.compressed = True + else: + self.valid = False + else: + self.valid = False + + @property + def is_compressed(self): + return self.compressed + + @property + def is_valid(self): + return self.valid + + def get_bytes(self): + assert self.valid + p = SECP256K1.affine(self.p) + if p is None: + return None + if self.compressed: + return bytes([0x02 + (p[1] & 1)]) + p[0].to_bytes(32, "big") + else: + return bytes([0x04]) + p[0].to_bytes(32, "big") + p[1].to_bytes(32, "big") + + def verify_ecdsa(self, sig, msg, low_s=True): + """Verify a strictly DER-encoded ECDSA signature against this pubkey. + + See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the + ECDSA verifier algorithm""" + assert self.valid + + # Extract r and s from the DER formatted signature. Return false for + # any DER encoding errors. + if sig[1] + 2 != len(sig): + return False + if len(sig) < 4: + return False + if sig[0] != 0x30: + return False + if sig[2] != 0x02: + return False + rlen = sig[3] + if len(sig) < 6 + rlen: + return False + if rlen < 1 or rlen > 33: + return False + if sig[4] >= 0x80: + return False + if rlen > 1 and (sig[4] == 0) and not (sig[5] & 0x80): + return False + r = int.from_bytes(sig[4 : 4 + rlen], "big") + if sig[4 + rlen] != 0x02: + return False + slen = sig[5 + rlen] + if slen < 1 or slen > 33: + return False + if len(sig) != 6 + rlen + slen: + return False + if sig[6 + rlen] >= 0x80: + return False + if slen > 1 and (sig[6 + rlen] == 0) and not (sig[7 + rlen] & 0x80): + return False + s = int.from_bytes(sig[6 + rlen : 6 + rlen + slen], "big") + + # Verify that r and s are within the group order + if r < 1 or s < 1 or r >= SECP256K1_ORDER or s >= SECP256K1_ORDER: + return False + if low_s and s >= SECP256K1_ORDER_HALF: + return False + z = int.from_bytes(msg, "big") + + # Run verifier algorithm on r, s + w = modinv(s, SECP256K1_ORDER) + u1 = z * w % SECP256K1_ORDER + u2 = r * w % SECP256K1_ORDER + R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, u1), (self.p, u2)])) + if R is None or (R[0] % SECP256K1_ORDER) != r: + return False + return True + + +def generate_privkey(): + """Generate a valid random 32-byte private key.""" + return random.randrange(1, SECP256K1_ORDER).to_bytes(32, "big") + + +class ECKey: + """A secp256k1 private key""" + + def __init__(self): + self.valid = False + + def set(self, secret, compressed): + """Construct a private key object with given 32-byte secret and compressed flag.""" + assert len(secret) == 32 + secret = int.from_bytes(secret, "big") + self.valid = secret > 0 and secret < SECP256K1_ORDER + if self.valid: + self.secret = secret + self.compressed = compressed + + def generate(self, compressed=True): + """Generate a random private key (compressed or uncompressed).""" + self.set(generate_privkey(), compressed) + + def get_bytes(self): + """Retrieve the 32-byte representation of this key.""" + assert self.valid + return self.secret.to_bytes(32, "big") + + @property + def is_valid(self): + return self.valid + + @property + def is_compressed(self): + return self.compressed + + def get_pubkey(self): + """Compute an ECPubKey object for this secret key.""" + assert self.valid + ret = ECPubKey() + p = SECP256K1.mul([(SECP256K1_G, self.secret)]) + ret.p = p + ret.valid = True + ret.compressed = self.compressed + return ret + + def sign_ecdsa(self, msg, nonce_function=None, extra_data=None, low_s=True): + """Construct a DER-encoded ECDSA signature with this key. + + See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the + ECDSA signer algorithm.""" + assert self.valid + z = int.from_bytes(msg, "big") + if nonce_function is None: + nonce_function = deterministic_k + k = nonce_function(self.secret, z, extra_data=extra_data) + R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, k)])) + r = R[0] % SECP256K1_ORDER + s = (modinv(k, SECP256K1_ORDER) * (z + self.secret * r)) % SECP256K1_ORDER + if low_s and s > SECP256K1_ORDER_HALF: + s = SECP256K1_ORDER - s + # Represent in DER format. The byte representations of r and s have + # length rounded up (255 bits becomes 32 bytes and 256 bits becomes 33 + # bytes). + rb = r.to_bytes((r.bit_length() + 8) // 8, "big") + sb = s.to_bytes((s.bit_length() + 8) // 8, "big") + return ( + b"\x30" + + bytes([4 + len(rb) + len(sb), 2, len(rb)]) + + rb + + bytes([2, len(sb)]) + + sb + ) + + +def deterministic_k(secret, z, extra_data=None): + # RFC6979, optimized for secp256k1 + k = b"\x00" * 32 + v = b"\x01" * 32 + if z > SECP256K1_ORDER: + z -= SECP256K1_ORDER + z_bytes = z.to_bytes(32, "big") + secret_bytes = secret.to_bytes(32, "big") + if extra_data is not None: + z_bytes += extra_data + k = hmac.new(k, v + b"\x00" + secret_bytes + z_bytes, "sha256").digest() + v = hmac.new(k, v, "sha256").digest() + k = hmac.new(k, v + b"\x01" + secret_bytes + z_bytes, "sha256").digest() + v = hmac.new(k, v, "sha256").digest() + while True: + v = hmac.new(k, v, "sha256").digest() + candidate = int.from_bytes(v, "big") + if candidate >= 1 and candidate < SECP256K1_ORDER: + return candidate + k = hmac.new(k, v + b"\x00", "sha256").digest() + v = hmac.new(k, v, "sha256").digest() + + +def compute_xonly_pubkey(key): + """Compute an x-only (32 byte) public key from a (32 byte) private key. + + This also returns whether the resulting public key was negated. + """ + + assert len(key) == 32 + x = int.from_bytes(key, "big") + if x == 0 or x >= SECP256K1_ORDER: + return (None, None) + P = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, x)])) + return (P[0].to_bytes(32, "big"), not SECP256K1.has_even_y(P)) + + +def tweak_add_privkey(key, tweak): + """Tweak a private key (after negating it if needed).""" + + assert len(key) == 32 + assert len(tweak) == 32 + + x = int.from_bytes(key, "big") + if x == 0 or x >= SECP256K1_ORDER: + return None + if not SECP256K1.has_even_y(SECP256K1.mul([(SECP256K1_G, x)])): + x = SECP256K1_ORDER - x + t = int.from_bytes(tweak, "big") + if t >= SECP256K1_ORDER: + return None + x = (x + t) % SECP256K1_ORDER + if x == 0: + return None + return x.to_bytes(32, "big") + + +def tweak_add_pubkey(key, tweak): + """Tweak a public key and return whether the result had to be negated.""" + + assert len(key) == 32 + assert len(tweak) == 32 + + x_coord = int.from_bytes(key, "big") + if x_coord >= SECP256K1_FIELD_SIZE: + return None + P = SECP256K1.lift_x(x_coord) + if P is None: + return None + t = int.from_bytes(tweak, "big") + if t >= SECP256K1_ORDER: + return None + Q = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, t), (P, 1)])) + if Q is None: + return None + return (Q[0].to_bytes(32, "big"), not SECP256K1.has_even_y(Q)) + + +def verify_schnorr(key, sig, msg): + """Verify a Schnorr signature (see BIP 340). + - key is a 32-byte xonly pubkey (computed using compute_xonly_pubkey). + - sig is a 64-byte Schnorr signature + - msg is a 32-byte message + """ + assert len(key) == 32 + assert len(msg) == 32 + assert len(sig) == 64 + + x_coord = int.from_bytes(key, "big") + if x_coord == 0 or x_coord >= SECP256K1_FIELD_SIZE: + return False + P = SECP256K1.lift_x(x_coord) + if P is None: + return False + r = int.from_bytes(sig[0:32], "big") + if r >= SECP256K1_FIELD_SIZE: + return False + s = int.from_bytes(sig[32:64], "big") + if s >= SECP256K1_ORDER: + return False + e = ( + int.from_bytes(TaggedHash("BIP0340/challenge", sig[0:32] + key + msg), "big") + % SECP256K1_ORDER + ) + R = SECP256K1.mul([(SECP256K1_G, s), (P, SECP256K1_ORDER - e)]) + if not SECP256K1.has_even_y(R): + return False + if ((r * R[2] * R[2]) % SECP256K1_FIELD_SIZE) != R[0]: + return False + return True + + +def sign_schnorr(key, msg, aux=None, flip_p=False, flip_r=False): + """Create a Schnorr signature (see BIP 340).""" + + assert len(key) == 32 + assert len(msg) == 32 + if aux is not None: + assert len(aux) == 32 + + sec = int.from_bytes(key, "big") + if sec == 0 or sec >= SECP256K1_ORDER: + return None + P = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, sec)])) + if SECP256K1.has_even_y(P) == flip_p: + sec = SECP256K1_ORDER - sec + if aux is not None: + t = (sec ^ int.from_bytes(TaggedHash("BIP0340/aux", aux), "big")).to_bytes( + 32, "big" + ) + else: + t = sec.to_bytes(32, "big") + kp = ( + int.from_bytes( + TaggedHash("BIP0340/nonce", t + P[0].to_bytes(32, "big") + msg), "big" + ) + % SECP256K1_ORDER + ) + assert kp != 0 + R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, kp)])) + k = kp if SECP256K1.has_even_y(R) != flip_r else SECP256K1_ORDER - kp + e = ( + int.from_bytes( + TaggedHash( + "BIP0340/challenge", + R[0].to_bytes(32, "big") + P[0].to_bytes(32, "big") + msg, + ), + "big", + ) + % SECP256K1_ORDER + ) + return R[0].to_bytes(32, "big") + ((k + e * sec) % SECP256K1_ORDER).to_bytes( + 32, "big" + ) diff --git a/bitcoin_client/ledger_bitcoin/embit/util/py_ripemd160.py b/bitcoin_client/ledger_bitcoin/embit/util/py_ripemd160.py new file mode 100644 index 000000000..7eeaa56ca --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/util/py_ripemd160.py @@ -0,0 +1,407 @@ +# Copyright (c) 2021 Pieter Wuille +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Pure Python RIPEMD160 implementation.""" + +# Message schedule indexes for the left path. +ML = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 7, + 4, + 13, + 1, + 10, + 6, + 15, + 3, + 12, + 0, + 9, + 5, + 2, + 14, + 11, + 8, + 3, + 10, + 14, + 4, + 9, + 15, + 8, + 1, + 2, + 7, + 0, + 6, + 13, + 11, + 5, + 12, + 1, + 9, + 11, + 10, + 0, + 8, + 12, + 4, + 13, + 3, + 7, + 15, + 14, + 5, + 6, + 2, + 4, + 0, + 5, + 9, + 7, + 12, + 2, + 10, + 14, + 1, + 3, + 8, + 11, + 6, + 15, + 13, +] + +# Message schedule indexes for the right path. +MR = [ + 5, + 14, + 7, + 0, + 9, + 2, + 11, + 4, + 13, + 6, + 15, + 8, + 1, + 10, + 3, + 12, + 6, + 11, + 3, + 7, + 0, + 13, + 5, + 10, + 14, + 15, + 8, + 12, + 4, + 9, + 1, + 2, + 15, + 5, + 1, + 3, + 7, + 14, + 6, + 9, + 11, + 8, + 12, + 2, + 10, + 0, + 4, + 13, + 8, + 6, + 4, + 1, + 3, + 11, + 15, + 0, + 5, + 12, + 2, + 13, + 9, + 7, + 10, + 14, + 12, + 15, + 10, + 4, + 1, + 5, + 8, + 7, + 6, + 2, + 13, + 14, + 0, + 3, + 9, + 11, +] + +# Rotation counts for the left path. +RL = [ + 11, + 14, + 15, + 12, + 5, + 8, + 7, + 9, + 11, + 13, + 14, + 15, + 6, + 7, + 9, + 8, + 7, + 6, + 8, + 13, + 11, + 9, + 7, + 15, + 7, + 12, + 15, + 9, + 11, + 7, + 13, + 12, + 11, + 13, + 6, + 7, + 14, + 9, + 13, + 15, + 14, + 8, + 13, + 6, + 5, + 12, + 7, + 5, + 11, + 12, + 14, + 15, + 14, + 15, + 9, + 8, + 9, + 14, + 5, + 6, + 8, + 6, + 5, + 12, + 9, + 15, + 5, + 11, + 6, + 8, + 13, + 12, + 5, + 12, + 13, + 14, + 11, + 8, + 5, + 6, +] + +# Rotation counts for the right path. +RR = [ + 8, + 9, + 9, + 11, + 13, + 15, + 15, + 5, + 7, + 7, + 8, + 11, + 14, + 14, + 12, + 6, + 9, + 13, + 15, + 7, + 12, + 8, + 9, + 11, + 7, + 7, + 12, + 7, + 6, + 15, + 13, + 11, + 9, + 7, + 15, + 11, + 8, + 6, + 6, + 14, + 12, + 13, + 5, + 14, + 13, + 13, + 7, + 5, + 15, + 5, + 8, + 11, + 14, + 14, + 6, + 14, + 6, + 9, + 12, + 9, + 12, + 5, + 15, + 8, + 8, + 5, + 12, + 9, + 12, + 5, + 14, + 6, + 8, + 13, + 6, + 5, + 15, + 13, + 11, + 11, +] + +# K constants for the left path. +KL = [0, 0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xA953FD4E] + +# K constants for the right path. +KR = [0x50A28BE6, 0x5C4DD124, 0x6D703EF3, 0x7A6D76E9, 0] + + +def fi(x, y, z, i): + """The f1, f2, f3, f4, and f5 functions from the specification.""" + if i == 0: + return x ^ y ^ z + elif i == 1: + return (x & y) | (~x & z) + elif i == 2: + return (x | ~y) ^ z + elif i == 3: + return (x & z) | (y & ~z) + elif i == 4: + return x ^ (y | ~z) + else: + assert False + + +def rol(x, i): + """Rotate the bottom 32 bits of x left by i bits.""" + return ((x << i) | ((x & 0xFFFFFFFF) >> (32 - i))) & 0xFFFFFFFF + + +def compress(h0, h1, h2, h3, h4, block): + """Compress state (h0, h1, h2, h3, h4) with block.""" + # Left path variables. + al, bl, cl, dl, el = h0, h1, h2, h3, h4 + # Right path variables. + ar, br, cr, dr, er = h0, h1, h2, h3, h4 + # Message variables. + x = [int.from_bytes(block[4 * i : 4 * (i + 1)], "little") for i in range(16)] + + # Iterate over the 80 rounds of the compression. + for j in range(80): + rnd = j >> 4 + # Perform left side of the transformation. + al = rol(al + fi(bl, cl, dl, rnd) + x[ML[j]] + KL[rnd], RL[j]) + el + al, bl, cl, dl, el = el, al, bl, rol(cl, 10), dl + # Perform right side of the transformation. + ar = rol(ar + fi(br, cr, dr, 4 - rnd) + x[MR[j]] + KR[rnd], RR[j]) + er + ar, br, cr, dr, er = er, ar, br, rol(cr, 10), dr + + # Compose old state, left transform, and right transform into new state. + return h1 + cl + dr, h2 + dl + er, h3 + el + ar, h4 + al + br, h0 + bl + cr + + +def ripemd160(data): + """Compute the RIPEMD-160 hash of data.""" + # Initialize state. + state = (0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0) + # Process full 64-byte blocks in the input. + for b in range(len(data) >> 6): + state = compress(*state, data[64 * b : 64 * (b + 1)]) + # Construct final blocks (with padding and size). + pad = b"\x80" + b"\x00" * ((119 - len(data)) & 63) + fin = data[len(data) & ~63 :] + pad + (8 * len(data)).to_bytes(8, "little") + # Process final blocks. + for b in range(len(fin) >> 6): + state = compress(*state, fin[64 * b : 64 * (b + 1)]) + # Produce output. + return b"".join((h & 0xFFFFFFFF).to_bytes(4, "little") for h in state) diff --git a/bitcoin_client/ledger_bitcoin/embit/util/py_secp256k1.py b/bitcoin_client/ledger_bitcoin/embit/util/py_secp256k1.py new file mode 100644 index 000000000..851408635 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/util/py_secp256k1.py @@ -0,0 +1,384 @@ +""" +This is a fallback option if the library can't do ctypes bindings to secp256k1 library. +Mimics the micropython bindings and internal representation of data structs in secp256k1. +""" + +from . import key as _key + +# Flags to pass to context_create. +CONTEXT_VERIFY = 0b0100000001 +CONTEXT_SIGN = 0b1000000001 +CONTEXT_NONE = 0b0000000001 + +# Flags to pass to ec_pubkey_serialize +EC_COMPRESSED = 0b0100000010 +EC_UNCOMPRESSED = 0b0000000010 + + +def context_randomize(seed, context=None): + pass + + +def _reverse64(b): + """Converts (a,b) from big to little endian to be consistent with secp256k1""" + x = b[:32] + y = b[32:] + return x[::-1] + y[::-1] + + +def _pubkey_serialize(pub): + """Returns pubkey representation like secp library""" + b = pub.get_bytes()[1:] + return _reverse64(b) + + +def _pubkey_parse(b): + """Returns pubkey class instance""" + pub = _key.ECPubKey() + pub.set(b"\x04" + _reverse64(b)) + return pub + + +def ec_pubkey_create(secret, context=None): + if len(secret) != 32: + raise ValueError("Private key should be 32 bytes long") + pk = _key.ECKey() + pk.set(secret, compressed=False) + if not pk.is_valid: + raise ValueError("Invalid private key") + return _pubkey_serialize(pk.get_pubkey()) + + +def ec_pubkey_parse(sec, context=None): + if len(sec) != 33 and len(sec) != 65: + raise ValueError("Serialized pubkey should be 33 or 65 bytes long") + if len(sec) == 33: + if sec[0] != 0x02 and sec[0] != 0x03: + raise ValueError("Compressed pubkey should start with 0x02 or 0x03") + else: + if sec[0] != 0x04: + raise ValueError("Uncompressed pubkey should start with 0x04") + pub = _key.ECPubKey() + pub.set(sec) + pub.compressed = False + if not pub.is_valid: + raise ValueError("Failed parsing public key") + return _pubkey_serialize(pub) + + +def ec_pubkey_serialize(pubkey, flag=EC_COMPRESSED, context=None): + if len(pubkey) != 64: + raise ValueError("Pubkey should be 64 bytes long") + if flag not in [EC_COMPRESSED, EC_UNCOMPRESSED]: + raise ValueError("Invalid flag") + pub = _pubkey_parse(pubkey) + if not pub.is_valid: + raise ValueError("Failed to serialize pubkey") + if flag == EC_COMPRESSED: + pub.compressed = True + return pub.get_bytes() + + +def ecdsa_signature_parse_compact(compact_sig, context=None): + if len(compact_sig) != 64: + raise ValueError("Compact signature should be 64 bytes long") + sig = _reverse64(compact_sig) + return sig + + +def ecdsa_signature_parse_der(der, context=None): + if der[1] + 2 != len(der): + raise ValueError("Failed parsing compact signature") + if len(der) < 4: + raise ValueError("Failed parsing compact signature") + if der[0] != 0x30: + raise ValueError("Failed parsing compact signature") + if der[2] != 0x02: + raise ValueError("Failed parsing compact signature") + rlen = der[3] + if len(der) < 6 + rlen: + raise ValueError("Failed parsing compact signature") + if rlen < 1 or rlen > 33: + raise ValueError("Failed parsing compact signature") + if der[4] >= 0x80: + raise ValueError("Failed parsing compact signature") + if rlen > 1 and (der[4] == 0) and not (der[5] & 0x80): + raise ValueError("Failed parsing compact signature") + r = int.from_bytes(der[4 : 4 + rlen], "big") + if der[4 + rlen] != 0x02: + raise ValueError("Failed parsing compact signature") + slen = der[5 + rlen] + if slen < 1 or slen > 33: + raise ValueError("Failed parsing compact signature") + if len(der) != 6 + rlen + slen: + raise ValueError("Failed parsing compact signature") + if der[6 + rlen] >= 0x80: + raise ValueError("Failed parsing compact signature") + if slen > 1 and (der[6 + rlen] == 0) and not (der[7 + rlen] & 0x80): + raise ValueError("Failed parsing compact signature") + s = int.from_bytes(der[6 + rlen : 6 + rlen + slen], "big") + + # Verify that r and s are within the group order + if r < 1 or s < 1 or r >= _key.SECP256K1_ORDER or s >= _key.SECP256K1_ORDER: + raise ValueError("Failed parsing compact signature") + if s >= _key.SECP256K1_ORDER_HALF: + raise ValueError("Failed parsing compact signature") + + return r.to_bytes(32, "little") + s.to_bytes(32, "little") + + +def ecdsa_signature_serialize_der(sig, context=None): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + r = int.from_bytes(sig[:32], "little") + s = int.from_bytes(sig[32:], "little") + rb = r.to_bytes((r.bit_length() + 8) // 8, "big") + sb = s.to_bytes((s.bit_length() + 8) // 8, "big") + return ( + b"\x30" + + bytes([4 + len(rb) + len(sb), 2, len(rb)]) + + rb + + bytes([2, len(sb)]) + + sb + ) + + +def ecdsa_signature_serialize_compact(sig, context=None): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + return _reverse64(sig) + + +def ecdsa_signature_normalize(sig, context=None): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + r = int.from_bytes(sig[:32], "little") + s = int.from_bytes(sig[32:], "little") + if s >= _key.SECP256K1_ORDER_HALF: + s = _key.SECP256K1_ORDER - s + return r.to_bytes(32, "little") + s.to_bytes(32, "little") + + +def ecdsa_verify(sig, msg, pub, context=None): + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") + if len(pub) != 64: + raise ValueError("Public key should be 64 bytes long") + pubkey = _pubkey_parse(pub) + return pubkey.verify_ecdsa(ecdsa_signature_serialize_der(sig), msg) + + +def ecdsa_sign(msg, secret, nonce_function=None, extra_data=None, context=None): + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") + if len(secret) != 32: + raise ValueError("Secret key should be 32 bytes long") + pk = _key.ECKey() + pk.set(secret, False) + sig = pk.sign_ecdsa(msg, nonce_function, extra_data) + return ecdsa_signature_parse_der(sig) + + +def ec_seckey_verify(secret, context=None): + if len(secret) != 32: + raise ValueError("Secret should be 32 bytes long") + pk = _key.ECKey() + pk.set(secret, compressed=False) + return pk.is_valid + + +def ec_privkey_negate(secret, context=None): + # negate in place + if len(secret) != 32: + raise ValueError("Secret should be 32 bytes long") + s = int.from_bytes(secret, "big") + s2 = _key.SECP256K1_ORDER - s + return s2.to_bytes(32, "big") + + +def ec_pubkey_negate(pubkey, context=None): + if len(pubkey) != 64: + raise ValueError("Pubkey should be a 64-byte structure") + sec = ec_pubkey_serialize(pubkey) + return ec_pubkey_parse(bytes([0x05 - sec[0]]) + sec[1:]) + + +def ec_privkey_tweak_add(secret, tweak, context=None): + res = ec_privkey_add(secret, tweak) + for i in range(len(secret)): + secret[i] = res[i] + + +def ec_pubkey_tweak_add(pub, tweak, context=None): + res = ec_pubkey_add(pub, tweak) + for i in range(len(pub)): + pub[i] = res[i] + + +def ec_privkey_add(secret, tweak, context=None): + if len(secret) != 32 or len(tweak) != 32: + raise ValueError("Secret and tweak should both be 32 bytes long") + s = int.from_bytes(secret, "big") + t = int.from_bytes(tweak, "big") + r = (s + t) % _key.SECP256K1_ORDER + return r.to_bytes(32, "big") + + +def ec_pubkey_add(pub, tweak, context=None): + if len(pub) != 64: + raise ValueError("Public key should be 64 bytes long") + if len(tweak) != 32: + raise ValueError("Tweak should be 32 bytes long") + pubkey = _pubkey_parse(pub) + pubkey.compressed = True + t = int.from_bytes(tweak, "big") + Q = _key.SECP256K1.affine( + _key.SECP256K1.mul([(_key.SECP256K1_G, t), (pubkey.p, 1)]) + ) + if Q is None: + return None + return Q[0].to_bytes(32, "little") + Q[1].to_bytes(32, "little") + + +# def ec_privkey_tweak_mul(secret, tweak, context=None): +# if len(secret)!=32 or len(tweak)!=32: +# raise ValueError("Secret and tweak should both be 32 bytes long") +# s = int.from_bytes(secret, 'big') +# t = int.from_bytes(tweak, 'big') +# if t > _key.SECP256K1_ORDER or s > _key.SECP256K1_ORDER: +# raise ValueError("Failed to tweak the secret") +# r = pow(s, t, _key.SECP256K1_ORDER) +# res = r.to_bytes(32, 'big') +# for i in range(len(secret)): +# secret[i] = res[i] + +# def ec_pubkey_tweak_mul(pub, tweak, context=None): +# if len(pub)!=64: +# raise ValueError("Public key should be 64 bytes long") +# if len(tweak)!=32: +# raise ValueError("Tweak should be 32 bytes long") +# if _secp.secp256k1_ec_pubkey_tweak_mul(context, pub, tweak) == 0: +# raise ValueError("Failed to tweak the public key") + +# def ec_pubkey_combine(*args, context=None): +# pub = bytes(64) +# pubkeys = (c_char_p * len(args))(*args) +# r = _secp.secp256k1_ec_pubkey_combine(context, pub, pubkeys, len(args)) +# if r == 0: +# raise ValueError("Failed to negate pubkey") +# return pub + +# schnorrsig + + +def xonly_pubkey_from_pubkey(pubkey, context=None): + if len(pubkey) != 64: + raise ValueError("Pubkey should be 64 bytes long") + sec = ec_pubkey_serialize(pubkey) + parity = sec[0] == 0x03 + pub = ec_pubkey_parse(b"\x02" + sec[1:33]) + return pub, parity + + +def schnorrsig_verify(sig, msg, pubkey, context=None): + assert len(sig) == 64 + assert len(msg) == 32 + assert len(pubkey) == 64 + sec = ec_pubkey_serialize(pubkey) + return _key.verify_schnorr(sec[1:33], sig, msg) + + +def keypair_create(secret, context=None): + pub = ec_pubkey_create(secret) + pub2, parity = xonly_pubkey_from_pubkey(pub) + keypair = secret + pub + return keypair + + +def schnorrsig_sign(msg, keypair, nonce_function=None, extra_data=None, context=None): + assert len(msg) == 32 + if len(keypair) == 32: + keypair = keypair_create(keypair, context=context) + assert len(keypair) == 96 + return _key.sign_schnorr(keypair[:32], msg, extra_data) + + +# recoverable + + +def ecdsa_sign_recoverable(msg, secret, context=None): + sig = ecdsa_sign(msg, secret) + pub = ec_pubkey_create(secret) + # Search for correct index. Not efficient but I am lazy. + # For efficiency use c-bindings to libsecp256k1 + for i in range(4): + if ecdsa_recover(sig + bytes([i]), msg) == pub: + return sig + bytes([i]) + raise ValueError("Failed to sign") + + +def ecdsa_recoverable_signature_serialize_compact(sig, context=None): + if len(sig) != 65: + raise ValueError("Recoverable signature should be 65 bytes long") + compact = ecdsa_signature_serialize_compact(sig[:64]) + return compact, sig[64] + + +def ecdsa_recoverable_signature_parse_compact(compact_sig, recid, context=None): + if len(compact_sig) != 64: + raise ValueError("Signature should be 64 bytes long") + # TODO: also check r value so recid > 2 makes sense + if recid < 0 or recid > 4: + raise ValueError("Failed parsing compact signature") + return ecdsa_signature_parse_compact(compact_sig) + bytes([recid]) + + +def ecdsa_recoverable_signature_convert(sigin, context=None): + if len(sigin) != 65: + raise ValueError("Recoverable signature should be 65 bytes long") + return sigin[:64] + + +def ecdsa_recover(sig, msghash, context=None): + if len(sig) != 65: + raise ValueError("Recoverable signature should be 65 bytes long") + if len(msghash) != 32: + raise ValueError("Message should be 32 bytes long") + idx = sig[-1] + r = int.from_bytes(sig[:32], "little") + s = int.from_bytes(sig[32:64], "little") + z = int.from_bytes(msghash, "big") + # r = Rx mod N, so R can be 02x, 03x, 02(N+x), 03(N+x) + # two latter cases only if N+x < P + r_candidates = [ + b"\x02" + r.to_bytes(32, "big"), + b"\x03" + r.to_bytes(32, "big"), + ] + if r + _key.SECP256K1_ORDER < _key.SECP256K1_FIELD_SIZE: + r2 = r + _key.SECP256K1_ORDER + r_candidates = r_candidates + [ + b"\x02" + r2.to_bytes(32, "big"), + b"\x03" + r2.to_bytes(32, "big"), + ] + if idx >= len(r_candidates): + raise ValueError("Failed to recover public key") + R = _key.ECPubKey() + R.set(r_candidates[idx]) + # s = (z + d * r)/k + # (R*s/r - z/r*G) = P + rinv = _key.modinv(r, _key.SECP256K1_ORDER) + u1 = (s * rinv) % _key.SECP256K1_ORDER + u2 = (z * rinv) % _key.SECP256K1_ORDER + P1 = _key.SECP256K1.mul([(R.p, u1)]) + P2 = _key.SECP256K1.negate(_key.SECP256K1.mul([(_key.SECP256K1_G, u2)])) + P = _key.SECP256K1.affine(_key.SECP256K1.add(P1, P2)) + result = P[0].to_bytes(32, "little") + P[1].to_bytes(32, "little") + # verify signature at the end + pubkey = _pubkey_parse(result) + if not pubkey.is_valid: + raise ValueError("Failed to recover public key") + if not ecdsa_verify(sig[:64], msghash, result): + raise ValueError("Failed to recover public key") + return result diff --git a/bitcoin_client/ledger_bitcoin/embit/util/secp256k1.py b/bitcoin_client/ledger_bitcoin/embit/util/secp256k1.py new file mode 100644 index 000000000..a3ed8d9a7 --- /dev/null +++ b/bitcoin_client/ledger_bitcoin/embit/util/secp256k1.py @@ -0,0 +1,12 @@ +try: + # if it's micropython + from micropython import const + from secp256k1 import * +except: + # we are in python + try: + # try ctypes bindings + from .ctypes_secp256k1 import * + except: + # fallback to python version + from .py_secp256k1 import * diff --git a/bitcoin_client/pyproject.toml b/bitcoin_client/pyproject.toml index 894551620..5020eb701 100644 --- a/bitcoin_client/pyproject.toml +++ b/bitcoin_client/pyproject.toml @@ -1,10 +1,8 @@ [build-system] requires = [ - "bip32~=3.0", - "coincurve~=18.0", - "typing-extensions>=3.7", "ledgercomm>=1.1.0", "setuptools>=42", + "typing-extensions>=3.7", "wheel" ] build-backend = "setuptools.build_meta" diff --git a/bitcoin_client/setup.cfg b/bitcoin_client/setup.cfg index 21656c888..9d44335f8 100644 --- a/bitcoin_client/setup.cfg +++ b/bitcoin_client/setup.cfg @@ -18,8 +18,6 @@ classifiers = packages = find: python_requires = >=3.7 install_requires= - bip32~=3.0, - coincurve~=18.0, typing-extensions>=3.7 ledgercomm>=1.1.0 packaging>=21.3