Skip to content

Commit

Permalink
WIP batch message endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
steventux committed Jan 23, 2025
1 parent 49cc82a commit 7af6d7b
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/notify/app/route_handlers/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from flask import request
import app.validators.request_validator as request_validator
import app.services.message_batch_dispatcher as message_batch_dispatcher


def batch():
json_data = request.json or {}
valid_headers, error_message = request_validator.verify_headers(dict(request.headers))

if not valid_headers:
return {"status": "failed", "error": error_message}, 401

if not request_validator.verify_signature(dict(request.headers), json_data):
return {"status": "failed", "error": "Invalid signature"}, 403

valid_body, error_message = request_validator.verify_body(json_data)

if not valid_body:
return {"status": "failed", "error": error_message}, 422

success, response = message_batch_dispatcher.dispatch(json_data)
if success:
return {"status": "success", "response": response}, 200
else:
return {"status": "failed", "error": response}, 500
28 changes: 28 additions & 0 deletions src/notify/app/services/message_batch_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import app.utils.access_token as access_token
import logging
import os
import requests
import uuid


def dispatch(body: dict) -> tuple[bool, str]:
response = requests.post(url(), json=body, headers=headers())
logging.info(f"Response from Notify API {url()}: {response.status_code}")

if response.status_code == 201:
return True, response.text
else:
return False, response.text


def headers() -> dict:
return {
"content-type": "application/vnd.api+json",
"accept": "application/vnd.api+json",
"x-correlation-id": str(uuid.uuid4()),
"authorization": "Bearer " + access_token.get_token(),
}


def url() -> str:
return f"{os.getenv('NOTIFY_API_URL')}/comms/v1/message-batches"
48 changes: 48 additions & 0 deletions src/notify/app/services/message_batch_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import app.utils.database as database
from collections import defaultdict
import database.models as models
from itertools import chain
from sqlalchemy.orm import Session


def save_batch(data, response, status) -> tuple[bool, str]:
try:
with Session(database.engine()) as session:
message_batch = models.MessageBatch(
batch_id=data["id"],
batch_reference=data["attributes"]["batch_reference"],
details=data,
response=response,
status=status,
)
session.add(message_batch)
session.flush()

for message in merged_messages(data, response):
message = models.Message(
batch_id=message_batch.id,
details=message,
message_id=message["id"],
message_reference=message["message_reference"],
nhs_number=message["recipient"]["nhs_number"],
recipient_id=message["attributes"]["recipient_id"],
)
session.add(message)

session.commit()

return True, f"Batch #{message_batch.id} saved successfully"
except Exception as e:
return False, str(e)


def merged_messages(data: dict, response: dict) -> list[dict]:
message_chain = chain(
data["attributes"]["messages"],
response["messages"]
)
collector = defaultdict(dict)
for collectible in message_chain:
collector[collectible["messageReference"]].update(collectible.items())

return list(collector.values())
80 changes: 80 additions & 0 deletions src/notify/app/utils/access_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import datetime
import jwt
import logging
import os
import requests
import time
import uuid


def get_token() -> str:
if not os.getenv("OAUTH2_API_KEY"):
return "awaiting-token"

jwt: str = generate_auth_jwt()
headers: dict = {"Content-Type": "application/x-www-form-urlencoded"}

body = {
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": jwt,
}

response = requests.post(
os.getenv("OAUTH2_TOKEN_URL"),
data=body,
headers=headers,
)
logging.info(f"Response from OAuth2 token provider: {response.status_code}")
response_json = response.json()

if response.status_code == 200:
access_token = response_json["access_token"]
else:
access_token = ""
logging.error("Failed to get access token")
logging.error(response_json)

return access_token


def generate_auth_jwt() -> str:
algorithm: str = "RS512"
headers: dict = {
"alg": algorithm,
"typ": "JWT",
"kid": os.getenv("OAUTH2_API_KID")
}
api_key: str = os.getenv("OAUTH2_API_KEY")

payload: dict = {
"sub": api_key,
"iss": api_key,
"jti": str(uuid.uuid4()),
"aud": os.getenv("OAUTH2_TOKEN_URL"),
"exp": int(time.time()) + 300, # 5mins in the future
}

private_key = os.getenv("PRIVATE_KEY")

return generate_jwt(
algorithm, private_key, headers,
payload, expiry_minutes=5
)


