Skip to content

Commit

Permalink
[release/2024-06-04] rollback updates after add_event
Browse files Browse the repository at this point in the history
  • Loading branch information
jaklinger committed Jun 12, 2024
1 parent b4996ec commit 213f400
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 89 deletions.
2 changes: 1 addition & 1 deletion src/etl/sds/worker/load/tests/test_load_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_load_worker_pass(
response = load.handler(event={}, context=None)
assert response == {
"stage_name": "load",
"processed_records": 4 * n_initial_unprocessed,
"processed_records": 2 * n_initial_unprocessed,
"unprocessed_records": 0,
"error_message": None,
}
Expand Down
31 changes: 13 additions & 18 deletions src/etl/sds/worker/transform/tests/test_transform_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ def test_transform_worker_pass_dupe_check_mock(
response = transform.handler(event={}, context=None)
assert response == {
"stage_name": "transform",
# 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created,
# all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent
"processed_records": n_initial_processed + 9 * n_initial_unprocessed,
# 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created
"processed_records": n_initial_processed + 5 * n_initial_unprocessed,
"unprocessed_records": 0,
"error_message": None,
}
Expand All @@ -123,7 +122,7 @@ def test_transform_worker_pass_dupe_check_mock(

# Confirm that everything has now been processed, and that there is no
# unprocessed data left in the bucket
assert n_final_processed == n_initial_processed + 9 * n_initial_unprocessed
assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed
assert n_final_unprocessed == 0


Expand All @@ -148,9 +147,8 @@ def test_transform_worker_pass_no_dupes(

assert response == {
"stage_name": "transform",
# 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created,
# all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent
"processed_records": n_initial_processed + 9 * n_initial_unprocessed,
# 2 x initial unprocessed because a key event is also created
"processed_records": n_initial_processed + 5 * n_initial_unprocessed,
"unprocessed_records": 0,
"error_message": None,
}
Expand All @@ -163,7 +161,7 @@ def test_transform_worker_pass_no_dupes(

# Confirm that everything has now been processed, and that there is no
# unprocessed data left in the bucket
assert n_final_processed == n_initial_processed + 9 * n_initial_unprocessed
assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed
assert n_final_unprocessed == 0


Expand Down Expand Up @@ -191,9 +189,8 @@ def test_transform_worker_pass_no_dupes_max_records(
n_unprocessed_records_expected = (
n_unprocessed_records - n_newly_processed_records_expected
)
# 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created,
# all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent
n_total_processed_records_expected += 9 * n_newly_processed_records_expected
# 5 x initial unprocessed because 5 events are created for each input record
n_total_processed_records_expected += 5 * n_newly_processed_records_expected

# Execute the transform worker
with mock.patch("etl.sds.worker.transform.transform.reject_duplicate_keys"):
Expand All @@ -217,7 +214,7 @@ def test_transform_worker_pass_no_dupes_max_records(

# Confirm that everything has now been processed, and that there is no
# unprocessed data left in the bucket
assert n_final_processed == n_initial_processed + 9 * n_initial_unprocessed
assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed
assert n_final_unprocessed == 0


Expand Down Expand Up @@ -304,9 +301,8 @@ def test_transform_worker_bad_record(

assert response == {
"stage_name": "transform",
# 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created,
# all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent
"processed_records": n_initial_processed + (9 * bad_record_index),
# 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created
"processed_records": n_initial_processed + (5 * bad_record_index),
"unprocessed_records": n_initial_unprocessed - bad_record_index,
"error_message": [
"The following errors were encountered",
Expand All @@ -327,9 +323,8 @@ def test_transform_worker_bad_record(
# Confirm that there are still unprocessed records, and that there may have been
# some records processed successfully
assert n_final_unprocessed > 0
# 9 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created,
# all of which are accompanied by a DeviceUpdatedEvent, plus one DeviceCreatedEvent
assert n_final_processed == n_initial_processed + (9 * bad_record_index)
# 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created
assert n_final_processed == n_initial_processed + (5 * bad_record_index)
assert n_final_unprocessed == n_initial_unprocessed - bad_record_index


Expand Down
26 changes: 7 additions & 19 deletions src/layers/domain/core/device.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import defaultdict
from datetime import UTC
from datetime import datetime as DateTime
from datetime import datetime
from enum import StrEnum, auto
from itertools import chain
from typing import Any, Optional
Expand Down Expand Up @@ -158,35 +157,24 @@ class Device(AggregateRoot):
status: DeviceStatus = Field(default=DeviceStatus.ACTIVE)
product_team_id: UUID
ods_code: str
created_on: DateTime = Field(
default_factory=lambda: DateTime.now(UTC), immutable=True
)
updated_on: Optional[DateTime] = Field(default=None)
deleted_on: Optional[DateTime] = Field(default=None)
created_on: datetime = Field(default_factory=datetime.utcnow, immutable=True)
updated_on: Optional[datetime] = Field(default=None)
deleted_on: Optional[datetime] = Field(default=None)
keys: dict[str, DeviceKey] = Field(default_factory=dict, exclude=True)
questionnaire_responses: dict[str, list[QuestionnaireResponse]] = Field(
default_factory=lambda: defaultdict(list), exclude=True
)
indexes: set[tuple[str, str, Any]] = Field(default_factory=set, exclude=True)

def add_event(self, event: Event) -> Event:
"""
Override the base add_event to ensure that the 'updated_on' timestamp
is always updated when events (other than DeviceUpdatedEvent) occur.
"""
_event = super().add_event(event=event)
self.update()
return _event

def update(self, **kwargs) -> DeviceUpdatedEvent:
if "updated_on" not in kwargs:
kwargs["updated_on"] = DateTime.now(UTC)
kwargs["updated_on"] = datetime.utcnow()
device_data = self._update(data=kwargs)
event = DeviceUpdatedEvent(**device_data)
return super().add_event(event)
return self.add_event(event)

def delete(self) -> DeviceUpdatedEvent:
deletion_datetime = DateTime.now(UTC)
deletion_datetime = datetime.utcnow()
return self.update(
status=DeviceStatus.INACTIVE,
updated_on=deletion_datetime,
Expand Down
4 changes: 2 additions & 2 deletions src/layers/domain/core/product_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ def create_device(
ods_code=self.ods_code,
)
device_created_event = DeviceCreatedEvent(**device.dict(), _trust=_trust)
super(Device, device).add_event(device_created_event)
super().add_event(device_created_event)
device.add_event(device_created_event)
self.add_event(device_created_event)
return device
16 changes: 0 additions & 16 deletions src/layers/domain/core/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,3 @@ def test_device_add_index_no_such_question(device: Device):

with pytest.raises(QuestionNotFoundError):
device.add_index(questionnaire_id="foo/1", question_name="question1")


def test_device_add_event(device: Device):
created_on_before = device.created_on
assert isinstance(device.created_on, datetime)
assert device.updated_on is None
assert device.deleted_on is None

device.add_event("an_event")

assert isinstance(device.created_on, datetime)
assert isinstance(device.updated_on, datetime)
assert device.deleted_on is None

assert device.created_on == created_on_before
assert device.updated_on > device.created_on
27 changes: 5 additions & 22 deletions src/layers/domain/repository/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ def batched(iterable: T, n: int) -> Generator[T, None, None]:
piece = list(islice(i, n))


def _split_transactions_by_key(
transact_items: list[TransactItem], n_max: int
) -> Generator[list[TransactItem], None, None]:
buffer, keys = [], set()
for transact_item in transact_items:
transaction_statement = transact_item.Put or transact_item.Delete
item = transaction_statement.Key or transaction_statement.Item
key = (item["pk"]["S"], item["sk"]["S"])
if key in keys:
yield from batched(buffer, n=n_max)
buffer, keys = [], set()
buffer.append(transact_item)
keys.add(key)
yield from batched(buffer, n=n_max)


class Repository(Generic[ModelType]):
def __init__(self, table_name, model: type[ModelType], dynamodb_client):
self.table_name = table_name
Expand All @@ -55,12 +39,11 @@ def generate_transaction_statements(event) -> TransactItem:
return handler(event=event)

responses = []
transact_items = map(generate_transaction_statements, entity.events)
for _transact_items in _split_transactions_by_key(
transact_items, n_max=batch_size
):
transaction = Transaction(TransactItems=_transact_items)
with handle_client_errors(commands=_transact_items):
for events in batched(entity.events, n=batch_size):
transact_items = list(map(generate_transaction_statements, events))
transaction = Transaction(TransactItems=transact_items)

with handle_client_errors(commands=transact_items):
_response = self.client.transact_write_items(
**transaction.dict(exclude_none=True)
)
Expand Down
14 changes: 11 additions & 3 deletions src/layers/domain/repository/tests/test_repository.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from attr import asdict, dataclass
from domain.repository.errors import AlreadyExistsError
from domain.repository.errors import AlreadyExistsError, UnhandledTransaction
from domain.repository.marshall import marshall, marshall_value, unmarshall
from domain.repository.repository import Repository
from domain.repository.transaction import (
Expand Down Expand Up @@ -142,9 +142,17 @@ def test_repository_raise_already_exists_from_single_transaction(
MyEventAdd(field="123"),
],
)
with pytest.raises(AlreadyExistsError) as exc:
with pytest.raises(UnhandledTransaction) as exc:
repository.write(my_item)
assert str(exc.value) == "Item already exists"
assert str(exc.value) == "\n".join(
(
"ValidationException: Transaction request cannot include multiple operations on one item",
f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "prefix:456"}}, "sk": {{"S": "prefix:456"}}, "field": {{"S": "456"}}}}}}}}',
f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk) AND attribute_not_exists(sk) AND attribute_not_exists(pk_1) AND attribute_not_exists(sk_1) AND attribute_not_exists(pk_2) AND attribute_not_exists(sk_2)"}}}}',
f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "prefix:345"}}, "sk": {{"S": "prefix:345"}}, "field": {{"S": "345"}}}}}}}}',
f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk) AND attribute_not_exists(sk) AND attribute_not_exists(pk_1) AND attribute_not_exists(sk_1) AND attribute_not_exists(pk_2) AND attribute_not_exists(sk_2)"}}}}',
)
)


