Skip to content

Commit

Permalink
In ECC, drive logic with enums, not strings
Browse files Browse the repository at this point in the history
  • Loading branch information
Legrandin committed Aug 17, 2024
1 parent 35fcd99 commit e5ee493
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 29 deletions.
42 changes: 20 additions & 22 deletions lib/Crypto/PublicKey/ECC.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from Crypto.Random import get_random_bytes

from ._point import EccPoint, EccXPoint, _curves
from ._point import CurveID as _CurveID


class UnsupportedEccFeature(ValueError):
Expand Down Expand Up @@ -124,7 +125,7 @@ def __init__(self, **kwargs):
# NIST P curves work with d, EdDSA works with seed

# RFC 8032, 5.1.5
if self._curve.name == "ed25519":
if self._curve.id == _CurveID.ED25519:
if self._d is not None:
raise ValueError("Parameter d can only be used with NIST P curves")
if len(self._seed) != 32:
Expand All @@ -136,7 +137,7 @@ def __init__(self, **kwargs):
tmp[31] = (tmp[31] & 0x7F) | 0x40
self._d = Integer.from_bytes(tmp, byteorder='little')
# RFC 8032, 5.2.5
elif self._curve.name == "ed448":
elif self._curve.id == _CurveID.ED448:
if self._d is not None:
raise ValueError("Parameter d can only be used with NIST P curves")
if len(self._seed) != 57:
Expand All @@ -149,7 +150,7 @@ def __init__(self, **kwargs):
tmp[56] = 0
self._d = Integer.from_bytes(tmp, byteorder='little')
# RFC 7748, 5
elif self._curve.name == "curve25519":
elif self._curve.id == _CurveID.CURVE25519:
if self._d is not None:
raise ValueError("Parameter d can only be used with NIST P curves")
if len(self._seed) != 32:
Expand All @@ -165,9 +166,6 @@ def __init__(self, **kwargs):
if not 1 <= self._d < self._curve.order:
raise ValueError("Parameter d must be an integer smaller than the curve order")

def _is_eddsa(self):
return self._curve.desc in ("Ed25519", "Ed448")

def __eq__(self, other):
if not isinstance(other, EccKey):
return False
Expand All @@ -179,13 +177,13 @@ def __eq__(self, other):

def __repr__(self):
if self.has_private():
if self._is_eddsa():
if self._curve.is_edwards:
extra = ", seed=%s" % tostr(binascii.hexlify(self._seed))
else:
extra = ", d=%d" % int(self._d)
else:
extra = ""
if self._curve.name == "curve25519":
if self._curve.id == _CurveID.CURVE25519:
x = self.pointQ.x
result = "EccKey(curve='%s', point_x=%d%s)" % (self._curve.desc, x, extra)
else:
Expand Down Expand Up @@ -249,7 +247,7 @@ def public_key(self):
return EccKey(curve=self._curve.desc, point=self.pointQ)

def _export_SEC1(self, compress):
if self._curve.desc in ("Ed25519", "Ed448", "Curve25519"):
if not self._curve.is_weierstrass:
raise ValueError("SEC1 format is only supported for NIST P curves")

# See 2.2 in RFC5480 and 2.3.3 in SEC1
Expand Down Expand Up @@ -278,29 +276,29 @@ def _export_SEC1(self, compress):

def _export_eddsa_public(self):
x, y = self.pointQ.xy
if self._curve.name == "ed25519":
if self._curve.id == _CurveID.ED25519:
result = bytearray(y.to_bytes(32, byteorder='little'))
result[31] = ((x & 1) << 7) | result[31]
elif self._curve.name == "ed448":
elif self._curve.id == _CurveID.ED448:
result = bytearray(y.to_bytes(57, byteorder='little'))
result[56] = (x & 1) << 7
else:
raise ValueError("Not an EdDSA key to export")
return bytes(result)

def _export_montgomery_public(self):
if self._curve.desc != "Curve25519":
if not self._curve.is_montgomery:
raise ValueError("Not a Montgomery key to export")
x = self.pointQ.x
result = bytearray(x.to_bytes(32, byteorder='little'))
return bytes(result)

def _export_subjectPublicKeyInfo(self, compress):
if self._is_eddsa():
if self._curve.is_edwards:
oid = self._curve.oid
public_key = self._export_eddsa_public()
params = None
elif self._curve.desc == "Curve25519":
elif self._curve.is_montgomery:
oid = self._curve.oid
public_key = self._export_montgomery_public()
params = None
Expand Down Expand Up @@ -523,9 +521,9 @@ def export_key(self, **kwargs):

use_pkcs8 = args.pop("use_pkcs8", True)
if use_pkcs8 is False:
if self._is_eddsa():
if self._curve.is_edwards:
raise ValueError("'pkcs8' must be True for EdDSA curves")
if self._curve.desc == "Curve25519":
if self._curve.is_montgomery:
raise ValueError("'pkcs8' must be True for Curve25519")
if 'protection' in args:
raise ValueError("'protection' is only supported for PKCS#8")
Expand Down Expand Up @@ -559,9 +557,9 @@ def export_key(self, **kwargs):
elif ext_format == "SEC1":
return self._export_SEC1(compress)
elif ext_format == "raw":
if self._curve.name in ('ed25519', 'ed448'):
if self._curve.is_edwards:
return self._export_eddsa_public()
elif self._curve.name in ('curve25519',):
elif self._curve.is_montgomery:
return self._export_montgomery_public()
else:
return self._export_SEC1(compress)
Expand All @@ -588,13 +586,13 @@ def generate(**kwargs):
if kwargs:
raise TypeError("Unknown parameters: " + str(kwargs))