def generate_jwt(
algorithm: str,
private_key,
headers: dict,
payload: dict,
expiry_minutes: int = None,
) -> str:
if expiry_minutes:
expiry_date = (
datetime.datetime.now(datetime.timezone.utc) +
datetime.timedelta(minutes=expiry_minutes)
)
payload["exp"] = expiry_date

return jwt.encode(payload, private_key, algorithm, headers)
6 changes: 6 additions & 0 deletions src/notify/app/utils/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import psycopg2
import time
from sqlalchemy import create_engine


def connection():
Expand All @@ -17,3 +18,8 @@ def connection():
logging.debug(f"Connected to database in {(end - start)}s")

return conn


def engine():
import pdb; pdb.set_trace()
return create_engine(connection().dsn)
64 changes: 64 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,70 @@ def truncate_table():
logging.error(f"Error: {e}")


@pytest.fixture
def message_batch_post_body():
return {
"data": {
"type": "MessageBatch",
"attributes": {
"routingPlanId": "b838b13c-f98c-4def-93f0-515d4e4f4ee1",
"messageBatchReference": "da0b1495-c7cb-468c-9d81-07dee089d728",
"messages": [
{
"messageReference": "703b8008-545d-4a04-bb90-1f2946ce1575",
"recipient": {
"nhsNumber": "9990548609",
"contactDetails": {
"email": "[email protected]",
"sms": "07777777777",
"address": {
"lines": [
"NHS England",
"6th Floor",
"7&8 Wellington Place",
"Leeds",
"West Yorkshire"
],
"postcode": "LS1 4AP"
}
}
},
"originator": {
"odsCode": "X26"
},
"personalisation": {}
}
]
}
}
}


@pytest.fixture
def message_batch_post_response():
return {
"data": {
"type": "MessageBatch",
"id": "2ZljUiS8NjJNs95PqiYOO7gAfJb",
"attributes": {
"messageBatchReference": "da0b1495-c7cb-468c-9d81-07dee089d728",
"routingPlan": {
"id": "b838b13c-f98c-4def-93f0-515d4e4f4ee1",
"name": "Plan Abc",
"version": "ztoe2qRAM8M8vS0bqajhyEBcvXacrGPp",
"createdDate": "2023-11-17T14:27:51.413Z"
},
"messages": [
{
"messageReference": "703b8008-545d-4a04-bb90-1f2946ce1575",
"id": "2WL3qFTEFM0qMY8xjRbt1LIKCzM"
}
]
}
}
}


@pytest.fixture
def channel_status_post_body():
return {
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/notify/app/services/test_message_batch_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import app.services.message_batch_recorder as message_batch_recorder

def test_save_batch(message_batch_post_body, message_batch_post_response):
success, response = message_batch_recorder.save_batch(
message_batch_post_body["data"],
message_batch_post_response,
"sent"
)
import pdb; pdb.set_trace()
assert success
assert response == "Batch #1 saved successfully"
53 changes: 53 additions & 0 deletions tests/unit/notify/app/utils/test_access_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import access_token
import logging
import pytest
import requests_mock
import cryptography.hazmat.primitives.asymmetric.rsa as rsa
from cryptography.hazmat.primitives import serialization


@pytest.fixture
def setup(monkeypatch):
"""Set up environment variables and private key for tests."""
monkeypatch.setenv("OAUTH2_TOKEN_URL", "http://tokens.example.com")
monkeypatch.setenv("OAUTH2_API_KEY", "an_api_key")
monkeypatch.setenv("OAUTH2_API_KID", "a_kid")
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
private_key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode()
monkeypatch.setenv("PRIVATE_KEY", private_key_pem)


def test_get_token_successful_response(setup):
"""Test that a valid response returns the expected access token."""
with requests_mock.Mocker() as mock:
mock.post(
"http://tokens.example.com/",
json={"access_token": "an_access_token"},
)

token = access_token.get_token()
assert token == "an_access_token"


def test_get_token_error_response(setup, mocker):
"""Test that an error response results in an empty token and logs errors."""
error_logging_spy = mocker.spy(logging, "error")

with requests_mock.Mocker() as mock:
mock.post(
"http://tokens.example.com/",
status_code=403,
json={"error": "an_error"},
)

token = access_token.get_token()
assert token == ""
error_logging_spy.assert_any_call("Failed to get access token")
error_logging_spy.assert_any_call({"error": "an_error"})

0 comments on commit 7af6d7b

Please sign in to comment.