From 2bef605ac6a5b950fe524de7f89a354e38851d4a Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 11:38:28 +0000 Subject: [PATCH 01/10] Ensure local test runs use test db --- .env.test | 4 ++-- test.sh | 3 +++ tests/conftest.py | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.env.test b/.env.test index 42cd95d6..96054f43 100644 --- a/.env.test +++ b/.env.test @@ -8,8 +8,8 @@ OAUTH2_API_KEY="" OAUTH2_API_KID="" OAUTH2_TOKEN_URL="" DATABASE_HOST=localhost -DATABASE_NAME=communication_management -DATABASE_PASSWORD=not-a-secret +DATABASE_NAME=communication_management_test +DATABASE_PASSWORD="" DATABASE_SSLMODE=allow DATABASE_USER=postgres PRIVATE_KEY="" diff --git a/test.sh b/test.sh index 049ff88e..2a24e80e 100755 --- a/test.sh +++ b/test.sh @@ -1,5 +1,8 @@ #!/bin/bash +source .env.test +sudo -u ${DATABASE_USER} psql -c "CREATE DATABASE ${DATABASE_NAME};" + ./test-setup.sh ./test-unit.sh ./test-integration.sh diff --git a/tests/conftest.py b/tests/conftest.py index 701f89d3..16b3bb29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,7 @@ import pytest +import dotenv + +dotenv.load_dotenv(".env.test") @pytest.hookimpl(hookwrapper=True) From 5a8f46fa72b093efa7ad76bdd1b2802c58cf9b6a Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 11:43:07 +0000 Subject: [PATCH 02/10] Add status recorder service When the status callback endpoint receives requests the post body will be saved to the database to record the channel and message status depending on the payload type. This commit adds the datastore and service module which handles this persistence. --- src/notify/app/route_handlers/status.py | 3 +- src/notify/app/services/__init__.py | 0 src/notify/app/services/datastore.py | 68 +++++++++++++++++ src/notify/app/services/status_recorder.py | 34 +++++++++ src/notify/app/utils/__init__.py | 0 src/notify/app/utils/database.py | 40 ++++++++++ .../notify/app/services/test_datastore.py | 73 +++++++++++++++++++ .../services/test_message_status_recorder.py | 56 ++++++++++++++ 8 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 src/notify/app/services/__init__.py create mode 100644 src/notify/app/services/datastore.py create mode 100644 src/notify/app/services/status_recorder.py create mode 100644 src/notify/app/utils/__init__.py create mode 100644 src/notify/app/utils/database.py create mode 100644 tests/unit/notify/app/services/test_datastore.py create mode 100644 tests/unit/notify/app/services/test_message_status_recorder.py diff --git a/src/notify/app/route_handlers/status.py b/src/notify/app/route_handlers/status.py index 09d60c6e..f2a2b4c2 100644 --- a/src/notify/app/route_handlers/status.py +++ b/src/notify/app/route_handlers/status.py @@ -1,6 +1,7 @@ from flask import request import json import app.validators.request_validator as request_validator +import app.services.status_recorder as status_recorder def create(): @@ -14,7 +15,7 @@ def create(): status_code = 422 body = {"status": "error"} else: - # status_recorder.save_statuses(body_dict) + status_recorder.save_statuses(dict(request.form)) status_code = 200 body = {"status": "success"} diff --git a/src/notify/app/services/__init__.py b/src/notify/app/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/notify/app/services/datastore.py b/src/notify/app/services/datastore.py new file mode 100644 index 00000000..a1dd3cb3 --- /dev/null +++ b/src/notify/app/services/datastore.py @@ -0,0 +1,68 @@ +import app.utils.database as database +import logging +import psycopg2 + +INSERT_BATCH_MESSAGE = """ + INSERT INTO batch_messages ( + batch_id, + details, + message_reference, + nhs_number, + recipient_id, + status + ) VALUES ( + %(batch_id)s, + %(details)s, + %(message_reference)s, + %(nhs_number)s, + %(recipient_id)s, + %(status)s + ) RETURNING batch_id, message_reference""" + +INSERT_STATUS = """ + INSERT INTO {table_name} ( + idempotency_key, + message_id, + message_reference, + details, + status + ) VALUES ( + %(idempotency_key)s, + %(message_id)s, + %(message_reference)s, + %(details)s, + %(status)s + ) RETURNING idempotency_key""" + +STATUS_TABLE_NAMES_BY_TYPE = { + "ChannelStatus": "channel_statuses", + "MessageStatus": "message_statuses" +} + + +def create_batch_message_record(batch_message_data: dict) -> bool | list[str, str]: + try: + with database.connection() as conn: + with conn.cursor() as cur: + cur.execute(INSERT_BATCH_MESSAGE, batch_message_data) + return cur.fetchone() + + except psycopg2.Error as e: + logging.error("Error creating batch message record") + logging.error(f"{type(e).__name__} : {e}") + return False + + +def create_status_record(status_type: str, status_data: dict) -> bool | str: + table_name = STATUS_TABLE_NAMES_BY_TYPE[status_type] + try: + with database.connection() as conn: + with conn.cursor() as cur: + cur.execute(INSERT_STATUS.format(table_name=table_name), status_data) + + return cur.fetchone()[0] + + except psycopg2.Error as e: + logging.error("Error creating status record") + logging.error(f"{type(e).__name__} : {e}") + return False diff --git a/src/notify/app/services/status_recorder.py b/src/notify/app/services/status_recorder.py new file mode 100644 index 00000000..18554d11 --- /dev/null +++ b/src/notify/app/services/status_recorder.py @@ -0,0 +1,34 @@ +import app.services.datastore as datastore +import json +import logging + + +def save_statuses(request_body: dict) -> None: + statuses: list[dict] = status_params(request_body) + status_type = request_body["data"][0]["type"] + + for status in statuses: + datastore.create_status_record(status_type, status) + + return None + + +def status_params(request_body: dict) -> list[dict]: + params = [] + for status_data in request_body["data"]: + try: + attributes = status_data["attributes"] + meta = status_data["meta"] + params.append({ + "details": json.dumps(request_body), + "idempotency_key": meta["idempotencyKey"], + "message_id": attributes["messageId"], + "message_reference": attributes["messageReference"], + "status": attributes.get("messageStatus", attributes.get("channelStatus")), + }) + except KeyError as e: + logging.error(f"Missing required field: {e}") + logging.error(f"Request body: {request_body}") + continue + + return params diff --git a/src/notify/app/utils/__init__.py b/src/notify/app/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/notify/app/utils/database.py b/src/notify/app/utils/database.py new file mode 100644 index 00000000..8f182da8 --- /dev/null +++ b/src/notify/app/utils/database.py @@ -0,0 +1,40 @@ +import logging +import os +import psycopg2 +import time + +SCHEMA_FILE_PATH = f"{os.path.dirname(__file__)}/../../../../database/schema.sql" +SCHEMA_INITIALISED_SQL = """ + SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = 'channel_statuses') +""" + + +def connection(): + start = time.time() + conn = psycopg2.connect( + dbname=os.environ["DATABASE_NAME"], + user=os.environ["DATABASE_USER"], + host=os.environ["DATABASE_HOST"], + password=os.environ["DATABASE_PASSWORD"], + sslmode=os.getenv("DATABASE_SSLMODE", "require"), + ) + end = time.time() + logging.debug(f"Connected to database in {(end - start)}s") + + check_and_initialise_schema(conn) + + return conn + + +def check_and_initialise_schema(conn: psycopg2.extensions.connection): + if bool(os.getenv("SCHEMA_INITIALISED")): + return + + with conn.cursor() as cur: + cur.execute(SCHEMA_INITIALISED_SQL) + if not bool(cur.fetchone()[0]): + logging.info("Initialising schema") + cur.execute(open(SCHEMA_FILE_PATH, "r").read()) + + conn.commit() + os.environ["SCHEMA_INITIALISED"] = "true" diff --git a/tests/unit/notify/app/services/test_datastore.py b/tests/unit/notify/app/services/test_datastore.py new file mode 100644 index 00000000..f6cda4b5 --- /dev/null +++ b/tests/unit/notify/app/services/test_datastore.py @@ -0,0 +1,73 @@ +import app.services.datastore as datastore +import pytest + + +@pytest.fixture +def mock_connection(mocker): + return mocker.patch("app.utils.database.connection") + + +@pytest.fixture +def mock_cursor(mocker, mock_connection): + return mock_connection().__enter__().cursor().__enter__() + + +@pytest.fixture +def batch_message_data(autouse=True): + return { + "batch_id": "0b3b3b3b-3b3b-3b3b-3b3b-3b3b3b3b3b3b", + "details": "Test details", + "message_reference": "0b3b3b3b-3b3b-3b-3b3b-3b3b3b3b3b3b", + "nhs_number": "1234567890", + "recipient_id": "e3e7b3b3-3b3b-3b-3b3b-3b3b3b3b3b3b", + "status": "test_status", + } + + +@pytest.fixture +def message_status_data(autouse=True): + return { + "idempotency_key": "0b3b3b3b-3b3b-3b3b-3b3b-3b3b3b3b3b3b", + "message_id": "0x0x0x0xabx0x0", + "message_reference": "0b3b3b3b-3b3b-3b3b-3b3b-3b3b3b3b3b3b", + "details": "Test details", + "status": "test_status", + } + + +def test_create_batch_message_record(mock_cursor): + """Test the SQL execution of batch message record creation.""" + datastore.create_batch_message_record(batch_message_data) + + mock_cursor.execute.assert_called_with(datastore.INSERT_BATCH_MESSAGE, batch_message_data) + mock_cursor.fetchone.assert_called_once() + + +def test_create_batch_message_record_with_error(mock_cursor): + """Test the SQL execution of batch message record creation with an error.""" + mock_cursor.execute.side_effect = Exception("Test error") + + with pytest.raises(Exception): + assert datastore.create_batch_message_record(batch_message_data) is False + + mock_cursor.execute.assert_called_with(datastore.INSERT_BATCH_MESSAGE, batch_message_data) + mock_cursor.fetchone.assert_not_called() + + +def test_create_message_status_record(mock_cursor): + """Test the SQL execution of message status record creation.""" + datastore.create_status_record("MessageStatus", message_status_data) + + mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table_name="message_statuses"), message_status_data) + mock_cursor.fetchone.assert_called_once() + + +def test_create_message_status_record_with_error(mock_cursor): + """Test the SQL execution of message status record creation with an error.""" + mock_cursor.execute.side_effect = Exception("Test error") + + with pytest.raises(Exception): + assert datastore.create_status_record("MessageStatus", message_status_data) is False + + mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table_name="message_statuses"), message_status_data) + mock_cursor.fetchone.assert_not_called() diff --git a/tests/unit/notify/app/services/test_message_status_recorder.py b/tests/unit/notify/app/services/test_message_status_recorder.py new file mode 100644 index 00000000..d25d7625 --- /dev/null +++ b/tests/unit/notify/app/services/test_message_status_recorder.py @@ -0,0 +1,56 @@ +import json +import app.services.status_recorder as status_recorder + + +def test_save_statuses_with_channel_status_data(mocker, channel_status_post_body): + """Test saving channel status data to datastore""" + mock_datastore = mocker.patch("app.services.status_recorder.datastore") + + assert status_recorder.save_statuses(channel_status_post_body) is None + + mock_datastore.create_status_record.assert_called_once_with("ChannelStatus", { + "details": json.dumps(channel_status_post_body), + "idempotency_key": "2515ae6b3a08339fba3534f3b17cd57cd573c57d25b25b9aae08e42dc9f0a445", #gitleaks:allow + "message_id": "2WL3qFTEFM0qMY8xjRbt1LIKCzM", + "message_reference": "1642109b-69eb-447f-8f97-ab70a74f5db4", + "status": "delivered" + }) + + +def test_save_statuses_with_message_status_data(mocker, message_status_post_body): + """Test saving message status data to datastore""" + mock_datastore = mocker.patch("app.services.status_recorder.datastore") + + assert status_recorder.save_statuses(message_status_post_body) is None + + mock_datastore.create_status_record.assert_called_once_with("MessageStatus", { + "details": json.dumps(message_status_post_body), + "idempotency_key": "2515ae6b3a08339fba3534f3b17cd57cd573c57d25b25b9aae08e42dc9f0a445", #gitleaks:allow + "message_id": "2WL3qFTEFM0qMY8xjRbt1LIKCzM", + "message_reference": "1642109b-69eb-447f-8f97-ab70a74f5db4", + "status": "sending" + }) + + +def test_status_params(message_status_post_body): + """Test conversion of request body to message status parameters""" + expected = [ + { + "details": json.dumps(message_status_post_body), + "idempotency_key": "2515ae6b3a08339fba3534f3b17cd57cd573c57d25b25b9aae08e42dc9f0a445", #gitleaks:allow + "message_id": "2WL3qFTEFM0qMY8xjRbt1LIKCzM", + "message_reference": "1642109b-69eb-447f-8f97-ab70a74f5db4", + "status": "sending" + }, + ] + + assert status_recorder.status_params(message_status_post_body) == expected + + +def test_status_params_with_missing_field(message_status_post_body): + """Test conversion of request body with missing field to message status parameters""" + message_status_post_body["data"][0]["attributes"].pop("messageReference") + + expected = [] + + assert status_recorder.status_params(message_status_post_body) == expected From e1126d1043596542488cc00635d96e1de7629acf Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 12:00:20 +0000 Subject: [PATCH 03/10] Use psycopg2 sql.Identifier to safely interpolate the table name There's good support for composition of SQL in psycopg2 so use it. --- src/notify/app/services/datastore.py | 9 +++++---- tests/unit/notify/app/services/test_datastore.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/notify/app/services/datastore.py b/src/notify/app/services/datastore.py index a1dd3cb3..c905146c 100644 --- a/src/notify/app/services/datastore.py +++ b/src/notify/app/services/datastore.py @@ -1,6 +1,7 @@ import app.utils.database as database import logging import psycopg2 +from psycopg2 import sql INSERT_BATCH_MESSAGE = """ INSERT INTO batch_messages ( @@ -19,8 +20,8 @@ %(status)s ) RETURNING batch_id, message_reference""" -INSERT_STATUS = """ - INSERT INTO {table_name} ( +INSERT_STATUS = sql.SQL(""" + INSERT INTO {table} ( idempotency_key, message_id, message_reference, @@ -32,7 +33,7 @@ %(message_reference)s, %(details)s, %(status)s - ) RETURNING idempotency_key""" + ) RETURNING idempotency_key""") STATUS_TABLE_NAMES_BY_TYPE = { "ChannelStatus": "channel_statuses", @@ -58,7 +59,7 @@ def create_status_record(status_type: str, status_data: dict) -> bool | str: try: with database.connection() as conn: with conn.cursor() as cur: - cur.execute(INSERT_STATUS.format(table_name=table_name), status_data) + cur.execute(INSERT_STATUS.format(table=sql.Identifier(table_name)), status_data) return cur.fetchone()[0] diff --git a/tests/unit/notify/app/services/test_datastore.py b/tests/unit/notify/app/services/test_datastore.py index f6cda4b5..1247daf6 100644 --- a/tests/unit/notify/app/services/test_datastore.py +++ b/tests/unit/notify/app/services/test_datastore.py @@ -1,5 +1,6 @@ import app.services.datastore as datastore import pytest +from psycopg2 import sql @pytest.fixture @@ -58,7 +59,7 @@ def test_create_message_status_record(mock_cursor): """Test the SQL execution of message status record creation.""" datastore.create_status_record("MessageStatus", message_status_data) - mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table_name="message_statuses"), message_status_data) + mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table=sql.Identifier("message_statuses")), message_status_data) mock_cursor.fetchone.assert_called_once() @@ -69,5 +70,5 @@ def test_create_message_status_record_with_error(mock_cursor): with pytest.raises(Exception): assert datastore.create_status_record("MessageStatus", message_status_data) is False - mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table_name="message_statuses"), message_status_data) + mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table=sql.Identifier("message_statuses")), message_status_data) mock_cursor.fetchone.assert_not_called() From 9722944e58259996ad981e40f8dead8e7d2dda7f Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 12:19:21 +0000 Subject: [PATCH 04/10] Add FLASK_DEBUG env vars --- .env.example | 1 + .env.test | 1 + 2 files changed, 2 insertions(+) diff --git a/.env.example b/.env.example index 19bf884d..518de78c 100644 --- a/.env.example +++ b/.env.example @@ -2,6 +2,7 @@ AZURITE_CONNECTION_STRING="" APPLICATION_ID="" BLOB_CONTAINER_NAME=pilot-data +FLASK_DEBUG=true NOTIFY_API_KEY="" NOTIFY_API_URL="" NOTIFY_FUNCTION_URL="" diff --git a/.env.test b/.env.test index 96054f43..a190469d 100644 --- a/.env.test +++ b/.env.test @@ -1,6 +1,7 @@ AZURITE_CONNECTION_STRING="DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" APPLICATION_ID="" BLOB_CONTAINER_NAME=pilot-data +FLASK_DEBUG=true NOTIFY_API_KEY="" NOTIFY_API_URL=https://sandbox.api.service.nhs.uk NOTIFY_FUNCTION_URL=http://localhost:7072/api/notify/message/send From acf011089ac4139b1b7325761e3f88fbae24e231 Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 15:23:39 +0000 Subject: [PATCH 05/10] Extract hmac signature generation We do this in test setup quite a few times so make it a utility. --- src/notify/app/utils/hmac_signature.py | 10 ++++ .../app/validators/request_validator.py | 12 ++--- .../notify/app/utils/test_hmac_signature.py | 50 +++++++++++++++++++ .../app/validators/test_request_validator.py | 14 +++--- 4 files changed, 72 insertions(+), 14 deletions(-) create mode 100644 src/notify/app/utils/hmac_signature.py create mode 100644 tests/unit/notify/app/utils/test_hmac_signature.py diff --git a/src/notify/app/utils/hmac_signature.py b/src/notify/app/utils/hmac_signature.py new file mode 100644 index 00000000..beb571e7 --- /dev/null +++ b/src/notify/app/utils/hmac_signature.py @@ -0,0 +1,10 @@ +import hmac +import hashlib + + +def create_digest(secret: str, message: str) -> str: + return hmac.new( + bytes(secret, 'ASCII'), + msg=bytes(message, 'ASCII'), + digestmod=hashlib.sha256 + ).hexdigest() diff --git a/src/notify/app/validators/request_validator.py b/src/notify/app/validators/request_validator.py index 7d42d2b9..4f7a286c 100644 --- a/src/notify/app/validators/request_validator.py +++ b/src/notify/app/validators/request_validator.py @@ -1,7 +1,10 @@ import hashlib import hmac +import json +import logging import os import app.validators.schema_validator as schema_validator +import app.utils.hmac_signature as hmac_signature API_KEY_HEADER_NAME = 'x-api-key' SIGNATURE_HEADER_NAME = 'x-hmac-sha256-signature' @@ -19,14 +22,11 @@ def verify_headers(headers: dict) -> bool: return True -def verify_signature(headers: dict, body: str) -> bool: +def verify_signature(headers: dict, body: dict) -> bool: lc_headers = header_keys_to_lower(headers) + body_str = json.dumps(body, sort_keys=True) - expected_signature = hmac.new( - bytes(signature_secret(), 'ASCII'), - msg=bytes(body, 'ASCII'), - digestmod=hashlib.sha256 - ).hexdigest() + expected_signature = hmac_signature.create_digest(signature_secret(), body_str) return hmac.compare_digest( expected_signature, diff --git a/tests/unit/notify/app/utils/test_hmac_signature.py b/tests/unit/notify/app/utils/test_hmac_signature.py new file mode 100644 index 00000000..1f735cde --- /dev/null +++ b/tests/unit/notify/app/utils/test_hmac_signature.py @@ -0,0 +1,50 @@ +import app.utils.hmac_signature as hmac_signature +import hmac +import hashlib + + +def test_valid_hmac_signature(): + """Test a valid HMAC signature matches.""" + secret = 'secret' + message = 'message' + + expected_signature = hmac.new( + bytes(secret, 'ASCII'), + msg=bytes(message, 'ASCII'), + digestmod=hashlib.sha256 + ).hexdigest() + + actual_signature = hmac_signature.create_digest(secret, message) + + assert hmac.compare_digest(expected_signature, actual_signature) + + +def test_unmatched_message_in_hmac_signature(): + """Test that a different message creates a different signature.""" + secret = 'secret' + message = 'message' + + expected_signature = hmac.new( + bytes(secret, 'ASCII'), + msg=bytes("nope", 'ASCII'), + digestmod=hashlib.sha256 + ).hexdigest() + + actual_signature = hmac_signature.create_digest(secret, message) + + assert not hmac.compare_digest(expected_signature, actual_signature) + +def test_unmatched_secret_in_hmac_signature(): + """Test that a different secret creates a different signature""" + secret = 'secret' + message = 'message' + + expected_signature = hmac.new( + bytes("nope", 'ASCII'), + msg=bytes(message, 'ASCII'), + digestmod=hashlib.sha256 + ).hexdigest() + + actual_signature = hmac_signature.create_digest(secret, message) + + assert not hmac.compare_digest(expected_signature, actual_signature) diff --git a/tests/unit/notify/app/validators/test_request_validator.py b/tests/unit/notify/app/validators/test_request_validator.py index 65e2971a..bb5e83f0 100644 --- a/tests/unit/notify/app/validators/test_request_validator.py +++ b/tests/unit/notify/app/validators/test_request_validator.py @@ -1,6 +1,8 @@ +import app.validators.request_validator as request_validator +import app.utils.hmac_signature as hmac_signature import hashlib import hmac -import app.validators.request_validator as request_validator +import json import pytest @@ -14,19 +16,15 @@ def setup(monkeypatch): def test_verify_signature_invalid(setup): """Test that an invalid signature fails verification.""" headers = {request_validator.SIGNATURE_HEADER_NAME: 'signature'} - body = 'body' + body = {'data': 'body'} assert not request_validator.verify_signature(headers, body) def test_verify_signature_valid(setup): """Test that a valid signature passes verification.""" - body = 'body' - signature = hmac.new( - bytes('application_id.api_key', 'ASCII'), - msg=bytes(body, 'ASCII'), - digestmod=hashlib.sha256 - ).hexdigest() + body = {'data': 'body'} + signature = hmac_signature.create_digest('application_id.api_key', json.dumps(body)) headers = {request_validator.SIGNATURE_HEADER_NAME: signature} assert request_validator.verify_signature(headers, body) From 24ffac4556ae62bc442a8dd34af611d67f77518a Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 15:26:54 +0000 Subject: [PATCH 06/10] Read the post body as json and ensure key sorting is consistent --- src/notify/app/route_handlers/status.py | 7 ++-- .../test_status_create_endpoint.py | 32 ++++++++++++------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/notify/app/route_handlers/status.py b/src/notify/app/route_handlers/status.py index f2a2b4c2..e23355ca 100644 --- a/src/notify/app/route_handlers/status.py +++ b/src/notify/app/route_handlers/status.py @@ -1,5 +1,4 @@ from flask import request -import json import app.validators.request_validator as request_validator import app.services.status_recorder as status_recorder @@ -8,14 +7,14 @@ def create(): if request_validator.verify_headers(dict(request.headers)) is False: status_code = 401 body = {"status": "error"} - elif request_validator.verify_signature(dict(request.headers), json.dumps(request.form)) is False: + elif request_validator.verify_signature(dict(request.headers), request.json) is False: status_code = 403 body = {"status": "error"} - elif request_validator.verify_body(dict(request.form)) is False: + elif request_validator.verify_body(request.json)[0] is False: status_code = 422 body = {"status": "error"} else: - status_recorder.save_statuses(dict(request.form)) + status_recorder.save_statuses(request.json) status_code = 200 body = {"status": "success"} diff --git a/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py b/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py index a9200cdc..c37033f0 100644 --- a/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py +++ b/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py @@ -1,5 +1,6 @@ from app import create_app from app.validators.request_validator import API_KEY_HEADER_NAME, SIGNATURE_HEADER_NAME, signature_secret +import app.utils.hmac_signature as hmac_signature import hashlib import hmac import json @@ -19,28 +20,37 @@ def client(): yield app.test_client() -def test_status_create_request_validation_fails(setup, client): +def test_status_create_request_validation_fails(setup, client, message_status_post_body): """Test that invalid request header values fail HMAC signature validation.""" - data = {"some": "data"} headers = {API_KEY_HEADER_NAME: "api_key", SIGNATURE_HEADER_NAME: "signature"} - response = client.post('/api/status/create', data=data, headers=headers) + response = client.post('/api/status/create', json=message_status_post_body, headers=headers) assert response.status_code == 403 assert response.get_json() == {"status": "error"} -def test_status_create_request_validation_succeeds(setup, client): +def test_status_create_body_validation_fails(setup, client, message_status_post_body): + """Test that invalid request body fails schema validation.""" + message_status_post_body["data"][0]["attributes"]["messageStatus"] = "invalid" + signature = hmac_signature.create_digest(signature_secret(), json.dumps(message_status_post_body, sort_keys=True)) + + headers = {API_KEY_HEADER_NAME: "api_key", SIGNATURE_HEADER_NAME: signature} + + response = client.post('/api/status/create', json=message_status_post_body, headers=headers) + + assert response.status_code == 422 + assert response.get_json() == {"status": "error"} + + + +def test_status_create_request_validation_succeeds(setup, client, message_status_post_body): """Test that valid request header values pass HMAC signature validation.""" - data = {"some": "data"} - signature = hmac.new( - bytes(signature_secret(), 'ASCII'), - msg=bytes(json.dumps(data), 'ASCII'), - digestmod=hashlib.sha256 - ).hexdigest() + signature = hmac_signature.create_digest(signature_secret(), json.dumps(message_status_post_body, sort_keys=True)) + headers = {API_KEY_HEADER_NAME: "api_key", SIGNATURE_HEADER_NAME: signature} - response = client.post('/api/status/create', data=data, headers=headers) + response = client.post('/api/status/create', json=message_status_post_body, headers=headers) assert response.status_code == 200 assert response.get_json() == {"status": "success"} From 372eb626ff7d6bed0de276b4d0272aa48fede824 Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 16:12:01 +0000 Subject: [PATCH 07/10] Truncate all tables after each test --- tests/conftest.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 16b3bb29..70540e8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import app.utils.database as database import pytest import dotenv @@ -20,6 +21,16 @@ def pytest_runtest_makereport(item, call): report.location = tuple(location) +@pytest.fixture(autouse=True, scope="function") +def truncate_table(): + with database.connection() as conn: + with conn.cursor() as cur: + cur.execute("TRUNCATE TABLE batch_messages") + cur.execute("TRUNCATE TABLE channel_statuses") + cur.execute("TRUNCATE TABLE message_statuses") + cur.connection.commit() + + @pytest.fixture def channel_status_post_body(): return { From 4263a471cd90e0706f7a1f3dcd25a9cb05a40050 Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 16:14:09 +0000 Subject: [PATCH 08/10] Return useful error descriptions in the status/create response --- src/notify/app/route_handlers/status.py | 29 ++++++++++--------- src/notify/app/services/status_recorder.py | 4 +-- .../app/validators/request_validator.py | 15 +++++----- .../test_status_create_endpoint.py | 4 +-- ...us_recorder.py => test_status_recorder.py} | 4 +-- .../app/validators/test_request_validator.py | 8 ++--- 6 files changed, 33 insertions(+), 31 deletions(-) rename tests/unit/notify/app/services/{test_message_status_recorder.py => test_status_recorder.py} (99%) diff --git a/src/notify/app/route_handlers/status.py b/src/notify/app/route_handlers/status.py index e23355ca..0c23b67b 100644 --- a/src/notify/app/route_handlers/status.py +++ b/src/notify/app/route_handlers/status.py @@ -4,18 +4,19 @@ def create(): - if request_validator.verify_headers(dict(request.headers)) is False: - status_code = 401 - body = {"status": "error"} - elif request_validator.verify_signature(dict(request.headers), request.json) is False: - status_code = 403 - body = {"status": "error"} - elif request_validator.verify_body(request.json)[0] is False: - status_code = 422 - body = {"status": "error"} - else: - status_recorder.save_statuses(request.json) - status_code = 200 - body = {"status": "success"} + json_data = request.json or {} + valid_headers, error_message = request_validator.verify_headers(dict(request.headers)) - return body, status_code + if not valid_headers: + return {"status": error_message}, 401 + + if not request_validator.verify_signature(dict(request.headers), json_data): + return {"status": "Invalid signature"}, 403 + + valid_body, error_message = request_validator.verify_body(json_data) + + if not valid_body: + return {"status": error_message}, 422 + + if status_recorder.save_statuses(json_data): + return {"status": "success"}, 200 diff --git a/src/notify/app/services/status_recorder.py b/src/notify/app/services/status_recorder.py index 18554d11..d1efb6e3 100644 --- a/src/notify/app/services/status_recorder.py +++ b/src/notify/app/services/status_recorder.py @@ -3,14 +3,14 @@ import logging -def save_statuses(request_body: dict) -> None: +def save_statuses(request_body: dict) -> bool: statuses: list[dict] = status_params(request_body) status_type = request_body["data"][0]["type"] for status in statuses: datastore.create_status_record(status_type, status) - return None + return True def status_params(request_body: dict) -> list[dict]: diff --git a/src/notify/app/validators/request_validator.py b/src/notify/app/validators/request_validator.py index 4f7a286c..5780c424 100644 --- a/src/notify/app/validators/request_validator.py +++ b/src/notify/app/validators/request_validator.py @@ -1,7 +1,6 @@ import hashlib import hmac import json -import logging import os import app.validators.schema_validator as schema_validator import app.utils.hmac_signature as hmac_signature @@ -10,16 +9,18 @@ SIGNATURE_HEADER_NAME = 'x-hmac-sha256-signature' -def verify_headers(headers: dict) -> bool: +def verify_headers(headers: dict) -> tuple[bool, str]: lc_headers = header_keys_to_lower(headers) - if (lc_headers.get(API_KEY_HEADER_NAME) is None or - lc_headers.get(API_KEY_HEADER_NAME) != os.getenv('NOTIFY_API_KEY')): - return False + if lc_headers.get(API_KEY_HEADER_NAME) is None: + return False, "Missing API key header" + + if lc_headers.get(API_KEY_HEADER_NAME) != os.getenv('NOTIFY_API_KEY'): + return False, "Invalid API key" if lc_headers.get(SIGNATURE_HEADER_NAME) is None: - return False + return False, "Missing signature header" - return True + return True, "" def verify_signature(headers: dict, body: dict) -> bool: diff --git a/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py b/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py index c37033f0..6300dec1 100644 --- a/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py +++ b/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py @@ -27,7 +27,7 @@ def test_status_create_request_validation_fails(setup, client, message_status_po response = client.post('/api/status/create', json=message_status_post_body, headers=headers) assert response.status_code == 403 - assert response.get_json() == {"status": "error"} + assert response.get_json() == {"status": "Invalid signature"} def test_status_create_body_validation_fails(setup, client, message_status_post_body): @@ -40,7 +40,7 @@ def test_status_create_body_validation_fails(setup, client, message_status_post_ response = client.post('/api/status/create', json=message_status_post_body, headers=headers) assert response.status_code == 422 - assert response.get_json() == {"status": "error"} + assert response.get_json() == {"status": "'invalid' is not one of ['created', 'pending_enrichment', 'enriched', 'sending', 'delivered', 'failed']"} diff --git a/tests/unit/notify/app/services/test_message_status_recorder.py b/tests/unit/notify/app/services/test_status_recorder.py similarity index 99% rename from tests/unit/notify/app/services/test_message_status_recorder.py rename to tests/unit/notify/app/services/test_status_recorder.py index d25d7625..1612d21d 100644 --- a/tests/unit/notify/app/services/test_message_status_recorder.py +++ b/tests/unit/notify/app/services/test_status_recorder.py @@ -6,7 +6,7 @@ def test_save_statuses_with_channel_status_data(mocker, channel_status_post_body """Test saving channel status data to datastore""" mock_datastore = mocker.patch("app.services.status_recorder.datastore") - assert status_recorder.save_statuses(channel_status_post_body) is None + assert status_recorder.save_statuses(channel_status_post_body) mock_datastore.create_status_record.assert_called_once_with("ChannelStatus", { "details": json.dumps(channel_status_post_body), @@ -21,7 +21,7 @@ def test_save_statuses_with_message_status_data(mocker, message_status_post_body """Test saving message status data to datastore""" mock_datastore = mocker.patch("app.services.status_recorder.datastore") - assert status_recorder.save_statuses(message_status_post_body) is None + assert status_recorder.save_statuses(message_status_post_body) mock_datastore.create_status_record.assert_called_once_with("MessageStatus", { "details": json.dumps(message_status_post_body), diff --git a/tests/unit/notify/app/validators/test_request_validator.py b/tests/unit/notify/app/validators/test_request_validator.py index bb5e83f0..010e014c 100644 --- a/tests/unit/notify/app/validators/test_request_validator.py +++ b/tests/unit/notify/app/validators/test_request_validator.py @@ -33,19 +33,19 @@ def test_verify_signature_valid(setup): def test_verify_headers_missing_all(setup): """Test that missing all headers fails verification.""" headers = {} - assert not request_validator.verify_headers(headers) + assert request_validator.verify_headers(headers) == (False, 'Missing API key header') def test_verify_headers_missing_api_key(setup): """Test that missing API key header fails verification.""" headers = {request_validator.SIGNATURE_HEADER_NAME: 'signature'} - assert not request_validator.verify_headers(headers) + assert request_validator.verify_headers(headers) == (False, 'Missing API key header') def test_verify_headers_missing_signature(setup): """Test that missing signature header fails verification.""" headers = {request_validator.API_KEY_HEADER_NAME: 'api_key'} - assert not request_validator.verify_headers(headers) + assert request_validator.verify_headers(headers) == (False, 'Missing signature header') def test_verify_headers_valid(setup): @@ -60,4 +60,4 @@ def test_verify_headers_valid(setup): def test_verify_headers_invalid_api_key(setup): """Test that an invalid API key fails verification.""" headers = {request_validator.API_KEY_HEADER_NAME: 'invalid_api_key'} - assert not request_validator.verify_headers(headers) + assert request_validator.verify_headers(headers) == (False, 'Invalid API key') From ec5d8251d6031f091843e9ac6713f12389997375 Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Thu, 16 Jan 2025 16:25:06 +0000 Subject: [PATCH 09/10] Add postgres to CI test jobs --- .env.test | 2 +- .github/workflows/stage-2-test.yaml | 55 +++++++++++++++++++++++++---- tests/conftest.py | 6 ++-- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/.env.test b/.env.test index a190469d..c5696dc8 100644 --- a/.env.test +++ b/.env.test @@ -8,7 +8,7 @@ NOTIFY_FUNCTION_URL=http://localhost:7072/api/notify/message/send OAUTH2_API_KEY="" OAUTH2_API_KID="" OAUTH2_TOKEN_URL="" -DATABASE_HOST=localhost +DATABASE_HOST=127.0.0.1 DATABASE_NAME=communication_management_test DATABASE_PASSWORD="" DATABASE_SSLMODE=allow diff --git a/.github/workflows/stage-2-test.yaml b/.github/workflows/stage-2-test.yaml index 9fc987e1..d0b2bc1c 100644 --- a/.github/workflows/stage-2-test.yaml +++ b/.github/workflows/stage-2-test.yaml @@ -42,12 +42,6 @@ jobs: with: python-version: "3.11" - - name: Install dependencies - run: | - python -m venv venv - source venv/bin/activate - python -m pip install --upgrade pip - test-lint: name: "Linting" runs-on: ubuntu-latest @@ -66,7 +60,31 @@ jobs: needs: set-up-dependencies runs-on: ubuntu-latest timeout-minutes: 5 + services: + postgres: + image: postgres:11.6-alpine + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: "" + POSTGRES_SSLMODE: "disable" + ports: + - 5432:5432 + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + + env: + CI: true + DATABASE_HOST: 127.0.0.1 + DATABASE_PORT: 5432 + DATABASE_NAME: communication_management_test + DATABASE_USER: postgres + DATABASE_PASSWORD: "" + DATABASE_SSLMODE: "disable" + steps: + - name: "Create test db" + run: | + psql -U postgres -h 127.0.0.1 -d postgres -tc "CREATE DATABASE communication_management_test;" + - name: "Checkout code" uses: actions/checkout@v4 @@ -78,12 +96,37 @@ jobs: - name: "Run unit test suite" run: | ./test-unit.sh + test-integration: name: "Integration tests" needs: set-up-dependencies runs-on: ubuntu-latest timeout-minutes: 5 + services: + postgres: + image: postgres:11.6-alpine + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: "" + POSTGRES_SSLMODE: "disable" + ports: + - 5432:5432 + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + + env: + CI: true + DATABASE_HOST: 127.0.0.1 + DATABASE_PORT: 5432 + DATABASE_NAME: communication_management_test + DATABASE_USER: postgres + DATABASE_PASSWORD: "" + DATABASE_SSLMODE: "disable" + steps: + - name: "Create test db" + run: | + psql -U postgres -h 127.0.0.1 -d postgres -tc "CREATE DATABASE communication_management_test;" + - name: "Checkout code" uses: actions/checkout@v4 diff --git a/tests/conftest.py b/tests/conftest.py index 70540e8d..ddb82898 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ import app.utils.database as database -import pytest import dotenv +import os +import pytest -dotenv.load_dotenv(".env.test") +if not bool(os.getenv("CI")): + dotenv.load_dotenv(".env.test") @pytest.hookimpl(hookwrapper=True) From 97472a445bed367a2d96156b53366bcee78de3d1 Mon Sep 17 00:00:00 2001 From: Steve Laing Date: Mon, 20 Jan 2025 16:47:11 +0000 Subject: [PATCH 10/10] Refactor persistence tests We now have local dev and test dbs so improve tests by asserting persistence has worked as expected. --- .../test_status_create_endpoint.py | 20 +++ .../notify/app/services/test_datastore.py | 133 ++++++++++++------ 2 files changed, 107 insertions(+), 46 deletions(-) diff --git a/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py b/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py index 6300dec1..bc15649f 100644 --- a/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py +++ b/tests/integration/notify/app/route_handlers/test_status_create_endpoint.py @@ -1,5 +1,7 @@ from app import create_app from app.validators.request_validator import API_KEY_HEADER_NAME, SIGNATURE_HEADER_NAME, signature_secret +from datetime import datetime, timedelta +import app.utils.database as database import app.utils.hmac_signature as hmac_signature import hashlib import hmac @@ -54,3 +56,21 @@ def test_status_create_request_validation_succeeds(setup, client, message_status assert response.status_code == 200 assert response.get_json() == {"status": "success"} + + +def test_status_create_saves_records(setup, client, message_status_post_body): + """Test that valid requests are saved to the database.""" + signature = hmac_signature.create_digest(signature_secret(), json.dumps(message_status_post_body, sort_keys=True)) + + headers = {API_KEY_HEADER_NAME: "api_key", SIGNATURE_HEADER_NAME: signature} + + response = client.post('/api/status/create', json=message_status_post_body, headers=headers) + + assert response.status_code == 200 + + with database.connection() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM message_statuses") + record = cursor.fetchone() + assert record[0] - datetime.now() < timedelta(seconds=1) + assert record[1] == message_status_post_body diff --git a/tests/unit/notify/app/services/test_datastore.py b/tests/unit/notify/app/services/test_datastore.py index 1247daf6..498dc7d3 100644 --- a/tests/unit/notify/app/services/test_datastore.py +++ b/tests/unit/notify/app/services/test_datastore.py @@ -1,74 +1,115 @@ import app.services.datastore as datastore +import app.utils.database as database +from datetime import datetime, timedelta +import json import pytest -from psycopg2 import sql @pytest.fixture -def mock_connection(mocker): - return mocker.patch("app.utils.database.connection") - - -@pytest.fixture -def mock_cursor(mocker, mock_connection): - return mock_connection().__enter__().cursor().__enter__() - - -@pytest.fixture -def batch_message_data(autouse=True): +def batch_message_data() -> dict[str, str | dict]: return { - "batch_id": "0b3b3b3b-3b3b-3b3b-3b3b-3b3b3b3b3b3b", - "details": "Test details", - "message_reference": "0b3b3b3b-3b3b-3b-3b3b-3b3b3b3b3b3b", + "batch_id": "499c8396-16a0-417c-849e-f0062940cd2a", + "details": json.dumps({"test": "details"}), + "message_reference": "ee43e0ae-c2ca-4c44-8ddb-266c6dfd3b5e", "nhs_number": "1234567890", - "recipient_id": "e3e7b3b3-3b3b-3b-3b3b-3b3b3b3b3b3b", - "status": "test_status", + "recipient_id": "a1a77bf2-d5e2-430b-85ea-ac0ba8a59132", + "status": "sent", } @pytest.fixture -def message_status_data(autouse=True): +def message_status_data() -> dict[str, str | dict]: return { - "idempotency_key": "0b3b3b3b-3b3b-3b3b-3b3b-3b3b3b3b3b3b", + "details": json.dumps({"test": "details"}), + "idempotency_key": "47652cc9-8f76-423b-9923-273af024d264", #gitleaks:allow "message_id": "0x0x0x0xabx0x0", - "message_reference": "0b3b3b3b-3b3b-3b3b-3b3b-3b3b3b3b3b3b", - "details": "Test details", - "status": "test_status", + "message_reference": "5bd25347-f941-461f-952f-773540ad86c9", + "status": "delivered", } -def test_create_batch_message_record(mock_cursor): +def test_create_batch_message_record(batch_message_data): """Test the SQL execution of batch message record creation.""" datastore.create_batch_message_record(batch_message_data) - mock_cursor.execute.assert_called_with(datastore.INSERT_BATCH_MESSAGE, batch_message_data) - mock_cursor.fetchone.assert_called_once() + with database.connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT * FROM batch_messages WHERE batch_id = %s", + (batch_message_data["batch_id"],) + ) + row = cur.fetchone() + assert row[0] == batch_message_data["batch_id"] + assert row[1] - datetime.now() < timedelta(seconds=1) + assert row[2:7] == ( + json.loads(batch_message_data["details"]), + batch_message_data["message_reference"], + batch_message_data["nhs_number"], + batch_message_data["recipient_id"], + batch_message_data["status"] + ) -def test_create_batch_message_record_with_error(mock_cursor): - """Test the SQL execution of batch message record creation with an error.""" - mock_cursor.execute.side_effect = Exception("Test error") - with pytest.raises(Exception): - assert datastore.create_batch_message_record(batch_message_data) is False +def test_create_batch_message_record_error(batch_message_data): + """Test the error handling of batch message record creation.""" + batch_message_data["batch_id"] = "invalid" - mock_cursor.execute.assert_called_with(datastore.INSERT_BATCH_MESSAGE, batch_message_data) - mock_cursor.fetchone.assert_not_called() + assert not datastore.create_batch_message_record(batch_message_data) -def test_create_message_status_record(mock_cursor): +def test_create_message_status_record(message_status_data): """Test the SQL execution of message status record creation.""" datastore.create_status_record("MessageStatus", message_status_data) - mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table=sql.Identifier("message_statuses")), message_status_data) - mock_cursor.fetchone.assert_called_once() - - -def test_create_message_status_record_with_error(mock_cursor): - """Test the SQL execution of message status record creation with an error.""" - mock_cursor.execute.side_effect = Exception("Test error") - - with pytest.raises(Exception): - assert datastore.create_status_record("MessageStatus", message_status_data) is False - - mock_cursor.execute.assert_called_with(datastore.INSERT_STATUS.format(table=sql.Identifier("message_statuses")), message_status_data) - mock_cursor.fetchone.assert_not_called() + with database.connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT * FROM message_statuses WHERE idempotency_key = %s", + (message_status_data["idempotency_key"],) + ) + row = cur.fetchone() + assert row[0] - datetime.now() < timedelta(seconds=1) + assert row[1:6] == ( + json.loads(message_status_data["details"]), + message_status_data["idempotency_key"], + message_status_data["message_id"], + message_status_data["message_reference"], + message_status_data["status"] + ) + + +def test_create_message_status_record_error(message_status_data): + """Test the error handling of message status record creation.""" + message_status_data["status"] = "invalid" + + assert not datastore.create_status_record("MessageStatus", message_status_data) + + +def test_create_channel_status_record(message_status_data): + """Test the SQL execution of channel status record creation.""" + datastore.create_status_record("ChannelStatus", message_status_data) + channel_status_data = message_status_data + + with database.connection() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT * FROM channel_statuses WHERE idempotency_key = %s", + (channel_status_data["idempotency_key"],) + ) + row = cur.fetchone() + assert row[0] - datetime.now() < timedelta(seconds=1) + assert row[1:6] == ( + json.loads(message_status_data["details"]), + channel_status_data["idempotency_key"], + channel_status_data["message_id"], + channel_status_data["message_reference"], + channel_status_data["status"] + ) + + +def test_create_channel_status_record_error(message_status_data): + """Test the error handling of channel status record creation.""" + message_status_data["status"] = "invalid" + + assert not datastore.create_status_record("ChannelStatus", message_status_data)