if _curves[curve_name].name == "ed25519":
if _curves[curve_name].id == _CurveID.ED25519:
seed = randfunc(32)
new_key = EccKey(curve=curve_name, seed=seed)
elif _curves[curve_name].name == "ed448":
elif _curves[curve_name].id == _CurveID.ED448:
seed = randfunc(57)
new_key = EccKey(curve=curve_name, seed=seed)
elif _curves[curve_name].name == "curve25519":
elif _curves[curve_name].id == _CurveID.CURVE25519:
seed = randfunc(32)
new_key = EccKey(curve=curve_name, seed=seed)
_validate_x25519_public_key(new_key)
Expand Down Expand Up @@ -647,7 +645,7 @@ def construct(**kwargs):
if "point" in kwargs:
raise TypeError("Unknown keyword: point")

if curve.desc == "Curve25519":
if curve.id == _CurveID.CURVE25519:

if point_x is not None:
kwargs["point"] = EccXPoint(point_x, curve_name)
Expand Down
34 changes: 27 additions & 7 deletions lib/Crypto/PublicKey/_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
from Crypto.Random.random import getrandbits


class CurveID(object):
P192 = 1
P224 = 2
P256 = 3
P384 = 4
P521 = 5
ED25519 = 6
ED448 = 7
CURVE25519 = 8


class _Curves(object):

curves = {}
Expand Down Expand Up @@ -42,34 +53,42 @@ def load(self, name):
if name in self.p192_names:
from . import _nist_ecc
p192 = _nist_ecc.p192_curve()
p192.id = CurveID.P192
self.curves.update(dict.fromkeys(self.p192_names, p192))
elif name in self.p224_names:
from . import _nist_ecc
p224 = _nist_ecc.p224_curve()
p224.id = CurveID.P224
self.curves.update(dict.fromkeys(self.p224_names, p224))
elif name in self.p256_names:
from . import _nist_ecc
p256 = _nist_ecc.p256_curve()
p256.id = CurveID.P256
self.curves.update(dict.fromkeys(self.p256_names, p256))
elif name in self.p384_names:
from . import _nist_ecc
p384 = _nist_ecc.p384_curve()
p384.id = CurveID.P384
self.curves.update(dict.fromkeys(self.p384_names, p384))
elif name in self.p521_names:
from . import _nist_ecc
p521 = _nist_ecc.p521_curve()
p521.id = CurveID.P521
self.curves.update(dict.fromkeys(self.p521_names, p521))
elif name in self.ed25519_names:
from . import _edwards
ed25519 = _edwards.ed25519_curve()
ed25519.id = CurveID.ED25519
self.curves.update(dict.fromkeys(self.ed25519_names, ed25519))
elif name in self.ed448_names:
from . import _edwards
ed448 = _edwards.ed448_curve()
ed448.id = CurveID.ED448
self.curves.update(dict.fromkeys(self.ed448_names, ed448))
elif name in self.curve25519_names:
from . import _montgomery
curve25519 = _montgomery.curve25519_curve()
curve25519.id = CurveID.CURVE25519
self.curves.update(dict.fromkeys(self.curve25519_names, curve25519))
else:
raise ValueError("Unsupported curve '%s'" % name)
Expand All @@ -84,6 +103,10 @@ def __getitem__(self, name):
curve.G = EccXPoint(curve.Gx, name)
else:
curve.G = EccPoint(curve.Gx, curve.Gy, name)
curve.is_edwards = curve.id in (CurveID.ED25519, CurveID.ED448)
curve.is_montgomery = curve.id in (CurveID.CURVE25519,)
curve.is_weierstrass = not (curve.is_edwards or
curve.is_montgomery)
return curve

def items(self):
Expand Down Expand Up @@ -125,7 +148,7 @@ def __init__(self, x, y, curve="p256"):
raise ValueError("Unknown curve name %s" % str(curve))
self._curve_name = curve

if self._curve.desc == "Curve25519":
if self._curve.id == CurveID.CURVE25519:
raise ValueError("EccPoint cannot be created for Curve25519")

modulus_bytes = self.size_in_bytes()
Expand Down Expand Up @@ -197,21 +220,18 @@ def copy(self):
np = EccPoint(x, y, self._curve_name)
return np

def _is_eddsa(self):
return self._curve.name in ("ed25519", "ed448")

def is_point_at_infinity(self):
"""``True`` if this is the *point-at-infinity*."""

if self._curve.name in ("ed25519", "ed448"):
if self._curve.is_edwards:
return self.x == 0
else:
return self.xy == (0, 0)

def point_at_infinity(self):
"""Return the *point-at-infinity* for the curve."""

if self._curve.name in ("ed25519", "ed448"):
if self._curve.is_edwards:
return EccPoint(0, 1, self._curve_name)
else:
return EccPoint(0, 0, self._curve_name)
Expand Down Expand Up @@ -328,7 +348,7 @@ def __init__(self, x, curve):
raise ValueError("Unknown curve name %s" % str(curve))
self._curve_name = curve

if self._curve.desc != "Curve25519":
if self._curve.id != CurveID.CURVE25519:
raise ValueError("EccXPoint can only be created for Curve25519")

modulus_bytes = self.size_in_bytes()
Expand Down

0 comments on commit e5ee493

Please sign in to comment.