From a49892d74249bc7d9ed3f2b5e34c8f4178f077ef Mon Sep 17 00:00:00 2001 From: Salvatore Ingala <6681844+bigspider@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:02:21 +0200 Subject: [PATCH] Added python standalone implementation of MuSig2 signing, and tests --- test_utils/bip0327.py | 465 +++++++++++++++++++ test_utils/musig2.py | 831 ++++++++++++++++++++++++++++++++++ test_utils/taproot.py | 262 +++++++++++ test_utils/taproot_sighash.py | 85 ++++ tests/test_musig2.py | 69 +++ 5 files changed, 1712 insertions(+) create mode 100644 test_utils/bip0327.py create mode 100644 test_utils/musig2.py create mode 100644 test_utils/taproot.py create mode 100644 test_utils/taproot_sighash.py create mode 100644 tests/test_musig2.py diff --git a/test_utils/bip0327.py b/test_utils/bip0327.py new file mode 100644 index 000000000..79149743f --- /dev/null +++ b/test_utils/bip0327.py @@ -0,0 +1,465 @@ +# from https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0327/reference.py +# Distributed as part of BIP-0327 under the BSD-3-Clause license + +# BIP327 reference implementation +# +# WARNING: This implementation is for demonstration purposes only and _not_ to +# be used in production environments. The code is vulnerable to timing attacks, +# for example. + +# fmt: off + +from typing import List, Optional, Tuple, NewType, NamedTuple +import hashlib +import secrets + +# +# The following helper functions were copied from the BIP-340 reference implementation: +# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py +# + +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 + +# Points are tuples of X and Y coordinates and the point at infinity is +# represented by the None keyword. +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) + +Point = Tuple[int, int] + +# This implementation can be sped up by storing the midstate after hashing +# tag_hash instead of rehashing it all the time. +def tagged_hash(tag: str, msg: bytes) -> bytes: + tag_hash = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(tag_hash + tag_hash + msg).digest() + +def is_infinite(P: Optional[Point]) -> bool: + return P is None + +def x(P: Point) -> int: + assert not is_infinite(P) + return P[0] + +def y(P: Point) -> int: + assert not is_infinite(P) + return P[1] + +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: + if P1 is None: + return P2 + if P2 is None: + return P1 + if (x(P1) == x(P2)) and (y(P1) != y(P2)): + return None + if P1 == P2: + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p + else: + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p + x3 = (lam * lam - x(P1) - x(P2)) % p + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) + +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: + R = None + for i in range(256): + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R + +def bytes_from_int(x: int) -> bytes: + return x.to_bytes(32, byteorder="big") + +def lift_x(b: bytes) -> Optional[Point]: + x = int_from_bytes(b) + if x >= p: + return None + y_sq = (pow(x, 3, p) + 7) % p + y = pow(y_sq, (p + 1) // 4, p) + if pow(y, 2, p) != y_sq: + return None + return (x, y if y & 1 == 0 else p-y) + +def int_from_bytes(b: bytes) -> int: + return int.from_bytes(b, byteorder="big") + +def has_even_y(P: Point) -> bool: + assert not is_infinite(P) + return y(P) % 2 == 0 + +def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: + if len(msg) != 32: + raise ValueError('The message must be a 32-byte array.') + if len(pubkey) != 32: + raise ValueError('The public key must be a 32-byte array.') + if len(sig) != 64: + raise ValueError('The signature must be a 64-byte array.') + P = lift_x(pubkey) + r = int_from_bytes(sig[0:32]) + s = int_from_bytes(sig[32:64]) + if (P is None) or (r >= p) or (s >= n): + return False + e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n + R = point_add(point_mul(G, s), point_mul(P, n - e)) + if (R is None) or (not has_even_y(R)) or (x(R) != r): + return False + return True + +# +# End of helper functions copied from BIP-340 reference implementation. +# + +PlainPk = NewType('PlainPk', bytes) +XonlyPk = NewType('XonlyPk', bytes) + +# There are two types of exceptions that can be raised by this implementation: +# - ValueError for indicating that an input doesn't conform to some function +# precondition (e.g. an input array is the wrong length, a serialized +# representation doesn't have the correct format). +# - InvalidContributionError for indicating that a signer (or the +# aggregator) is misbehaving in the protocol. +# +# Assertions are used to (1) satisfy the type-checking system, and (2) check for +# inconvenient events that can't happen except with negligible probability (e.g. +# output of a hash function is 0) and can't be manually triggered by any +# signer. + +# This exception is raised if a party (signer or nonce aggregator) sends invalid +# values. Actual implementations should not crash when receiving invalid +# contributions. Instead, they should hold the offending party accountable. +class InvalidContributionError(Exception): + def __init__(self, signer, contrib): + self.signer = signer + # contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig". + self.contrib = contrib + +infinity = None + +def xbytes(P: Point) -> bytes: + return bytes_from_int(x(P)) + +def cbytes(P: Point) -> bytes: + a = b'\x02' if has_even_y(P) else b'\x03' + return a + xbytes(P) + +def cbytes_ext(P: Optional[Point]) -> bytes: + if is_infinite(P): + return (0).to_bytes(33, byteorder='big') + assert P is not None + return cbytes(P) + +def point_negate(P: Optional[Point]) -> Optional[Point]: + if P is None: + return P + return (x(P), p - y(P)) + +def cpoint(x: bytes) -> Point: + if len(x) != 33: + raise ValueError('x is not a valid compressed point.') + P = lift_x(x[1:33]) + if P is None: + raise ValueError('x is not a valid compressed point.') + if x[0] == 2: + return P + elif x[0] == 3: + P = point_negate(P) + assert P is not None + return P + else: + raise ValueError('x is not a valid compressed point.') + +def cpoint_ext(x: bytes) -> Optional[Point]: + if x == (0).to_bytes(33, 'big'): + return None + else: + return cpoint(x) + +# Return the plain public key corresponding to a given secret key +def individual_pk(seckey: bytes) -> PlainPk: + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + raise ValueError('The secret key must be an integer in the range 1..n-1.') + P = point_mul(G, d0) + assert P is not None + return PlainPk(cbytes(P)) + +def key_sort(pubkeys: List[PlainPk]) -> List[PlainPk]: + pubkeys.sort() + return pubkeys + +KeyAggContext = NamedTuple('KeyAggContext', [('Q', Point), + ('gacc', int), + ('tacc', int)]) + +def get_xonly_pk(keyagg_ctx: KeyAggContext) -> XonlyPk: + Q, _, _ = keyagg_ctx + return XonlyPk(xbytes(Q)) + +def key_agg(pubkeys: List[PlainPk]) -> KeyAggContext: + pk2 = get_second_key(pubkeys) + u = len(pubkeys) + Q = infinity + for i in range(u): + try: + P_i = cpoint(pubkeys[i]) + except ValueError: + raise InvalidContributionError(i, "pubkey") + a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2) + Q = point_add(Q, point_mul(P_i, a_i)) + # Q is not the point at infinity except with negligible probability. + assert(Q is not None) + gacc = 1 + tacc = 0 + return KeyAggContext(Q, gacc, tacc) + +def hash_keys(pubkeys: List[PlainPk]) -> bytes: + return tagged_hash('KeyAgg list', b''.join(pubkeys)) + +def get_second_key(pubkeys: List[PlainPk]) -> PlainPk: + u = len(pubkeys) + for j in range(1, u): + if pubkeys[j] != pubkeys[0]: + return pubkeys[j] + return PlainPk(b'\x00'*33) + +def key_agg_coeff(pubkeys: List[PlainPk], pk_: PlainPk) -> int: + pk2 = get_second_key(pubkeys) + return key_agg_coeff_internal(pubkeys, pk_, pk2) + +def key_agg_coeff_internal(pubkeys: List[PlainPk], pk_: PlainPk, pk2: PlainPk) -> int: + L = hash_keys(pubkeys) + if pk_ == pk2: + return 1 + return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n + +def apply_tweak(keyagg_ctx: KeyAggContext, tweak: bytes, is_xonly: bool) -> KeyAggContext: + if len(tweak) != 32: + raise ValueError('The tweak must be a 32-byte array.') + Q, gacc, tacc = keyagg_ctx + if is_xonly and not has_even_y(Q): + g = n - 1 + else: + g = 1 + t = int_from_bytes(tweak) + if t >= n: + raise ValueError('The tweak must be less than n.') + Q_ = point_add(point_mul(Q, g), point_mul(G, t)) + if Q_ is None: + raise ValueError('The result of tweaking cannot be infinity.') + gacc_ = g * gacc % n + tacc_ = (t + g * tacc) % n + return KeyAggContext(Q_, gacc_, tacc_) + +def bytes_xor(a: bytes, b: bytes) -> bytes: + return bytes(x ^ y for x, y in zip(a, b)) + +def nonce_hash(rand: bytes, pk: PlainPk, aggpk: XonlyPk, i: int, msg_prefixed: bytes, extra_in: bytes) -> int: + buf = b'' + buf += rand + buf += len(pk).to_bytes(1, 'big') + buf += pk + buf += len(aggpk).to_bytes(1, 'big') + buf += aggpk + buf += msg_prefixed + buf += len(extra_in).to_bytes(4, 'big') + buf += extra_in + buf += i.to_bytes(1, 'big') + return int_from_bytes(tagged_hash('MuSig/nonce', buf)) + +def nonce_gen_internal(rand_: bytes, sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]: + if sk is not None: + rand = bytes_xor(sk, tagged_hash('MuSig/aux', rand_)) + else: + rand = rand_ + if aggpk is None: + aggpk = XonlyPk(b'') + if msg is None: + msg_prefixed = b'\x00' + else: + msg_prefixed = b'\x01' + msg_prefixed += len(msg).to_bytes(8, 'big') + msg_prefixed += msg + if extra_in is None: + extra_in = b'' + k_1 = nonce_hash(rand, pk, aggpk, 0, msg_prefixed, extra_in) % n + k_2 = nonce_hash(rand, pk, aggpk, 1, msg_prefixed, extra_in) % n + # k_1 == 0 or k_2 == 0 cannot occur except with negligible probability. + assert k_1 != 0 + assert k_2 != 0 + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + pk) + return secnonce, pubnonce + +def nonce_gen(sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]: + if sk is not None and len(sk) != 32: + raise ValueError('The optional byte array sk must have length 32.') + if aggpk is not None and len(aggpk) != 32: + raise ValueError('The optional byte array aggpk must have length 32.') + rand_ = secrets.token_bytes(32) + return nonce_gen_internal(rand_, sk, pk, aggpk, msg, extra_in) + +def nonce_agg(pubnonces: List[bytes]) -> bytes: + u = len(pubnonces) + aggnonce = b'' + for j in (1, 2): + R_j = infinity + for i in range(u): + try: + R_ij = cpoint(pubnonces[i][(j-1)*33:j*33]) + except ValueError: + raise InvalidContributionError(i, "pubnonce") + R_j = point_add(R_j, R_ij) + aggnonce += cbytes_ext(R_j) + return aggnonce + +SessionContext = NamedTuple('SessionContext', [('aggnonce', bytes), + ('pubkeys', List[PlainPk]), + ('tweaks', List[bytes]), + ('is_xonly', List[bool]), + ('msg', bytes)]) + +def key_agg_and_tweak(pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool]): + if len(tweaks) != len(is_xonly): + raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.') + keyagg_ctx = key_agg(pubkeys) + v = len(tweaks) + for i in range(v): + keyagg_ctx = apply_tweak(keyagg_ctx, tweaks[i], is_xonly[i]) + return keyagg_ctx + +def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, int, Point, int]: + (aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx + Q, gacc, tacc = key_agg_and_tweak(pubkeys, tweaks, is_xonly) + b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + xbytes(Q) + msg)) % n + try: + R_1 = cpoint_ext(aggnonce[0:33]) + R_2 = cpoint_ext(aggnonce[33:66]) + except ValueError: + # Nonce aggregator sent invalid nonces + raise InvalidContributionError(None, "aggnonce") + R_ = point_add(R_1, point_mul(R_2, b)) + R = R_ if not is_infinite(R_) else G + assert R is not None + e = int_from_bytes(tagged_hash('BIP0340/challenge', xbytes(R) + xbytes(Q) + msg)) % n + return (Q, gacc, tacc, b, R, e) + +def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int: + (_, pubkeys, _, _, _) = session_ctx + pk = PlainPk(cbytes(P)) + if pk not in pubkeys: + raise ValueError('The signer\'s pubkey must be included in the list of pubkeys.') + return key_agg_coeff(pubkeys, pk) + +def sign(secnonce: bytearray, sk: bytes, session_ctx: SessionContext) -> bytes: + (Q, gacc, _, b, R, e) = get_session_values(session_ctx) + k_1_ = int_from_bytes(secnonce[0:32]) + k_2_ = int_from_bytes(secnonce[32:64]) + # Overwrite the secnonce argument with zeros such that subsequent calls of + # sign with the same secnonce raise a ValueError. + secnonce[:64] = bytearray(b'\x00'*64) + if not 0 < k_1_ < n: + raise ValueError('first secnonce value is out of range.') + if not 0 < k_2_ < n: + raise ValueError('second secnonce value is out of range.') + k_1 = k_1_ if has_even_y(R) else n - k_1_ + k_2 = k_2_ if has_even_y(R) else n - k_2_ + d_ = int_from_bytes(sk) + if not 0 < d_ < n: + raise ValueError('secret key value is out of range.') + P = point_mul(G, d_) + assert P is not None + pk = cbytes(P) + if not pk == secnonce[64:97]: + raise ValueError('Public key does not match nonce_gen argument') + a = get_session_key_agg_coeff(session_ctx, P) + g = 1 if has_even_y(Q) else n - 1 + d = g * gacc * d_ % n + s = (k_1 + b * k_2 + e * a * d) % n + psig = bytes_from_int(s) + R_s1 = point_mul(G, k_1_) + R_s2 = point_mul(G, k_2_) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + # Optional correctness check. The result of signing should pass signature verification. + assert partial_sig_verify_internal(psig, pubnonce, pk, session_ctx) + return psig + +def det_nonce_hash(sk_: bytes, aggothernonce: bytes, aggpk: bytes, msg: bytes, i: int) -> int: + buf = b'' + buf += sk_ + buf += aggothernonce + buf += aggpk + buf += len(msg).to_bytes(8, 'big') + buf += msg + buf += i.to_bytes(1, 'big') + return int_from_bytes(tagged_hash('MuSig/deterministic/nonce', buf)) + +def deterministic_sign(sk: bytes, aggothernonce: bytes, pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, rand: Optional[bytes]) -> Tuple[bytes, bytes]: + if rand is not None: + sk_ = bytes_xor(sk, tagged_hash('MuSig/aux', rand)) + else: + sk_ = sk + aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly)) + + k_1 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 0) % n + k_2 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 1) % n + # k_1 == 0 or k_2 == 0 cannot occur except with negligible probability. + assert k_1 != 0 + assert k_2 != 0 + + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + individual_pk(sk)) + try: + aggnonce = nonce_agg([pubnonce, aggothernonce]) + except Exception: + raise InvalidContributionError(None, "aggothernonce") + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + psig = sign(secnonce, sk, session_ctx) + return (pubnonce, psig) + +def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, i: int) -> bool: + if len(pubnonces) != len(pubkeys): + raise ValueError('The `pubnonces` and `pubkeys` arrays must have the same length.') + if len(tweaks) != len(is_xonly): + raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.') + aggnonce = nonce_agg(pubnonces) + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx) + +def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session_ctx: SessionContext) -> bool: + (Q, gacc, _, b, R, e) = get_session_values(session_ctx) + s = int_from_bytes(psig) + if s >= n: + return False + R_s1 = cpoint(pubnonce[0:33]) + R_s2 = cpoint(pubnonce[33:66]) + Re_s_ = point_add(R_s1, point_mul(R_s2, b)) + Re_s = Re_s_ if has_even_y(R) else point_negate(Re_s_) + P = cpoint(pk) + if P is None: + return False + a = get_session_key_agg_coeff(session_ctx, P) + g = 1 if has_even_y(Q) else n - 1 + g_ = g * gacc % n + return point_mul(G, s) == point_add(Re_s, point_mul(P, e * a * g_ % n)) + +def partial_sig_agg(psigs: List[bytes], session_ctx: SessionContext) -> bytes: + (Q, _, tacc, _, R, e) = get_session_values(session_ctx) + s = 0 + u = len(psigs) + for i in range(u): + s_i = int_from_bytes(psigs[i]) + if s_i >= n: + raise InvalidContributionError(i, "psig") + s = (s + s_i) % n + g = 1 if has_even_y(Q) else n - 1 + s = (s + e * g * tacc) % n + return xbytes(R) + bytes_from_int(s) diff --git a/test_utils/musig2.py b/test_utils/musig2.py new file mode 100644 index 000000000..3537aa0bd --- /dev/null +++ b/test_utils/musig2.py @@ -0,0 +1,831 @@ +""" +This module contains a complete, minimal, standalone MuSig cosigner implementation. +It is NOT a cryptographically secure implementation, and it is only to be used for +testing purposes. + +In lack of a library for wallet policies in python, a minimal version of it for +the purpose of parsing and processing tr() descriptors is implemented here, using +embit for the the final task of compiling simple miniscript descriptors to Script. + +The main objects and methods exported in this class are: + +- PsbtMusig2Cosigner: an abstract class that represents a cosigner in MuSig2. +- HotMusig2Cosigner: an implementation of PsbtMusig2Cosigner that contains a hot + extended private key. Useful for tests. +- run_musig2_test: tests a full signing cycle for a generic list of PsbtMusig2Cosigners. +""" + + +import hashlib +import hmac +from io import BytesIO +import re +from re import Match + +from dataclasses import dataclass +import secrets +import struct +from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union +from abc import ABC, abstractmethod + +import base58 + +from test_utils.taproot_sighash import SIGHASH_DEFAULT, TaprootSignatureHash + +from . import bip0327, bip0340, sha256 +from . import taproot + +from bitcoin_client.ledger_bitcoin.embit.descriptor.miniscript import Miniscript +from bitcoin_client.ledger_bitcoin.psbt import PSBT, PartiallySignedInput +from bitcoin_client.ledger_bitcoin.key import G, ExtendedKey, bytes_to_point, point_add, point_mul, point_to_bytes +from bitcoin_client.ledger_bitcoin.wallet import WalletPolicy + + +HARDENED_INDEX = 0x80000000 + + +def tapleaf_hash(script: Optional[bytes], leaf_version=b'\xC0') -> Optional[bytes]: + if script is None: + return None + return taproot.tagged_hash( + "TapLeaf", + leaf_version + taproot.ser_script(script) + ) + + +@dataclass +class PlainKeyPlaceholder: + key_index: int + num1: int + num2: int + + +@dataclass +class Musig2KeyPlaceholder: + key_indexes: List[int] + num1: int + num2: int + + +KeyPlaceholder = Union[PlainKeyPlaceholder, Musig2KeyPlaceholder] + + +def parse_placeholder(placeholder_str: str) -> KeyPlaceholder: + """Parses a placeholder string to create a KeyPlaceholder object.""" + if placeholder_str.startswith('musig'): + key_indexes_str = placeholder_str[6:placeholder_str.index( + ')/<')].split(',') + key_indexes = [int(index.strip('@')) for index in key_indexes_str] + + nums_part = placeholder_str[placeholder_str.index(')/<') + 3:-3] + num1, num2 = map(int, nums_part.split(';')) + + return Musig2KeyPlaceholder(key_indexes, num1, num2) + elif placeholder_str.startswith('@'): + parts = placeholder_str.split('/') + key_index = int(parts[0].strip('@')) + + # Remove '<' from the start and '>' from the end + nums_part = parts[1][1:-1] + num1, num2 = map(int, nums_part.split(';')) + + return PlainKeyPlaceholder(key_index, num1, num2) + else: + raise ValueError("Invalid placeholder string") + + +def extract_placeholders(desc_tmpl: str) -> List[KeyPlaceholder]: + """Extracts and parses all placeholders in a descriptor template, from left to right.""" + + pattern = r'musig\((?:@\d+,)*(?:@\d+)\)/<\d+;\d+>/\*|@\d+/<\d+;\d+>/\*' + matches = [(match.group(), match.start()) + for match in re.finditer(pattern, desc_tmpl)] + sorted_matches = sorted(matches, key=lambda x: x[1]) + return [parse_placeholder(match[0]) for match in sorted_matches] + + +def musig(pubkeys: Iterable[bytes], version_bytes: bytes) -> Tuple[str, bip0327.KeyAggContext]: + """ + Constructs the musig2 aggregated extended public key from a list of + compressed public keys, and the version bytes. + """ + + assert all(len(pk) == 33 for pk in pubkeys) + assert len(version_bytes) == 4 + + depth = b'\x00' + fingerprint = b'\x00\x00\x00\x00' + child_number = b'\x00\x00\x00\x00' + + key_agg_ctx = bip0327.key_agg(pubkeys) + Q = key_agg_ctx.Q + compressed_pubkey = ( + b'\x02' if Q[1] % 2 == 0 else b'\x03') + bip0327.get_xonly_pk(key_agg_ctx) + chaincode = bytes.fromhex( + "868087ca02a6f974c4598924c36b57762d32cb45717167e300622c7167e38965") + ext_pubkey = version_bytes + depth + fingerprint + \ + child_number + chaincode + compressed_pubkey + return base58.b58encode_check(ext_pubkey).decode(), key_agg_ctx + + +def aggregate_musig_pubkey(keys_info: Iterable[str]) -> Tuple[str, bip0327.KeyAggContext]: + """ + Constructs the musig2 aggregated extended public key from the list of keys info + of the participating keys. + """ + + pubkeys: list[bytes] = [] + versions: Set[str] = set() + for ki in keys_info: + start = ki.find(']') + xpub = ki[start + 1:] + xpub_bytes = base58.b58decode_check(xpub) + versions.add(xpub_bytes[:4]) + pubkeys.append(xpub_bytes[-33:]) + + if len(versions) > 1: + raise ValueError( + "All the extended public keys should be from the same network") + + return musig(pubkeys, versions.pop()) + + +def derive_from_key_info(key_info: str, steps: List[int]) -> str: + start = key_info.find(']') + pk = ExtendedKey.deserialize(key_info[start + 1:]) + return pk.derive_pub_path(steps).to_string() + + +def derive_plain_descriptor(desc_tmpl: str, keys_info: List[str], is_change: bool, address_index: int): + """ + Given a wallet policy, and the change/address_index combination, computes the corresponding descriptor. + It replaces /** with /<0;1>/* + It also replaces each musig() key expression with the corresponding xpub. + The resulting descriptor can be used with descriptor libraries that do not support musig or wallet policies. + """ + + desc_tmpl = desc_tmpl.replace("/**", "/<0;1>/*") + desc_tmpl = desc_tmpl.replace("*", str(address_index)) + + # Replace each with M if is_change is False, otherwise with N + def replace_m_n(match: Match[str]): + m, n = match.groups() + return m if not is_change else n + + desc_tmpl = re.sub(r'<([^;]+);([^>]+)>', replace_m_n, desc_tmpl) + + # Replace musig(...) expressions + def replace_musig(match: Match[str]): + musig_content = match.group(1) + steps = [int(x) for x in match.group(2).split("/")] + + assert len(steps) == 2 + + key_indexes = [int(i.strip('@')) for i in musig_content.split(',')] + key_infos = [keys_info[i] for i in key_indexes] + agg_xpub = aggregate_musig_pubkey(key_infos)[0] + + return derive_from_key_info(agg_xpub, steps) + + desc_tmpl = re.sub(r'musig\(([^)]+)\)/(\d+/\d+)', replace_musig, desc_tmpl) + + # Replace @i/a/b with the i-th element in keys_info, deriving the key appropriately + # to get a plain xpub + def replace_key_index(match): + index, step1, step2 = [int(x) for x in match.group(1).split('/')] + return derive_from_key_info(keys_info[index], [step1, step2]) + + desc_tmpl = re.sub(r'@(\d+/\d+/\d+)', replace_key_index, desc_tmpl) + + return desc_tmpl + + +class Tree: + """ + Recursive structure that represents a taptree, or one of its subtrees. + It can either contain a single descriptor template (if it's a tapleaf), or exactly two child Trees. + """ + + def __init__(self, content: Union[str, Tuple['Tree', 'Tree']]): + if isinstance(content, str): + self.script = content + self.left, self.right = (None, None) + else: + self.script = None + self.left, self.right = content + + @property + def is_leaf(self) -> bool: + return self.script is not None + + def __str__(self): + if self.is_leaf: + return self.script + else: + return f'{{{str(self.left)},{str(self.right)}}}' + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, str]]: + """ + Generates an iterator over the placeholders contained in the scripts of the tree's leaf nodes. + + Yields: + Iterator[Tuple[KeyPlaceholder, str]]: An iterator over tuples containing a KeyPlaceholder and its associated script. + """ + + if self.is_leaf: + assert self.script is not None + for placeholder in extract_placeholders(self.script): + yield (placeholder, self.script) + else: + assert self.left is not None and self.right is not None + for placeholder, script in self.left.placeholders(): + yield (placeholder, script) + for placeholder, script in self.right.placeholders(): + yield (placeholder, script) + + def get_taptree_hash(self, keys_info: List[str], is_change: bool, address_index: int) -> bytes: + if self.is_leaf: + assert self.script is not None + leaf_desc = derive_plain_descriptor( + self.script, keys_info, is_change, address_index) + + s = BytesIO(leaf_desc.encode()) + desc: Miniscript = Miniscript.read_from( + s, taproot=True) + + return tapleaf_hash(desc.compile()) + + else: + assert self.left is not None and self.right is not None + left_h = self.left.get_taptree_hash( + keys_info, is_change, address_index) + right_h = self.left.get_taptree_hash( + keys_info, is_change, address_index) + if left_h <= right_h: + return taproot.tagged_hash("TapBranch", left_h + right_h) + else: + return taproot.tagged_hash("TapBranch", right_h + left_h) + + +class TrDescriptorTemplate: + """ + Represents a descriptor template for a tr(KEY) or a tr(KEY,TREE). + This is minimal implementation in order to enable iterating over the placeholders, + and compile the corresponding leaf scripts. + """ + + def __init__(self, key: KeyPlaceholder, tree=Optional[Tree]): + self.key: KeyPlaceholder = key + self.tree: Optional[Tree] = tree + + @classmethod + def from_string(cls, input_string): + parser = cls.Parser(input_string.replace("/**", "/<0;1>/*")) + return parser.parse() + + class Parser: + def __init__(self, input): + self.input = input + self.index = 0 + self.length = len(input) + + def parse(self): + if self.input.startswith('tr('): + self.consume('tr(') + key = self.parse_keyplaceholder() + tree = None + if self.peek() == ',': + self.consume(',') + tree = self.parse_tree() + self.consume(')') + return TrDescriptorTemplate(key, tree) + else: + raise Exception( + "Syntax error: Input does not start with 'tr('") + + def parse_keyplaceholder(self): + if self.peek() == '@': + self.consume('@') + key_index = self.parse_num() + self.consume('/<') + num1 = self.parse_num() + self.consume(';') + num2 = self.parse_num() + self.consume('>/*') + return PlainKeyPlaceholder(key_index, num1, num2) + elif self.input[self.index:self.index+6] == 'musig(': + self.consume('musig(') + key_indexes = self.parse_key_indexes() + self.consume(')/<') + num1 = self.parse_num() + self.consume(';') + num2 = self.parse_num() + self.consume('>/*') + return Musig2KeyPlaceholder(key_indexes, num1, num2) + else: + raise Exception("Syntax error in key placeholder") + + def parse_tree(self) -> Tree: + if self.peek() == '{': + self.consume('{') + tree1 = self.parse_tree() + self.consume(',') + tree2 = self.parse_tree() + self.consume('}') + return Tree((tree1, tree2)) + else: + return Tree(self.parse_script()) + + def parse_script(self) -> str: + start = self.index + nesting = 0 + while self.index < self.length and (nesting > 0 or self.input[self.index] not in ('}', ',', ')')): + if self.input[self.index] == '(': + nesting = nesting + 1 + elif self.input[self.index] == ')': + nesting = nesting - 1 + + self.index += 1 + return self.input[start:self.index] + + def parse_key_indexes(self): + nums = [] + self.consume('@') + nums.append(self.parse_num()) + while self.peek() == ',': + self.consume(',@') + nums.append(self.parse_num()) + return nums + + def parse_num(self): + start = self.index + while self.index < self.length and self.input[self.index].isdigit(): + self.index += 1 + return int(self.input[start:self.index]) + + def consume(self, char): + if self.input[self.index:self.index+len(char)] == char: + self.index += len(char) + else: + raise Exception( + f"Syntax error: Expected '{char}'; rest: {self.input[self.index:]}") + + def peek(self): + return self.input[self.index] if self.index < self.length else None + + def placeholders(self) -> Iterator[Tuple[KeyPlaceholder, Optional[str]]]: + """ + Generates an iterator over the placeholders contained in the template and its tree, also + yielding the corresponding leaf script descriptor (or None for the keypath placeholder). + + Yields: + Iterator[Tuple[KeyPlaceholder, Optional[str]]]: An iterator over tuples containing a KeyPlaceholder and an optional associated script. + """ + + yield (self.key, None) + + if self.tree is not None: + for placeholder, script in self.tree.placeholders(): + yield (placeholder, script) + + def get_taptree_hash(self, is_change: bool, address_index: int) -> bytes: + if self.tree is None: + raise ValueError("There is no taptree") + return self.tree.get_taptree_hash(is_change, address_index) + + +class PsbtMusig2Cosigner(ABC): + @abstractmethod + def get_participant_pubkey(self) -> bip0327.Point: + """ + This method should returns this cosigner's public key. + """ + pass + + @abstractmethod + def generate_public_nonces(self, psbt: PSBT) -> None: + """ + This method should generate public nonces and modify the given Psbt object in-place. + It should raise an exception in case of failure. + """ + pass + + @abstractmethod + def generate_partial_signatures(self, psbt: PSBT) -> None: + """ + Receives a PSBT that contains all the participants' public nonces, and adds this participant's partial signature. + It should raise an exception in case of failure. + """ + pass + + +def find_change_and_addr_index_for_musig(input_psbt: PartiallySignedInput, placeholder: Musig2KeyPlaceholder, agg_xpub: ExtendedKey): + num1, num2 = placeholder.num1, placeholder.num2 + + # Iterate through tap key origins in the input + # TODO: this might be made more precise (e.g. use the leaf_hash from the tap_bip32_paths items) + for xonly, (_, key_origin) in input_psbt.tap_bip32_paths.items(): + der_path = key_origin.path + # Check if the fingerprint matches the expected pattern and the derivation path has the correct structure + if key_origin.fingerprint == b'\x00\x00\x00\x00' and len(der_path) == 2 and der_path[0] < HARDENED_INDEX and der_path[1] < HARDENED_INDEX and (der_path[0] == num1 or der_path[0] == num2): + if xonly != agg_xpub.derive_pub_path(der_path).pubkey[1:]: + continue + + # Determine if the address is a change address and extract the address index + is_change = (der_path[0] == num2) + addr_index = int(der_path[1]) + return is_change, addr_index + + return None + + +def get_bip32_tweaks(ext_key: ExtendedKey, steps: List[int]) -> List[bytes]: + """ + Generate BIP32 tweaks for a series of derivation steps on an extended key. + + Args: + ext_key (ExtendedKey): The extended public key. + steps (List[int]): A list of derivation steps (must be unhardened). + + Returns: + List[bytes]: The list of additive tweaks for those derivation steps. + """ + + result = [] + + cur_pubkey = ext_key.pubkey + cur_chaincode = ext_key.chaincode + + for step in steps: + if step < 0 or step >= HARDENED_INDEX: + raise ValueError("Invalid unhardened derivation step") + + data = cur_pubkey + struct.pack(">L", step) + Ihmac = hmac.new(cur_chaincode, data, hashlib.sha512).digest() + Il = Ihmac[:32] + Ir = Ihmac[32:] + + result.append(Il) + + Il_int = int.from_bytes(Il, 'big') + child_pubkey_point = point_add(point_mul(G, Il_int), + bytes_to_point(cur_pubkey)) + child_pubkey = point_to_bytes(child_pubkey_point) + + cur_pubkey = child_pubkey + cur_chaincode = Ir + + return result + + +def process_placeholder( + wallet_policy: WalletPolicy, + psbt_input: PartiallySignedInput, + placeholder: Musig2KeyPlaceholder, + keyagg_ctx: bip0327.KeyAggContext, + agg_xpub: ExtendedKey, + tapleaf_desc: Optional[str], + desc_tmpl: TrDescriptorTemplate +) -> Optional[Tuple[List[bytes], List[bool], Optional[bytes], bytes]]: + """ + This method encapsulates all the precomputations that are done for a certain + wallet policy, psbt input and musig() placeholder that are common to both the + nonce generation and the partial signature generation flows. + + Returs a tuple containing: + - tweaks: a list of tweaks to be applied to the aggregate musig key + - is_xonly_tweak: a list of boolean of the same length of tweaks, specifying for + each of them if it's a plain tweak or an x-only tweak + - leaf_script: the compiled leaf script, or None for a taproot keypath spend + - aggpk_tweaked: the value of the aggregate pubkey after applying the tweaks + """ + res = find_change_and_addr_index_for_musig( + psbt_input, placeholder, agg_xpub) + if res is None: + return None + is_change, address_index = res + + leaf_script = None + if tapleaf_desc is not None: + leaf_desc = derive_plain_descriptor( + tapleaf_desc, wallet_policy.keys_info, is_change, address_index) + s = BytesIO(leaf_desc.encode()) + desc: Miniscript = Miniscript.read_from(s, taproot=True) + leaf_script = desc.compile() + + tweaks = [] + is_xonly_tweak = [] + + # Compute bip32 tweaks + bip32_tweaks = get_bip32_tweaks(agg_xpub, [ + placeholder.num2 if is_change else placeholder.num1, + address_index + ]) + for tweak in bip32_tweaks: + tweaks.append(tweak) + is_xonly_tweak.append(False) + + if leaf_script is None: + t = agg_xpub.pubkey[-32:] + if desc_tmpl.tree is not None: + t += desc_tmpl.get_taptree_hash(is_change, address_index) + tweaks.append(taproot.tagged_hash("TapTweak", t)) + is_xonly_tweak.append(True) + + keyagg_ctx = aggregate_musig_pubkey( + wallet_policy.keys_info[i] for i in placeholder.key_indexes)[1] + + for tweak, is_xonly in zip(tweaks, is_xonly_tweak): + keyagg_ctx = bip0327.apply_tweak(keyagg_ctx, tweak, is_xonly) + + aggpk_tweaked = bip0327.get_xonly_pk(keyagg_ctx) + + return (tweaks, is_xonly_tweak, leaf_script, aggpk_tweaked) + + +class HotMusig2Cosigner(PsbtMusig2Cosigner): + """ + Implements a PsbtMusig2Cosigner for a given wallet policy and a private + that appears as one of the key in a musig() key expression. + """ + + def __init__(self, wallet_policy: WalletPolicy, privkey: str) -> None: + super().__init__() + + self.wallet_policy = wallet_policy + self.privkey = ExtendedKey.deserialize(privkey) + + assert self.privkey.to_string() == privkey + + self.musig_psbt_sessions: Dict[bytes, bytes] = {} + + assert self.privkey.is_private + + def compute_psbt_session_id(self, psbt: PSBT) -> bytes: + psbt.tx.rehash() + return sha256(psbt.tx.hash + self.wallet_policy.id) + + def get_participant_pubkey(self) -> bip0327.Point: + return bip0327.cpoint(self.privkey.pubkey) + + def generate_public_nonces(self, psbt: PSBT) -> None: + desc_tmpl = TrDescriptorTemplate.from_string( + self.wallet_policy.descriptor_template) + + psbt_session_id = self.compute_psbt_session_id(psbt) + + # root of all pseudorandomness for this psbt session + rand_seed = secrets.token_bytes(32) + + for placeholder_index, (placeholder, tapleaf_desc) in enumerate(desc_tmpl.placeholders()): + if not isinstance(placeholder, Musig2KeyPlaceholder): + continue + + agg_xpub_str, keyagg_ctx = aggregate_musig_pubkey( + self.wallet_policy.keys_info[i] for i in placeholder.key_indexes) + agg_xpub = ExtendedKey.deserialize(agg_xpub_str) + + for input_index, input in enumerate(psbt.inputs): + result = process_placeholder( + self.wallet_policy, input, placeholder, keyagg_ctx, agg_xpub, tapleaf_desc, desc_tmpl) + if result is None: + continue + + (_, _, leaf_script, aggpk_tweaked) = result + + rand_i_j = sha256( + rand_seed + + input_index.to_bytes(4, byteorder='big') + + placeholder_index.to_bytes(4, byteorder='big') + ) + + # secnonce: bytearray + # pubnonce: bytes + _, pubnonce = bip0327.nonce_gen_internal( + rand_=rand_i_j, + sk=self.privkey.privkey, + pk=self.privkey.pubkey, + aggpk=aggpk_tweaked, + msg=None, + extra_in=None + ) + + pubnonce_identifier = ( + self.privkey.pubkey, + aggpk_tweaked, + tapleaf_hash(leaf_script) + ) + + input.musig2_pub_nonces[pubnonce_identifier] = pubnonce + + self.musig_psbt_sessions[psbt_session_id] = rand_seed + + def generate_partial_signatures(self, psbt: PSBT) -> None: + desc_tmpl = TrDescriptorTemplate.from_string( + self.wallet_policy.descriptor_template) + + psbt_session_id = self.compute_psbt_session_id(psbt) + + # Get the session's randomness seed, while simultaneously deleting it from the open sessions + rand_seed = self.musig_psbt_sessions.pop(psbt_session_id, None) + + if rand_seed is None: + raise ValueError( + "No musig signing session for this psbt") + + for placeholder_index, (placeholder, tapleaf_desc) in enumerate(desc_tmpl.placeholders()): + if not isinstance(placeholder, Musig2KeyPlaceholder): + continue + + agg_xpub_str, keyagg_ctx = aggregate_musig_pubkey( + self.wallet_policy.keys_info[i] for i in placeholder.key_indexes) + agg_xpub = ExtendedKey.deserialize(agg_xpub_str) + + for input_index, input in enumerate(psbt.inputs): + result = process_placeholder( + self.wallet_policy, input, placeholder, keyagg_ctx, agg_xpub, tapleaf_desc, desc_tmpl) + if result is None: + continue + + (tweaks, is_xonly_tweak, leaf_script, aggpk_tweaked) = result + + rand_i_j = sha256( + rand_seed + + input_index.to_bytes(4, byteorder='big') + + placeholder_index.to_bytes(4, byteorder='big') + ) + + secnonce, pubnonce = bip0327.nonce_gen_internal( + rand_=rand_i_j, + sk=self.privkey.privkey, + pk=self.privkey.pubkey, + aggpk=aggpk_tweaked, + msg=None, + extra_in=None + ) + + pubkeys_in_musig: List[ExtendedKey] = [] + my_key_index_in_musig: Optional[int] = None + for i in placeholder.key_indexes: + k_i = self.wallet_policy.keys_info[i] + xpub_i = k_i[k_i.find(']') + 1:] + pubkeys_in_musig.append(ExtendedKey.deserialize(xpub_i)) + + if xpub_i == self.privkey.neutered().to_string(): + my_key_index_in_musig = i + + if my_key_index_in_musig is None: + raise ValueError("No internal key found in musig") + + nonces: List[bytes] = [] + for participant_key in pubkeys_in_musig: + participant_pubnonce_identifier = ( + participant_key.pubkey, + bytes(aggpk_tweaked), + tapleaf_hash(leaf_script) + ) + + if participant_key.pubkey == self.privkey.pubkey and input.musig2_pub_nonces[participant_pubnonce_identifier] != pubnonce: + raise ValueError( + f"Public nonce in psbt didn't match the expected one for cosigner {self.privkey.pubkey}") + + if participant_pubnonce_identifier in input.musig2_pub_nonces: + nonces.append( + input.musig2_pub_nonces[participant_pubnonce_identifier]) + else: + raise ValueError( + f"Missing pubnonce for pubkey {participant_key.pubkey.hex()} in psbt") + + if leaf_script is None: + sighash = TaprootSignatureHash( + txTo=psbt.tx, + spent_utxos=[ + psbt.inputs[i].witness_utxo for i in range(len(psbt.inputs))], + hash_type=input.sighash or SIGHASH_DEFAULT, + input_index=input_index, + ) + else: + sighash = TaprootSignatureHash( + txTo=psbt.tx, + spent_utxos=[ + psbt.inputs[i].witness_utxo for i in range(len(psbt.inputs))], + hash_type=input.sighash or SIGHASH_DEFAULT, + input_index=input_index, + scriptpath=True, + script=leaf_script + ) + + aggnonce = bip0327.nonce_agg(nonces) + + session_ctx = bip0327.SessionContext( + aggnonce=aggnonce, + pubkeys=[pk.pubkey for pk in pubkeys_in_musig], + tweaks=tweaks, + is_xonly=is_xonly_tweak, + msg=sighash) + + partial_sig = bip0327.sign( + secnonce, self.privkey.privkey, session_ctx) + + pubnonce_identifier = ( + self.privkey.pubkey, + aggpk_tweaked, + tapleaf_hash(leaf_script) + ) + + input.musig2_partial_sigs[pubnonce_identifier] = partial_sig + + +def run_musig2_test(wallet_policy: WalletPolicy, psbt: PSBT, cosigners: List[PsbtMusig2Cosigner], sighashes: list[bytes]): + """ + This performs the following steps: + - go through all the cosigners to let them add their pubnonce; + - go through all the cosigners to let them add their partial signature; + - aggregate the partial signatures to produce the final Schnorr signature; + - verify that the produced signature is valid for the provided sighash. + + The sighashes (one per input) are given as argument and are assument to be correct. + """ + + if len(psbt.inputs) != len(sighashes): + raise ValueError("The sighashes") + + for signer in cosigners: + signer.generate_public_nonces(psbt) + + for signer in cosigners: + signer.generate_partial_signatures(psbt) + + desc_tmpl = TrDescriptorTemplate.from_string( + wallet_policy.descriptor_template) + + for placeholder, tapleaf_desc in desc_tmpl.placeholders(): + if not isinstance(placeholder, Musig2KeyPlaceholder): + continue + + agg_xpub_str, keyagg_ctx = aggregate_musig_pubkey( + wallet_policy.keys_info[i] for i in placeholder.key_indexes) + agg_xpub = ExtendedKey.deserialize(agg_xpub_str) + + for input_index, input in enumerate(psbt.inputs): + result = process_placeholder( + wallet_policy, input, placeholder, keyagg_ctx, agg_xpub, tapleaf_desc, desc_tmpl) + + if result is None: + raise RuntimeError( + "Unexpected: processing the musig placeholder failed") + + (tweaks, is_xonly_tweak, leaf_script, aggpk_tweaked) = result + + pubkeys_in_musig: List[ExtendedKey] = [] + for i in placeholder.key_indexes: + k_i = wallet_policy.keys_info[i] + xpub_i = k_i[k_i.find(']') + 1:] + pubkeys_in_musig.append(ExtendedKey.deserialize(xpub_i)) + + nonces: List[bytes] = [] + for participant_key in pubkeys_in_musig: + pubnonce_identifier = ( + participant_key.pubkey, + bytes(aggpk_tweaked), + tapleaf_hash(leaf_script) + ) + + if pubnonce_identifier in input.musig2_pub_nonces: + nonces.append( + input.musig2_pub_nonces[pubnonce_identifier]) + else: + raise ValueError( + f"Missing pubnonce for pubkey {participant_key.pubkey.hex()} in psbt") + + aggnonce = bip0327.nonce_agg(nonces) + + sighash = sighashes[input_index] + + session_ctx = bip0327.SessionContext( + aggnonce=aggnonce, + pubkeys=[pk.pubkey for pk in pubkeys_in_musig], + tweaks=tweaks, + is_xonly=is_xonly_tweak, + msg=sighash) + + # collect partial signatures + psigs: List[bytes] = [] + + for participant_key in pubkeys_in_musig: + pubnonce_identifier = ( + participant_key.pubkey, + bytes(aggpk_tweaked), + tapleaf_hash(leaf_script) + ) + + if pubnonce_identifier in input.musig2_partial_sigs: + psigs.append( + input.musig2_partial_sigs[pubnonce_identifier]) + else: + raise ValueError( + f"Missing partial signature for pubkey {participant_key.pubkey.hex()} in psbt") + + sig = bip0327.partial_sig_agg(psigs, session_ctx) + + assert (bip0340.schnorr_verify(sighash, aggpk_tweaked, sig)) diff --git a/test_utils/taproot.py b/test_utils/taproot.py new file mode 100644 index 000000000..3f84ab56e --- /dev/null +++ b/test_utils/taproot.py @@ -0,0 +1,262 @@ +# from BIP-0340 and BIP-0341 +# - https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0340.mediawiki +# - https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0341.mediawiki +# Distributed under the BSD-3-Clause license + +# fmt: off + +# Set DEBUG to True to get a detailed debug output including +# intermediate values during key generation, signing, and +# verification. This is implemented via calls to the +# debug_print_vars() function. +# +# If you want to print values on an individual basis, use +# the pretty() function, e.g., print(pretty(foo)). +import hashlib +from typing import Any, Optional, Tuple + + +DEBUG = False + +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 +SECP256K1_ORDER = n + +# Points are tuples of X and Y coordinates and the point at infinity is +# represented by the None keyword. +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) + +Point = Tuple[int, int] + +# This implementation can be sped up by storing the midstate after hashing +# tag_hash instead of rehashing it all the time. +def tagged_hash(tag: str, msg: bytes) -> bytes: + tag_hash = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(tag_hash + tag_hash + msg).digest() + +def is_infinite(P: Optional[Point]) -> bool: + return P is None + +def x(P: Point) -> int: + assert not is_infinite(P) + return P[0] + +def y(P: Point) -> int: + assert not is_infinite(P) + return P[1] + +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: + if P1 is None: + return P2 + if P2 is None: + return P1 + if (x(P1) == x(P2)) and (y(P1) != y(P2)): + return None + if P1 == P2: + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p + else: + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p + x3 = (lam * lam - x(P1) - x(P2)) % p + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) + +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: + R = None + for i in range(256): + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R + +def bytes_from_int(x: int) -> bytes: + return x.to_bytes(32, byteorder="big") + +def bytes_from_point(P: Point) -> bytes: + return bytes_from_int(x(P)) + +def xor_bytes(b0: bytes, b1: bytes) -> bytes: + return bytes(x ^ y for (x, y) in zip(b0, b1)) + +def lift_x(x: int) -> Optional[Point]: + if x >= p: + return None + y_sq = (pow(x, 3, p) + 7) % p + y = pow(y_sq, (p + 1) // 4, p) + if pow(y, 2, p) != y_sq: + return None + return (x, y if y & 1 == 0 else p-y) + +def int_from_bytes(b: bytes) -> int: + return int.from_bytes(b, byteorder="big") + +def hash_sha256(b: bytes) -> bytes: + return hashlib.sha256(b).digest() + +def has_even_y(P: Point) -> bool: + assert not is_infinite(P) + return y(P) % 2 == 0 + +def pubkey_gen(seckey: bytes) -> bytes: + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + raise ValueError('The secret key must be an integer in the range 1..n-1.') + P = point_mul(G, d0) + assert P is not None + return bytes_from_point(P) + +def schnorr_sign(msg: bytes, seckey: bytes, aux_rand: bytes) -> bytes: + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + raise ValueError('The secret key must be an integer in the range 1..n-1.') + if len(aux_rand) != 32: + raise ValueError('aux_rand must be 32 bytes instead of %i.' % len(aux_rand)) + P = point_mul(G, d0) + assert P is not None + d = d0 if has_even_y(P) else n - d0 + t = xor_bytes(bytes_from_int(d), tagged_hash("BIP0340/aux", aux_rand)) + k0 = int_from_bytes(tagged_hash("BIP0340/nonce", t + bytes_from_point(P) + msg)) % n + if k0 == 0: + raise RuntimeError('Failure. This happens only with negligible probability.') + R = point_mul(G, k0) + assert R is not None + k = n - k0 if not has_even_y(R) else k0 + e = int_from_bytes(tagged_hash("BIP0340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n + sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n) + debug_print_vars() + if not schnorr_verify(msg, bytes_from_point(P), sig): + raise RuntimeError('The created signature does not pass verification.') + return sig + +def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: + if len(pubkey) != 32: + raise ValueError('The public key must be a 32-byte array.') + if len(sig) != 64: + raise ValueError('The signature must be a 64-byte array.') + P = lift_x(int_from_bytes(pubkey)) + r = int_from_bytes(sig[0:32]) + s = int_from_bytes(sig[32:64]) + if (P is None) or (r >= p) or (s >= n): + debug_print_vars() + return False + e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n + R = point_add(point_mul(G, s), point_mul(P, n - e)) + if (R is None) or (not has_even_y(R)) or (x(R) != r): + debug_print_vars() + return False + debug_print_vars() + return True + +import inspect + +def pretty(v: Any) -> Any: + if isinstance(v, bytes): + return '0x' + v.hex() + if isinstance(v, int): + return pretty(bytes_from_int(v)) + if isinstance(v, tuple): + return tuple(map(pretty, v)) + return v + +def debug_print_vars() -> None: + if DEBUG: + current_frame = inspect.currentframe() + assert current_frame is not None + frame = current_frame.f_back + assert frame is not None + print(' Variables in function ', frame.f_code.co_name, ' at line ', frame.f_lineno, ':', sep='') + for var_name, var_val in frame.f_locals.items(): + print(' ' + var_name.rjust(11, ' '), '==', pretty(var_val)) + + +import struct + +def ser_compact_size(l): + r = b"" + if l < 253: + r = struct.pack("B", l) + elif l < 0x10000: + r = struct.pack("= SECP256K1_ORDER: + raise ValueError + P = lift_x(int_from_bytes(pubkey)) + if P is None: + raise ValueError + Q = point_add(P, point_mul(G, t)) + return 0 if has_even_y(Q) else 1, bytes_from_int(x(Q)) + +def taproot_tweak_seckey(seckey0, h): + seckey0 = int_from_bytes(seckey0) + P = point_mul(G, seckey0) + seckey = seckey0 if has_even_y(P) else SECP256K1_ORDER - seckey0 + t = int_from_bytes(tagged_hash("TapTweak", bytes_from_int(x(P)) + h)) + if t >= SECP256K1_ORDER: + raise ValueError + return bytes_from_int((seckey + t) % SECP256K1_ORDER) + +def taproot_tree_helper(script_tree): + if isinstance(script_tree, tuple): + leaf_version, script = script_tree + h = tagged_hash("TapLeaf", bytes([leaf_version]) + ser_script(script)) + return ([((leaf_version, script), bytes())], h) + left, left_h = taproot_tree_helper(script_tree[0]) + right, right_h = taproot_tree_helper(script_tree[1]) + ret = [(l, c + right_h) for l, c in left] + [(l, c + left_h) for l, c in right] + if right_h < left_h: + left_h, right_h = right_h, left_h + return (ret, tagged_hash("TapBranch", left_h + right_h)) + +def taproot_output_script(internal_pubkey, script_tree): + """Given a internal public key and a tree of scripts, compute the output script. + script_tree is either: + - a (leaf_version, script) tuple (leaf_version is 0xc0 for [[bip-0342.mediawiki|BIP342]] scripts) + - a list of two elements, each with the same structure as script_tree itself + - None + """ + if script_tree is None: + h = bytes() + else: + _, h = taproot_tree_helper(script_tree) + _, output_pubkey = taproot_tweak_pubkey(internal_pubkey, h) + return bytes([0x51, 0x20]) + output_pubkey + + +# Tweak without tag +def tweak_pubkey(pubkey, data: bytes): + assert len(data) == 32 + t = int_from_bytes(data) + if t >= SECP256K1_ORDER: + raise ValueError + P = lift_x(int_from_bytes(pubkey)) + if P is None: + raise ValueError + Q = point_add(P, point_mul(G, t)) + return 0 if has_even_y(Q) else 1, bytes_from_int(x(Q)) diff --git a/test_utils/taproot_sighash.py b/test_utils/taproot_sighash.py new file mode 100644 index 000000000..073be97e2 --- /dev/null +++ b/test_utils/taproot_sighash.py @@ -0,0 +1,85 @@ +# Based on code from the bitcoin's functional test framework, extracted from: +# https://github.com/bitcoin/bitcoin/blob/58446e1d92c7da46da1fc48e1eb5eefe2e0748cb/test/functional/feature_taproot.py +# +# Copyright (c) 2015-2022 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying + + +import struct +from test_utils import sha256 +from test_utils.taproot import ser_string, tagged_hash + + +def BIP341_sha_prevouts(txTo): + return sha256(b"".join(i.prevout.serialize() for i in txTo.vin)) + + +def BIP341_sha_amounts(spent_utxos): + return sha256(b"".join(struct.pack("