From fa70e5eb7e2b4b3ba782b1e5bfe879d5d558abd5 Mon Sep 17 00:00:00 2001 From: elay Date: Wed, 12 Jun 2024 11:13:31 -0700 Subject: [PATCH] update assertions and logger --- pcfuncs/ipban/__init__.py | 4 +- pcfuncs/tests/ipban/test_ipban.py | 101 ++++++++++++++---------------- 2 files changed, 50 insertions(+), 55 deletions(-) diff --git a/pcfuncs/ipban/__init__.py b/pcfuncs/ipban/__init__.py index e26b9595..6e3d66fd 100644 --- a/pcfuncs/ipban/__init__.py +++ b/pcfuncs/ipban/__init__.py @@ -9,12 +9,14 @@ from .config import settings from .models import UpdateBannedIPTask +logger = logging.getLogger(__name__) + def main(mytimer: func.TimerRequest) -> None: utc_timestamp: str = ( datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat() ) - logging.info("Updating the ip ban list at %s", utc_timestamp) + logger.info("Updating the ip ban list at %s", utc_timestamp) credential: DefaultAzureCredential = DefaultAzureCredential() logs_query_client: LogsQueryClient = LogsQueryClient(credential) table_service_client: TableServiceClient = TableServiceClient( diff --git a/pcfuncs/tests/ipban/test_ipban.py b/pcfuncs/tests/ipban/test_ipban.py index c27d28a3..d2a17664 100644 --- a/pcfuncs/tests/ipban/test_ipban.py +++ b/pcfuncs/tests/ipban/test_ipban.py @@ -1,3 +1,5 @@ +import logging +import uuid from typing import Any, Dict, Generator, List, Tuple from unittest.mock import MagicMock @@ -12,47 +14,38 @@ from pytest_mock import MockerFixture MOCK_LOGS_QUERY_RESULT = [("192.168.1.1", 8000), ("192.168.1.4", 12000)] -TEST_BANNED_IP_TABLE = "testblobstoragebannedip" - - -def populate_banned_ip_table(table_client: TableClient) -> List[Dict[str, Any]]: - print("Populating the table") - entities: List[Dict[str, Any]] = [ - { - "PartitionKey": "192.168.1.1", - "RowKey": "192.168.1.1", - "ReadCount": 647, - "Threshold": settings.threshold_read_count_in_gb, - "TimeWindow": settings.time_window_in_hours, - }, - { - "PartitionKey": "192.168.1.2", - "RowKey": "192.168.1.2", - "ReadCount": 214, - "Threshold": settings.threshold_read_count_in_gb, - "TimeWindow": settings.time_window_in_hours, - }, - { - "PartitionKey": "192.168.1.3", - "RowKey": "192.168.1.3", - "ReadCount": 550, - "Threshold": settings.threshold_read_count_in_gb, - "TimeWindow": settings.time_window_in_hours, - }, - ] - for entity in entities: +TEST_ID = uuid.uuid4() +TEST_BANNED_IP_TABLE = f"testblobstoragebannedip-{TEST_ID}" + +logger = logging.getLogger(__name__) +PREPOPULATED_ENTITIES = [ + { + "PartitionKey": "192.168.1.1", + "RowKey": "192.168.1.1", + "ReadCount": 647, + "Threshold": settings.threshold_read_count_in_gb, + "TimeWindow": settings.time_window_in_hours, + }, + { + "PartitionKey": "192.168.1.2", + "RowKey": "192.168.1.2", + "ReadCount": 214, + "Threshold": settings.threshold_read_count_in_gb, + "TimeWindow": settings.time_window_in_hours, + }, + { + "PartitionKey": "192.168.1.3", + "RowKey": "192.168.1.3", + "ReadCount": 550, + "Threshold": settings.threshold_read_count_in_gb, + "TimeWindow": settings.time_window_in_hours, + }, +] + + +def populate_banned_ip_table(table_client: TableClient) -> None: + for entity in PREPOPULATED_ENTITIES: table_client.create_entity(entity) - return entities - - -def clear_table(table_client: TableClient) -> None: - entities = list(table_client.list_entities()) - for entity in entities: - table_client.delete_entity( - partition_key=entity["PartitionKey"], row_key=entity["RowKey"] - ) - entities = list(table_client.list_entities()) - assert len(entities) == 0 @pytest.fixture @@ -70,19 +63,18 @@ def mock_clients( "TableEndpoint=http://azurite:10002/devstoreaccount1;" ) # Use Azurite for unit tests and populate the table with initial data - table_service: TableServiceClient = TableServiceClient.from_connection_string( - CONNECTION_STRING + table_service_client: TableServiceClient = ( + TableServiceClient.from_connection_string(CONNECTION_STRING) ) - table_client: TableClient = table_service.create_table_if_not_exists( + table_client: TableClient = table_service_client.create_table_if_not_exists( table_name=TEST_BANNED_IP_TABLE ) - # Pre-populate the banned ip table populate_banned_ip_table(table_client) yield logs_query_client, table_client - # Clear all entities from the table - clear_table(table_client) + # Delete the test table + table_service_client.delete_table(TEST_BANNED_IP_TABLE) @pytest.fixture @@ -100,21 +92,22 @@ def integration_clients( # Pre-populate the banned ip table populate_banned_ip_table(table_client) yield logs_query_client, table_client - # Clear all entities from the table - clear_table(table_client) + + # Delete the test table + table_service_client.delete_table(TEST_BANNED_IP_TABLE) @pytest.mark.integration def test_update_banned_ip_integration( integration_clients: Tuple[LogsQueryClient, TableClient] ) -> None: - print("Integration test is running") + logger.info(f"Test id: {TEST_ID} - integration test is running") logs_query_client, table_client = integration_clients + assert len(list(table_client.list_entities())) == len(PREPOPULATED_ENTITIES) task: UpdateBannedIPTask = UpdateBannedIPTask(logs_query_client, table_client) # retrieve the logs query result from pc-api-loganalytics logs_query_result: List[LogsTableRow] = task.run() - entities = list(table_client.list_entities()) - assert len(logs_query_result) == len(entities) + assert len(list(table_client.list_entities())) == len(logs_query_result) for ip, expected_read_count in logs_query_result: entity: TableEntity = table_client.get_entity(ip, ip) assert entity["ReadCount"] == expected_read_count @@ -123,12 +116,12 @@ def test_update_banned_ip_integration( def test_update_banned_ip(mock_clients: Tuple[MagicMock, TableClient]) -> None: - print("Unit test is running") + logger.info(f"Test id: {TEST_ID} - unit test is running") mock_logs_query_client, table_client = mock_clients + assert len(list(table_client.list_entities())) == len(PREPOPULATED_ENTITIES) task: UpdateBannedIPTask = UpdateBannedIPTask(mock_logs_query_client, table_client) task.run() - entities = list(table_client.list_entities()) - assert len(entities) == len(MOCK_LOGS_QUERY_RESULT) + assert len(list(table_client.list_entities())) == len(MOCK_LOGS_QUERY_RESULT) for ip, expected_read_count in MOCK_LOGS_QUERY_RESULT: entity = table_client.get_entity(ip, ip) assert entity["ReadCount"] == expected_read_count