Skip to content

Commit

Permalink
Add tests for import of X448 keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Sep 8, 2024
1 parent ae90a3a commit cbed604
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 8 deletions.
46 changes: 45 additions & 1 deletion lib/Crypto/Protocol/DH.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from Crypto.Util.number import long_to_bytes
from Crypto.PublicKey.ECC import (EccKey,
construct,
_import_curve25519_public_key)
_import_curve25519_public_key,
_import_curve448_public_key)


def _compute_ecdh(key_priv, key_pub):
Expand All @@ -11,6 +12,8 @@ def _compute_ecdh(key_priv, key_pub):

if key_priv.curve == "Curve25519":
z = bytearray(pointP.x.to_bytes(32, byteorder='little'))
elif key_priv.curve == "Curve448":
z = bytearray(pointP.x.to_bytes(56, byteorder='little'))
else:
# See Section 5.7.1.2 in NIST SP 800-56Ar3
z = long_to_bytes(pointP.x, pointP.size_in_bytes())
Expand Down Expand Up @@ -58,6 +61,47 @@ def import_x25519_private_key(encoded):
return construct(seed=encoded, curve="Curve25519")


def import_x448_public_key(encoded):
"""Create a new X448 public key object,
starting from the key encoded as raw ``bytes``,
in the format described in RFC7748.
Args:
encoded (bytes):
The x448 public key to import.
It must be 56 bytes.
Returns:
:class:`Crypto.PublicKey.EccKey` : a new ECC key object.
Raises:
ValueError: when the given key cannot be parsed.
"""

x = _import_curve448_public_key(encoded)
return construct(curve='Curve448', point_x=x)


def import_x448_private_key(encoded):
"""Create a new X448 private key object,
starting from the key encoded as raw ``bytes``,
in the format described in RFC7748.
Args:
encoded (bytes):
The X448 private key to import.
It must be 56 bytes.
Returns:
:class:`Crypto.PublicKey.EccKey` : a new ECC key object.
Raises:
ValueError: when the given key cannot be parsed.
"""

return construct(seed=encoded, curve="Curve448")


