diff --git a/src/etl/sds/worker/load/tests/test_load_worker.py b/src/etl/sds/worker/load/tests/test_load_worker.py index c9e548b27..11468af09 100644 --- a/src/etl/sds/worker/load/tests/test_load_worker.py +++ b/src/etl/sds/worker/load/tests/test_load_worker.py @@ -173,7 +173,7 @@ def test_load_worker_pass( response = load.handler(event={}, context=None) assert response == { "stage_name": "load", - "processed_records": 4 * n_initial_unprocessed, + "processed_records": 2 * n_initial_unprocessed, "unprocessed_records": 0, "error_message": None, } diff --git a/src/etl/sds/worker/transform/tests/test_transform_worker.py b/src/etl/sds/worker/transform/tests/test_transform_worker.py index 761da5a1d..fb7090d96 100644 --- a/src/etl/sds/worker/transform/tests/test_transform_worker.py +++ b/src/etl/sds/worker/transform/tests/test_transform_worker.py @@ -108,9 +108,8 @@ def test_transform_worker_pass_dupe_check_mock( response = transform.handler(event={}, context=None) assert response == { "stage_name": "transform", - # 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created, - # all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent - "processed_records": n_initial_processed + 9 * n_initial_unprocessed, + # 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created + "processed_records": n_initial_processed + 5 * n_initial_unprocessed, "unprocessed_records": 0, "error_message": None, } @@ -123,7 +122,7 @@ def test_transform_worker_pass_dupe_check_mock( # Confirm that everything has now been processed, and that there is no # unprocessed data left in the bucket - assert n_final_processed == n_initial_processed + 9 * n_initial_unprocessed + assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed assert n_final_unprocessed == 0 @@ -148,9 +147,8 @@ def test_transform_worker_pass_no_dupes( assert response == { "stage_name": "transform", - # 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created, - # all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent - "processed_records": n_initial_processed + 9 * n_initial_unprocessed, + # 2 x initial unprocessed because a key event is also created + "processed_records": n_initial_processed + 5 * n_initial_unprocessed, "unprocessed_records": 0, "error_message": None, } @@ -163,7 +161,7 @@ def test_transform_worker_pass_no_dupes( # Confirm that everything has now been processed, and that there is no # unprocessed data left in the bucket - assert n_final_processed == n_initial_processed + 9 * n_initial_unprocessed + assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed assert n_final_unprocessed == 0 @@ -191,9 +189,8 @@ def test_transform_worker_pass_no_dupes_max_records( n_unprocessed_records_expected = ( n_unprocessed_records - n_newly_processed_records_expected ) - # 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created, - # all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent - n_total_processed_records_expected += 9 * n_newly_processed_records_expected + # 5 x initial unprocessed because 5 events are created for each input record + n_total_processed_records_expected += 5 * n_newly_processed_records_expected # Execute the transform worker with mock.patch("etl.sds.worker.transform.transform.reject_duplicate_keys"): @@ -217,7 +214,7 @@ def test_transform_worker_pass_no_dupes_max_records( # Confirm that everything has now been processed, and that there is no # unprocessed data left in the bucket - assert n_final_processed == n_initial_processed + 9 * n_initial_unprocessed + assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed assert n_final_unprocessed == 0 @@ -304,9 +301,8 @@ def test_transform_worker_bad_record( assert response == { "stage_name": "transform", - # 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created, - # all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent - "processed_records": n_initial_processed + (9 * bad_record_index), + # 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created + "processed_records": n_initial_processed + (5 * bad_record_index), "unprocessed_records": n_initial_unprocessed - bad_record_index, "error_message": [ "The following errors were encountered", @@ -327,9 +323,8 @@ def test_transform_worker_bad_record( # Confirm that there are still unprocessed records, and that there may have been # some records processed successfully assert n_final_unprocessed > 0 - # 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created, - # all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent - assert n_final_processed == n_initial_processed + (9 * bad_record_index) + # 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created + assert n_final_processed == n_initial_processed + (5 * bad_record_index) assert n_final_unprocessed == n_initial_unprocessed - bad_record_index diff --git a/src/layers/domain/core/device.py b/src/layers/domain/core/device.py index 74d27b6ca..bc2baf573 100644 --- a/src/layers/domain/core/device.py +++ b/src/layers/domain/core/device.py @@ -1,6 +1,5 @@ from collections import defaultdict -from datetime import UTC -from datetime import datetime as DateTime +from datetime import datetime from enum import StrEnum, auto from itertools import chain from typing import Any, Optional @@ -158,35 +157,24 @@ class Device(AggregateRoot): status: DeviceStatus = Field(default=DeviceStatus.ACTIVE) product_team_id: UUID ods_code: str - created_on: DateTime = Field( - default_factory=lambda: DateTime.now(UTC), immutable=True - ) - updated_on: Optional[DateTime] = Field(default=None) - deleted_on: Optional[DateTime] = Field(default=None) + created_on: datetime = Field(default_factory=datetime.utcnow, immutable=True) + updated_on: Optional[datetime] = Field(default=None) + deleted_on: Optional[datetime] = Field(default=None) keys: dict[str, DeviceKey] = Field(default_factory=dict, exclude=True) questionnaire_responses: dict[str, list[QuestionnaireResponse]] = Field( default_factory=lambda: defaultdict(list), exclude=True ) indexes: set[tuple[str, str, Any]] = Field(default_factory=set, exclude=True) - def add_event(self, event: Event) -> Event: - """ - Override the base add_event to ensure that the 'updated_on' timestamp - is always updated when events (other than DeviceUpdatedEvent) occur. - """ - _event = super().add_event(event=event) - self.update() - return _event - def update(self, **kwargs) -> DeviceUpdatedEvent: if "updated_on" not in kwargs: - kwargs["updated_on"] = DateTime.now(UTC) + kwargs["updated_on"] = datetime.utcnow() device_data = self._update(data=kwargs) event = DeviceUpdatedEvent(**device_data) - return super().add_event(event) + return self.add_event(event) def delete(self) -> DeviceUpdatedEvent: - deletion_datetime = DateTime.now(UTC) + deletion_datetime = datetime.utcnow() return self.update( status=DeviceStatus.INACTIVE, updated_on=deletion_datetime, diff --git a/src/layers/domain/core/product_team.py b/src/layers/domain/core/product_team.py index 16e876d68..da901b8d1 100644 --- a/src/layers/domain/core/product_team.py +++ b/src/layers/domain/core/product_team.py @@ -42,6 +42,6 @@ def create_device( ods_code=self.ods_code, ) device_created_event = DeviceCreatedEvent(**device.dict(), _trust=_trust) - super(Device, device).add_event(device_created_event) - super().add_event(device_created_event) + device.add_event(device_created_event) + self.add_event(device_created_event) return device diff --git a/src/layers/domain/core/tests/test_device.py b/src/layers/domain/core/tests/test_device.py index a2c5e1599..8933b0ea6 100644 --- a/src/layers/domain/core/tests/test_device.py +++ b/src/layers/domain/core/tests/test_device.py @@ -293,19 +293,3 @@ def test_device_add_index_no_such_question(device: Device): with pytest.raises(QuestionNotFoundError): device.add_index(questionnaire_id="foo/1", question_name="question1") - - -def test_device_add_event(device: Device): - created_on_before = device.created_on - assert isinstance(device.created_on, datetime) - assert device.updated_on is None - assert device.deleted_on is None - - device.add_event("an_event") - - assert isinstance(device.created_on, datetime) - assert isinstance(device.updated_on, datetime) - assert device.deleted_on is None - - assert device.created_on == created_on_before - assert device.updated_on > device.created_on diff --git a/src/layers/domain/repository/repository.py b/src/layers/domain/repository/repository.py index 3e443f3e0..f9f8b0cdb 100644 --- a/src/layers/domain/repository/repository.py +++ b/src/layers/domain/repository/repository.py @@ -26,22 +26,6 @@ def batched(iterable: T, n: int) -> Generator[T, None, None]: piece = list(islice(i, n)) -def _split_transactions_by_key( - transact_items: list[TransactItem], n_max: int -) -> Generator[list[TransactItem], None, None]: - buffer, keys = [], set() - for transact_item in transact_items: - transaction_statement = transact_item.Put or transact_item.Delete - item = transaction_statement.Key or transaction_statement.Item - key = (item["pk"]["S"], item["sk"]["S"]) - if key in keys: - yield from batched(buffer, n=n_max) - buffer, keys = [], set() - buffer.append(transact_item) - keys.add(key) - yield from batched(buffer, n=n_max) - - class Repository(Generic[ModelType]): def __init__(self, table_name, model: type[ModelType], dynamodb_client): self.table_name = table_name @@ -55,12 +39,11 @@ def generate_transaction_statements(event) -> TransactItem: return handler(event=event) responses = [] - transact_items = map(generate_transaction_statements, entity.events) - for _transact_items in _split_transactions_by_key( - transact_items, n_max=batch_size - ): - transaction = Transaction(TransactItems=_transact_items) - with handle_client_errors(commands=_transact_items): + for events in batched(entity.events, n=batch_size): + transact_items = list(map(generate_transaction_statements, events)) + transaction = Transaction(TransactItems=transact_items) + + with handle_client_errors(commands=transact_items): _response = self.client.transact_write_items( **transaction.dict(exclude_none=True) ) diff --git a/src/layers/domain/repository/tests/test_repository.py b/src/layers/domain/repository/tests/test_repository.py index e12e7ded2..9dbb5855b 100644 --- a/src/layers/domain/repository/tests/test_repository.py +++ b/src/layers/domain/repository/tests/test_repository.py @@ -1,6 +1,6 @@ import pytest from attr import asdict, dataclass -from domain.repository.errors import AlreadyExistsError +from domain.repository.errors import AlreadyExistsError, UnhandledTransaction from domain.repository.marshall import marshall, marshall_value, unmarshall from domain.repository.repository import Repository from domain.repository.transaction import ( @@ -142,9 +142,17 @@ def test_repository_raise_already_exists_from_single_transaction( MyEventAdd(field="123"), ], ) - with pytest.raises(AlreadyExistsError) as exc: + with pytest.raises(UnhandledTransaction) as exc: repository.write(my_item) - assert str(exc.value) == "Item already exists" + assert str(exc.value) == "\n".join( + ( + "ValidationException: Transaction request cannot include multiple operations on one item", + f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "prefix:456"}}, "sk": {{"S": "prefix:456"}}, "field": {{"S": "456"}}}}}}}}', + f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk) AND attribute_not_exists(sk) AND attribute_not_exists(pk_1) AND attribute_not_exists(sk_1) AND attribute_not_exists(pk_2) AND attribute_not_exists(sk_2)"}}}}', + f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "prefix:345"}}, "sk": {{"S": "prefix:345"}}, "field": {{"S": "345"}}}}}}}}', + f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk) AND attribute_not_exists(sk) AND attribute_not_exists(pk_1) AND attribute_not_exists(sk_1) AND attribute_not_exists(pk_2) AND attribute_not_exists(sk_2)"}}}}', + ) + ) @pytest.mark.integration diff --git a/src/layers/sds/cpm_translation/modify/tests/test_modify_device.py b/src/layers/sds/cpm_translation/modify/tests/test_modify_device.py index 5ba326d91..c4b8f723b 100644 --- a/src/layers/sds/cpm_translation/modify/tests/test_modify_device.py +++ b/src/layers/sds/cpm_translation/modify/tests/test_modify_device.py @@ -1,5 +1,5 @@ import pytest -from domain.core.device import Device, DeviceType, DeviceUpdatedEvent +from domain.core.device import Device, DeviceType from domain.core.questionnaire import ( Questionnaire, QuestionnaireResponse, @@ -198,10 +198,9 @@ def test_update_device_metadata( ) assert _device is device - assert len(_device.events) == len(events_before) + 2 + assert len(_device.events) == len(events_before) + 1 assert all(event in _device.events for event in events_before) - assert isinstance(_device.events[-1], DeviceUpdatedEvent) - assert isinstance(_device.events[-2], QuestionnaireResponseUpdatedEvent) + assert isinstance(_device.events[-1], QuestionnaireResponseUpdatedEvent) device_metadata = _device.questionnaire_responses[ questionnaire_response.questionnaire.id diff --git a/src/layers/sds/cpm_translation/tests/test_cpm_translation.py b/src/layers/sds/cpm_translation/tests/test_cpm_translation.py index 19200c9dd..9a8178aa7 100644 --- a/src/layers/sds/cpm_translation/tests/test_cpm_translation.py +++ b/src/layers/sds/cpm_translation/tests/test_cpm_translation.py @@ -38,13 +38,9 @@ EXPECTED_EVENTS = [ "device_created_event", "device_key_added_event", - "device_updated_event", "questionnaire_instance_event", - "device_updated_event", "questionnaire_response_added_event", - "device_updated_event", "device_index_added_event", - "device_updated_event", ]