Skip to content

Commit fc2a108

Browse files
Merge pull request #11 from bitovi/more-pr-feedback
More pr feedback
2 parents 524e430 + 71bc2f0 commit fc2a108

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

encryption_jwt/codec_server.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import requests
77
from aiohttp import hdrs, web
88
from google.protobuf import json_format
9+
from jwt import PyJWK
910
from jwt.algorithms import RSAAlgorithm
1011
from temporalio.api.cloud.cloudservice.v1 import GetUsersRequest
1112
from temporalio.api.common.v1 import Payloads
@@ -18,9 +19,9 @@
1819

1920
TEMPORAL_CLIENT_CLOUD_API_VERSION = "2024-05-13-00"
2021

21-
temporal_ops_address = "saas-api.tmprl.cloud:443"
22-
if os.environ.get("TEMPORAL_OPS_ADDRESS"):
23-
temporal_ops_address = os.environ.get("TEMPORAL_OPS_ADDRESS")
22+
temporal_ops_address = (
23+
os.environ.get("TEMPORAL_OPS_ADDRESS") or "saas-api.tmprl.cloud:443"
24+
)
2425

2526

2627
def build_codec_server() -> web.Application:
@@ -76,8 +77,8 @@ async def decryption_authorized(email: str, namespace: str) -> bool:
7677

7778
def make_handler(fn: str):
7879
async def handler(req: web.Request):
79-
namespace = req.headers.get("x-namespace")
80-
auth_header = req.headers.get("Authorization")
80+
namespace = req.headers.get("x-namespace") or "default"
81+
auth_header = req.headers.get("Authorization") or ""
8182
_bearer, encoded = auth_header.split(" ")
8283

8384
# Extract the kid from the Auth header
@@ -90,20 +91,20 @@ async def handler(req: web.Request):
9091
jwks = requests.get(jwks_url).json()
9192

9293
# Extract Temporal Cloud's public key
93-
public_key = None
94+
pyjwk = None
9495
for key in jwks["keys"]:
9596
if key["kid"] == kid:
9697
# Convert JWKS key to PEM format
97-
public_key = RSAAlgorithm.from_jwk(key)
98+
pyjwk = PyJWK.from_dict(key)
9899
break
99100

100-
if public_key is None:
101+
if pyjwk is None:
101102
raise ValueError("Public key not found in JWKS")
102103

103104
# Decode the jwt, verifying against Temporal Cloud's public key
104105
decoded = jwt.decode(
105106
encoded,
106-
public_key,
107+
pyjwk.key,
107108
algorithms=[algorithm],
108109
audience=[
109110
"https://saas-api.tmprl.cloud",
@@ -156,7 +157,7 @@ async def handler(req: web.Request):
156157
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
157158
ssl_context.check_hostname = False
158159
ssl_context.load_cert_chain(
159-
os.environ.get("SSL_PEM"), os.environ.get("SSL_KEY")
160+
os.environ.get("SSL_PEM") or "", os.environ.get("SSL_KEY") or ""
160161
)
161162

162163
web.run_app(

encryption_jwt/encryptor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import base64
23
import logging
34
import os
@@ -38,16 +39,20 @@ async def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
3839

3940
nonce = os.urandom(12)
4041
encryptor = AESGCM(data_key_plaintext)
41-
return nonce + encryptor.encrypt(nonce, data, None), base64.b64encode(
42-
data_key_encrypted
42+
encrypted = asyncio.get_running_loop().run_in_executor(
43+
None, encryptor.encrypt, nonce, data, None
4344
)
45+
return nonce + await encrypted, base64.b64encode(data_key_encrypted)
4446

4547
async def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes:
4648
"""Encrypt data using a key from KMS."""
4749
data_key_encrypted = base64.b64decode(data_key_encrypted_base64)
4850
data_key_plaintext = await self.__decrypt_data_key(data_key_encrypted)
4951
encryptor = AESGCM(data_key_plaintext)
50-
return encryptor.decrypt(data[:12], data[12:], None)
52+
decrypted = await asyncio.get_running_loop().run_in_executor(
53+
None, encryptor.decrypt, data[:12], data[12:], None
54+
)
55+
return decrypted
5156

5257
async def __create_data_key(self, namespace: str):
5358
"""Get a set of keys from AWS KMS that can be used to encrypt data."""

0 commit comments

Comments
 (0)