def key_agreement(**kwargs):
"""Perform a Diffie-Hellman key agreement.
Expand Down
100 changes: 93 additions & 7 deletions lib/Crypto/SelfTest/Protocol/test_ecdh.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,18 +358,103 @@ def test_weak(self):
# it will set the MSB to zero (as required by RFC7748, Section 5),
# therefore leading to another public key (and to a point which is
# not of low order anymore).
#"cdeb7a7c3b41b8ae1656e3faf19fc46ada098deb9c32b1fd866205165f49b880",
#"4c9c95bca3508c24b1d0b1559c83ef5b04445cc4581c8e86d8224eddd09f11d7",
#"d9ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
#"daffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
#"dbffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
# "cdeb7a7c3b41b8ae1656e3faf19fc46ada098deb9c32b1fd866205165f49b880",
# "4c9c95bca3508c24b1d0b1559c83ef5b04445cc4581c8e86d8224eddd09f11d7",
# "d9ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
# "daffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
# "dbffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
)

for x in weak_keys:
self.assertRaises(ValueError,
DH.import_x25519_public_key,
unhexlify(x))


class X448_Tests(unittest.TestCase):

def test_rfc7748_1(self):
tvs = (
("3d262fddf9ec8e88495266fea19a34d28882acef045104d0d1aae121700a779c984c24f8cdd78fbff44943eba368f54b29259a4f1c600ad3",
"06fce640fa3487bfda5f6cf2d5263f8aad88334cbd07437f020f08f9814dc031ddbdc38c19c6da2583fa5429db94ada18aa7a7fb4ef8a086",
"ce3e4ff95a60dc6697da1db1d85e6afbdf79b50a2412d7546d5f239fe14fbaadeb445fc66a01b0779d98223961111e21766282f73dd96b6f"),
("203d494428b8399352665ddca42f9de8fef600908e0d461cb021f8c538345dd77c3e4806e25f46d3315c44e0a5b4371282dd2c8d5be3095f",
"0fbcc2f993cd56d3305b0b7d9e55d4c1a8fb5dbb52f8e9a1e9b6201b165d015894e56c4d3570bee52fe205e28a78b91cdfbde71ce8d157db",
"884a02576239ff7a2f2f63b2db6a9ff37047ac13568e1e30fe63c4a7ad1b3ee3a5700df34321d62077e63633c575c1c954514e99da7c179d"),
)

for tv1, tv2, tv3 in tvs:
priv_key = DH.import_x448_private_key(unhexlify(tv1))
pub_key = DH.import_x448_public_key(unhexlify(tv2))
result = key_agreement(static_pub=pub_key,
static_priv=priv_key,
kdf=lambda x: x)
self.assertEqual(result, unhexlify(tv3))

def test_rfc7748_2(self):
k = unhexlify("0500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")

priv_key = DH.import_x448_private_key(k)
pub_key = DH.import_x448_public_key(k)
result = key_agreement(static_pub=pub_key,
static_priv=priv_key,
kdf=lambda x: x)
self.assertEqual(
result,
unhexlify("3f482c8a9f19b01e6c46ee9711d9dc14fd4bf67af30765c2ae2b846a4d23a8cd0db897086239492caf350b51f833868b9bc2b3bca9cf4113")
)

for _ in range(999):
priv_key = DH.import_x448_private_key(result)
pub_key = DH.import_x448_public_key(k)
k = result
result = key_agreement(static_pub=pub_key,
static_priv=priv_key,
kdf=lambda x: x)

self.assertEqual(
result,
unhexlify("aa3b4749d55b9daf1e5b00288826c467274ce3ebbdd5c17b975e09d4af6c67cf10d087202db88286e2b79fceea3ec353ef54faa26e219f38")
)

def test_rfc7748_3(self):
tv1 = "9a8f4925d1519f5775cf46b04b5800d4ee9ee8bae8bc5565d498c28dd9c9baf574a9419744897391006382a6f127ab1d9ac2d8c0a598726b"
tv2 = "9b08f7cc31b7e3e67d22d5aea121074a273bd2b83de09c63faa73d2c22c5d9bbc836647241d953d40c5b12da88120d53177f80e532c41fa0"
tv3 = "1c306a7ac2a0e2e0990b294470cba339e6453772b075811d8fad0d1d6927c120bb5ee8972b0d3e21374c9c921b09d1b0366f10b65173992d"
tv4 = "3eb7a829b0cd20f5bcfc0b599b6feccf6da4627107bdb0d4f345b43027d8b972fc3e34fb4232a13ca706dcb57aec3dae07bdc1c67bf33609"
tv5 = "07fff4181ac6cc95ec1c16a94a0f74d12da232ce40a77552281d282bb60c0b56fd2464c335543936521c24403085d59a449a5037514a879d"

alice_priv_key = DH.import_x448_private_key(unhexlify(tv1))
alice_pub_key = DH.import_x448_public_key(unhexlify(tv2))
bob_priv_key = DH.import_x448_private_key(unhexlify(tv3))
bob_pub_key = DH.import_x448_public_key(unhexlify(tv4))
secret = unhexlify(tv5)

result1 = key_agreement(static_pub=alice_pub_key,
static_priv=bob_priv_key,
kdf=lambda x: x)
result2 = key_agreement(static_pub=bob_pub_key,
static_priv=alice_priv_key,
kdf=lambda x: x)
self.assertEqual(result1, secret)
self.assertEqual(result2, secret)

def test_weak(self):

weak_keys = (
"0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"0100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"fefffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffffffffffffffffffffffffffffffffffffffffffffffffff",
"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffffffffffffffffffffffffffffffffffffffffffffffffff",
"00000000000000000000000000000000000000000000000000000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
)

for x in weak_keys:
self.assertRaises(ValueError,
DH.import_x448_public_key,
unhexlify(x))


class TestVectorsXECDHWycheproof(unittest.TestCase):

desc = "Wycheproof XECDH tests"
Expand Down Expand Up @@ -444,7 +529,7 @@ def test_verify(self, tv):
assert not tv.valid
assert "Unsupported ECC" in str(e)
return
except ValueError as e:
except ValueError:
assert tv.valid
assert tv.warning
assert "LowOrderPublic" in tv.flags
Expand Down Expand Up @@ -487,7 +572,7 @@ def base64url_decode(input_str):
else:
assert "Incorrect length" in str(e)
return
except ValueError as e:
except ValueError:
assert tv.valid
else:
raise ValueError("Unknown encoding", tv.encoding)
Expand Down Expand Up @@ -518,6 +603,7 @@ def get_tests(config={}):
tests += [TestVectorsECDHWycheproof()]
tests += list_test_cases(ECDH_Tests)
tests += list_test_cases(X25519_Tests)
tests += list_test_cases(X448_Tests)
tests += [TestVectorsXECDHWycheproof()]

slow_tests = config.get('slow_tests')
Expand Down

0 comments on commit cbed604

Please sign in to comment.