Skip to content

Commit f1111d6

Browse files
authored
Mapping (#23)
* always return a key id when accessing `Jwk.kid` * use UserDict instead of dict as base class for Jwk, JwkSet and JwsJson. Accept `Mapping` everywhere instead of `dict`
1 parent 6a73468 commit f1111d6

File tree

20 files changed

+496
-445
lines changed

20 files changed

+496
-445
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ repos:
2626
hooks:
2727
- id: blacken-docs
2828
- repo: https://github.com/astral-sh/ruff-pre-commit
29-
rev: v0.1.11
29+
rev: v0.1.13
3030
hooks:
3131
- id: ruff
3232
args: [ --fix ]
33-
- id: ruff-format
33+
- id: ruff-format
3434
- repo: https://github.com/pre-commit/mirrors-mypy
3535
rev: v1.8.0
3636
hooks:
@@ -44,6 +44,6 @@ repos:
4444
additional_dependencies:
4545
- types-cryptography==3.3.23.2
4646
- pytest-mypy==0.10.3
47-
- binapy==0.7.0
47+
- binapy==0.8.0
4848
- freezegun==1.2.2
4949
- jwcrypto==1.5.0

README.md

Lines changed: 329 additions & 327 deletions
Large diffs are not rendered by default.

jwskate/enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class KeyManagementAlgs:
7878
A128GCMKW = "A128GCMKW"
7979
A192GCMKW = "A192GCMKW"
8080
A256GCMKW = "A256GCMKW"
81-
dir = "dir" # noqa: A003
81+
dir = "dir"
8282

8383
PBES2_HS256_A128KW = "PBES2-HS256+A128KW"
8484
PBES2_HS384_A192KW = "PBES2-HS384+A192KW"

jwskate/jwa/encryption/aesgcm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def encrypt(
5050
if not isinstance(plaintext, bytes):
5151
plaintext = bytes(plaintext)
5252
ciphertext_with_tag = BinaPy(aead.AESGCM(self.key).encrypt(iv, plaintext, aad))
53-
ciphertext, tag = ciphertext_with_tag.cut_at(-self.tag_size)
53+
ciphertext, tag = ciphertext_with_tag.split_at(-self.tag_size)
5454
return ciphertext, tag
5555

5656
def decrypt(

jwskate/jwa/signature/ec.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def sign(self, data: bytes | SupportsBytes) -> BinaPy:
4444
with self.private_key_required() as key:
4545
dss_sig = key.sign(data, ec.ECDSA(self.hashing_alg))
4646
r, s = asymmetric.utils.decode_dss_signature(dss_sig)
47-
return BinaPy.from_int(r, self.curve.coordinate_size) + BinaPy.from_int(s, self.curve.coordinate_size)
47+
return BinaPy.from_int(r, length=self.curve.coordinate_size) + BinaPy.from_int(
48+
s, length=self.curve.coordinate_size
49+
)
4850

4951
@override
5052
def verify(self, data: bytes | SupportsBytes, signature: bytes | SupportsBytes) -> bool:

jwskate/jwe/compact.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,11 @@ def enc(self) -> str:
140140
def encrypt(
141141
cls,
142142
plaintext: bytes | SupportsBytes,
143-
key: Jwk | dict[str, Any] | Any,
143+
key: Jwk | Mapping[str, Any] | Any,
144144
*,
145145
enc: str,
146146
alg: str | None = None,
147-
extra_headers: dict[str, Any] | None = None,
147+
extra_headers: Mapping[str, Any] | None = None,
148148
cek: bytes | None = None,
149149
iv: bytes | None = None,
150150
epk: Jwk | None = None,
@@ -188,7 +188,7 @@ def encrypt(
188188

189189
def unwrap_cek(
190190
self,
191-
key_or_password: Jwk | dict[str, Any] | bytes | str,
191+
key_or_password: Jwk | Mapping[str, Any] | bytes | str,
192192
alg: str | None = None,
193193
algs: Iterable[str] | None = None,
194194
) -> Jwk:
@@ -220,7 +220,7 @@ def unwrap_cek(
220220

221221
def decrypt(
222222
self,
223-
key: Jwk | dict[str, Any] | Any,
223+
key: Jwk | Mapping[str, Any] | Any,
224224
*,
225225
alg: str | None = None,
226226
algs: Iterable[str] | None = None,
@@ -249,7 +249,7 @@ def decrypt(
249249

250250
def decrypt_jwt(
251251
self,
252-
key: Jwk | dict[str, Any] | Any,
252+
key: Jwk | Mapping[str, Any] | Any,
253253
*,
254254
alg: str | None = None,
255255
algs: Iterable[str] | None = None,

jwskate/jwk/base.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import warnings
15+
from copy import copy
1516
from dataclasses import dataclass
1617
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Mapping, SupportsBytes
1718

@@ -154,7 +155,7 @@ def generate_for_kty(cls, kty: str, **kwargs: Any) -> Jwk:
154155
"shake256": "shake256",
155156
}
156157

157-
def __new__(cls, key: Jwk | dict[str, Any] | Any, **kwargs: Any) -> Jwk:
158+
def __new__(cls, key: Jwk | Mapping[str, Any] | Any, **kwargs: Any) -> Jwk:
158159
"""Overridden `__new__` to make the Jwk constructor smarter.
159160
160161
The `Jwk` constructor will accept:
@@ -171,7 +172,7 @@ def __new__(cls, key: Jwk | dict[str, Any] | Any, **kwargs: Any) -> Jwk:
171172
if cls == Jwk:
172173
if isinstance(key, Jwk):
173174
return cls.from_cryptography_key(key.cryptography_key, **kwargs)
174-
if isinstance(key, dict):
175+
if isinstance(key, Mapping):
175176
kty: str | None = key.get("kty")
176177
if kty is None:
177178
msg = "A Json Web Key must have a Key Type (kty)"
@@ -188,9 +189,9 @@ def __new__(cls, key: Jwk | dict[str, Any] | Any, **kwargs: Any) -> Jwk:
188189
return cls.from_json(key)
189190
else:
190191
return cls.from_cryptography_key(key, **kwargs)
191-
return super().__new__(cls, key, **kwargs)
192+
return super().__new__(cls)
192193

193-
def __init__(self, params: dict[str, Any] | Any, *, include_kid_thumbprint: bool = False):
194+
def __init__(self, params: Mapping[str, Any] | Any, *, include_kid_thumbprint: bool = False):
194195
if isinstance(params, dict): # this is to avoid double init due to the __new__ above
195196
super().__init__({key: val for key, val in params.items() if val is not None})
196197
self._validate()
@@ -275,7 +276,8 @@ def __setitem__(self, key: str, value: Any) -> None:
275276
RuntimeError: when trying to modify cryptographic attributes
276277
277278
"""
278-
if key in self.PARAMS:
279+
# don't allow modifying private attributes after the key has been initialized
280+
if key in self.PARAMS and hasattr(self, "cryptography_key"):
279281
msg = "JWK key attributes cannot be modified."
280282
raise RuntimeError(msg)
281283
super().__setitem__(key, value)
@@ -305,12 +307,18 @@ def alg(self) -> str | None:
305307
return alg
306308

307309
@property
308-
def kid(self) -> str | None:
309-
"""Return the JWK key ID (kid), if present."""
310+
def kid(self) -> str:
311+
"""Return the JWK key ID (kid).
312+
313+
If the kid is not explicitly set, the RFC7638 key thumbprint is returned.
314+
315+
"""
310316
kid = self.get("kid")
311317
if kid is not None and not isinstance(kid, str): # pragma: no branch
312318
msg = f"invalid kid type {type(kid)}"
313319
raise TypeError(msg, kid)
320+
if kid is None:
321+
return self.thumbprint()
314322
return kid
315323

316324
@property
@@ -1220,7 +1228,7 @@ def copy(self) -> Jwk:
12201228
a copy of this key, with the same value
12211229
12221230
"""
1223-
return Jwk(super().copy())
1231+
return Jwk(copy(self.data))
12241232

12251233
def with_kid_thumbprint(self, *, force: bool = False) -> Jwk:
12261234
"""Include the JWK thumbprint as `kid`.

jwskate/jwk/ec.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ def private(cls, *, crv: str, x: int, y: int, d: int, **params: Any) -> ECJwk:
150150
dict(
151151
kty=cls.KTY,
152152
crv=crv,
153-
x=BinaPy.from_int(x, coord_size).to("b64u").ascii(),
154-
y=BinaPy.from_int(y, coord_size).to("b64u").ascii(),
155-
d=BinaPy.from_int(d, coord_size).to("b64u").ascii(),
153+
x=BinaPy.from_int(x, length=coord_size).to("b64u").ascii(),
154+
y=BinaPy.from_int(y, length=coord_size).to("b64u").ascii(),
155+
d=BinaPy.from_int(d, length=coord_size).to("b64u").ascii(),
156156
**{k: v for k, v in params.items() if v is not None},
157157
)
158158
)
@@ -218,12 +218,12 @@ def from_cryptography_key(cls, cryptography_key: Any, **kwargs: Any) -> ECJwk:
218218
msg = f"Unsupported Curve {cryptography_key.curve.name}"
219219
raise NotImplementedError(msg)
220220

221-
x = BinaPy.from_int(public_numbers.x, crv.coordinate_size).to("b64u").ascii()
222-
y = BinaPy.from_int(public_numbers.y, crv.coordinate_size).to("b64u").ascii()
221+
x = BinaPy.from_int(public_numbers.x, length=crv.coordinate_size).to("b64u").ascii()
222+
y = BinaPy.from_int(public_numbers.y, length=crv.coordinate_size).to("b64u").ascii()
223223
parameters = {"kty": KeyTypes.EC, "crv": crv.name, "x": x, "y": y}
224224
if isinstance(cryptography_key, ec.EllipticCurvePrivateKey):
225225
pn = cryptography_key.private_numbers() # type: ignore[attr-defined]
226-
d = BinaPy.from_int(pn.private_value, crv.coordinate_size).to("b64u").ascii()
226+
d = BinaPy.from_int(pn.private_value, length=crv.coordinate_size).to("b64u").ascii()
227227
parameters["d"] = d
228228

229229
return cls(parameters)

jwskate/jwk/jwks.py

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Iterable
5+
from typing import Any, Iterable, Mapping
6+
7+
from typing_extensions import override
68

79
from jwskate.token import BaseJsonDict
810

@@ -17,8 +19,8 @@ class JwkSet(BaseJsonDict):
1719
methods to get the keys, add or remove keys, and verify signatures
1820
using keys from this set.
1921
20-
- a `dict` from the parsed JSON object representing this JwkSet (in paramter `jwks`)
21-
- a list of `Jwk` (in parameter `keys`
22+
- a `dict` from the parsed JSON object representing this JwkSet (in parameter `jwks`)
23+
- a list of `Jwk` (in parameter `keys`)
2224
- nothing, to initialize an empty JwkSet
2325
2426
Args:
@@ -29,21 +31,23 @@ class JwkSet(BaseJsonDict):
2931

3032
def __init__(
3133
self,
32-
jwks: dict[str, Any] | None = None,
33-
keys: Iterable[Jwk | dict[str, Any]] | None = None,
34+
jwks: Mapping[str, Any] | None = None,
35+
keys: Iterable[Jwk | Mapping[str, Any]] | None = None,
3436
):
35-
if jwks is None and keys is None:
36-
keys = []
37-
38-
if jwks is not None:
39-
keys = jwks.pop("keys", [])
40-
super().__init__(jwks) # init the dict with all the dict content that is not keys
37+
super().__init__({k: v for k, v in jwks.items() if k != "keys"} if jwks else {})
38+
if keys is None and jwks is not None and "keys" in jwks:
39+
keys = jwks.get("keys")
40+
if keys:
41+
for key in keys:
42+
self.add_jwk(key)
43+
44+
@override
45+
def __setitem__(self, name: str, value: Any) -> None:
46+
if name == "keys":
47+
for key in value:
48+
self.add_jwk(key)
4149
else:
42-
super().__init__()
43-
44-
if keys is not None:
45-
for jwk in keys:
46-
self.add_jwk(jwk)
50+
super().__setitem__(name, value)
4751

4852
@property
4953
def jwks(self) -> list[Jwk]:
@@ -53,7 +57,7 @@ def jwks(self) -> list[Jwk]:
5357
a list of `Jwk`
5458
5559
"""
56-
return self.get("keys", []) # type: ignore[no-any-return]
60+
return self.get("keys", [])
5761

5862
def get_jwk_by_kid(self, kid: str) -> Jwk:
5963
"""Return a Jwk from this JwkSet, based on its kid.
@@ -84,35 +88,23 @@ def __len__(self) -> int:
8488

8589
def add_jwk(
8690
self,
87-
key: Jwk | dict[str, Any] | Any,
88-
kid: str | None = None,
89-
use: str | None = None,
91+
key: Jwk | Mapping[str, Any] | Any,
9092
) -> str:
9193
"""Add a Jwk in this JwkSet.
9294
9395
Args:
9496
key: the Jwk to add (either a `Jwk` instance, or a dict containing the Jwk parameters)
95-
kid: the kid to use, if `jwk` doesn't contain one
96-
use: the defined use for the added Jwk
9797
9898
Returns:
99-
the kid from the added Jwk (it may be generated if no kid is provided)
99+
the key ID. It will be generated if missing from the given Jwk.
100100
101101
"""
102-
key = to_jwk(key)
103-
104-
self.setdefault("keys", [])
102+
key = to_jwk(key).with_kid_thumbprint()
105103

106-
kid = key.get("kid", kid)
107-
if not kid:
108-
kid = key.thumbprint()
109-
key["kid"] = kid
110-
use = key.get("use", use)
111-
if use:
112-
key["use"] = use
113-
self.jwks.append(key)
104+
self.data.setdefault("keys", [])
105+
self.data["keys"].append(key)
114106

115-
return kid
107+
return key.kid
116108

117109
def remove_jwk(self, kid: str) -> None:
118110
"""Remove a Jwk from this JwkSet, based on a `kid`.
@@ -198,7 +190,7 @@ def verify(
198190
jwk = self.get_jwk_by_kid(kid)
199191
return jwk.verify(data, signature, alg=alg, algs=algs)
200192

201-
# otherwise, try all keys which support the given alg(s)
193+
# otherwise, try all keys that support the given alg(s)
202194
if algs is None:
203195
if alg is not None:
204196
algs = (alg,)

jwskate/jws/compact.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from functools import cached_property
6-
from typing import TYPE_CHECKING, Any, Iterable, SupportsBytes
6+
from typing import TYPE_CHECKING, Any, Iterable, Mapping, SupportsBytes
77

88
from binapy import BinaPy
99
from typing_extensions import Self
@@ -62,9 +62,9 @@ def __init__(self, value: bytes | str, max_size: int = 16 * 1024):
6262
def sign(
6363
cls,
6464
payload: bytes | SupportsBytes,
65-
key: Jwk | dict[str, Any] | Any,
65+
key: Jwk | Mapping[str, Any] | Any,
6666
alg: str | None = None,
67-
extra_headers: dict[str, Any] | None = None,
67+
extra_headers: Mapping[str, Any] | None = None,
6868
) -> JwsCompact:
6969
"""Sign a payload and returns the resulting JwsCompact.
7070
@@ -132,7 +132,7 @@ def signed_part(self) -> bytes:
132132

133133
def verify_signature(
134134
self,
135-
key: Jwk | dict[str, Any] | Any,
135+
key: Jwk | Mapping[str, Any] | Any,
136136
*,
137137
alg: str | None = None,
138138
algs: Iterable[str] | None = None,
@@ -153,7 +153,7 @@ def verify_signature(
153153

154154
def verify(
155155
self,
156-
key: Jwk | dict[str, Any] | Any,
156+
key: Jwk | Mapping[str, Any] | Any,
157157
*,
158158
alg: str | None = None,
159159
algs: Iterable[str] | None = None,

0 commit comments

Comments
 (0)