@pytest.mark.integration
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from domain.core.device import Device, DeviceType, DeviceUpdatedEvent
from domain.core.device import Device, DeviceType
from domain.core.questionnaire import (
Questionnaire,
QuestionnaireResponse,
Expand Down Expand Up @@ -198,10 +198,9 @@ def test_update_device_metadata(
)

assert _device is device
assert len(_device.events) == len(events_before) + 2
assert len(_device.events) == len(events_before) + 1
assert all(event in _device.events for event in events_before)
assert isinstance(_device.events[-1], DeviceUpdatedEvent)
assert isinstance(_device.events[-2], QuestionnaireResponseUpdatedEvent)
assert isinstance(_device.events[-1], QuestionnaireResponseUpdatedEvent)

device_metadata = _device.questionnaire_responses[
questionnaire_response.questionnaire.id
Expand Down
4 changes: 0 additions & 4 deletions src/layers/sds/cpm_translation/tests/test_cpm_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,9 @@
EXPECTED_EVENTS = [
"device_created_event",
"device_key_added_event",
"device_updated_event",
"questionnaire_instance_event",
"device_updated_event",
"questionnaire_response_added_event",
"device_updated_event",
"device_index_added_event",
"device_updated_event",
]


Expand Down

0 comments on commit 213f400

Please sign in to comment.