Skip to content

Commit 10d9fed

Browse files
committed
Merge branch 'feature/PI-477-query_by_tag_ignore_chunky_fields' into release/2024-09-13
2 parents 77f3163 + 5f816ad commit 10d9fed

File tree

9 files changed

+142
-43
lines changed

9 files changed

+142
-43
lines changed

src/api/searchSdsDevice/src/v1/steps.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from event.step_chain import StepChain
88
from pydantic import ValidationError
99

10+
FIELDS_TO_DROP = ["tags"]
11+
1012

1113
def parse_event_query(data, cache):
1214
event = APIGatewayProxyEvent(data[StepChain.INIT])
@@ -31,7 +33,7 @@ def query_devices(data, cache) -> List[dict]:
3133
device_repo = DeviceRepository(
3234
table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"]
3335
)
34-
results = device_repo.query_by_tag(**query_params)
36+
results = device_repo.query_by_tag(fields_to_drop=FIELDS_TO_DROP, **query_params)
3537
return [result.state() for result in results]
3638

3739

src/api/searchSdsDevice/tests/test_index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def _create_device(device, product_team, params):
4141

4242
questionnaire_response = questionnaire.respond(responses=response)
4343
cpmdevice.add_questionnaire_response(questionnaire_response=questionnaire_response)
44-
cpmdevice.add_tag(**params)
44+
tag_params = [params]
45+
cpmdevice.add_tags(tags=tag_params)
4546
return cpmdevice
4647

4748

src/api/searchSdsEndpoint/src/v1/steps.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from event.step_chain import StepChain
88
from pydantic import ValidationError
99

10+
FIELDS_TO_DROP = ["tags"]
11+
1012

1113
def parse_event_query(data, cache):
1214
event = APIGatewayProxyEvent(data[StepChain.INIT])
@@ -32,7 +34,7 @@ def query_endpoints(data, cache) -> List[dict]:
3234
device_repo = DeviceRepository(
3335
table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"]
3436
)
35-
results = device_repo.query_by_tag(**query_params)
37+
results = device_repo.query_by_tag(fields_to_drop=FIELDS_TO_DROP, **query_params)
3638
return [result.state() for result in results]
3739

3840

src/api/searchSdsEndpoint/tests/test_index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def _create_device(device, product_team, params):
4141

4242
questionnaire_response = questionnaire.respond(responses=response)
4343
cpmdevice.add_questionnaire_response(questionnaire_response=questionnaire_response)
44-
cpmdevice.add_tag(**params)
44+
tag_params = [params]
45+
cpmdevice.add_tags(tags=tag_params)
4546
return cpmdevice
4647

4748

src/etl/sds/worker/load_bulk/tests/test_load_bulk_worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def all_devices(self) -> Generator[Device, None, None]:
3131
for device in devices:
3232
if not device.get("root"):
3333
continue
34-
device["tags"] = [
35-
pkl_loads_gzip(tag) for tag in pkl_loads_gzip(device["tags"])
36-
]
34+
if device.get("tags"): # Only compress if tags not empty
35+
device["tags"] = [
36+
pkl_loads_gzip(tag) for tag in pkl_loads_gzip(device["tags"])
37+
]
3738
yield Device(**device)
3839

3940
def count(self, by: DeviceType | DeviceKeyType):

src/layers/domain/core/device/v2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,10 @@ def state(self) -> dict:
450450
def is_active(self):
451451
return self.status is Status.ACTIVE
452452

453+
@classmethod
454+
def get_all_fields(cls) -> set[str]:
455+
return set(cls.__fields__.keys())
456+
453457

454458
class DeviceEventDeserializer(EventDeserializer):
455459
event_types = (

src/layers/domain/repository/device_repository/tests/v2/test_device_repository_tags_v2.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from collections import defaultdict
2+
13
import pytest
24
from domain.core.device.v2 import Device, DeviceTag
35
from domain.core.device_key.v2 import DeviceKeyType
4-
from domain.repository.device_repository.v2 import DeviceRepository
6+
from domain.core.enum import Status
7+
from domain.repository.device_repository.v2 import (
8+
MANDATORY_DEVICE_FIELDS,
9+
DeviceRepository,
10+
)
511

612

713
@pytest.mark.integration
@@ -113,3 +119,50 @@ def test__device_repository__add_two_tags_and_then_clear(
113119

114120
assert repository.query_by_tag(shoe_size=123) == []
115121
assert repository.query_by_tag(shoe_size=456) == []
122+
123+
124+
@pytest.mark.integration
125+
@pytest.mark.parametrize(
126+
"field_to_drop, expected_default_value",
127+
[
128+
(["tags"], set()), # If 'tags' is dropped, it should default to an empty set
129+
(["keys"], []), # If 'keys' is dropped, it should default to an empty list
130+
(["status"], Status.ACTIVE), # 'status' should default to Status.ACTIVE
131+
(["updated_on"], None), # 'updated_on' should default to None
132+
(["deleted_on"], None), # 'deleted_on' should default to None
133+
(
134+
["questionnaire_responses"],
135+
defaultdict(dict),
136+
), # 'questionnaire_responses' defaults to an empty dict
137+
],
138+
)
139+
def test__device_repository__drop_fields(
140+
device: Device, repository: DeviceRepository, field_to_drop, expected_default_value
141+
):
142+
repository.write(device)
143+
(_device_123,) = repository.query_by_tag(abc=123)
144+
assert _device_123.dict() == device.dict()
145+
146+
# Query with specific fields to drop
147+
results = repository.query_by_tag(abc=123, fields_to_drop=field_to_drop)
148+
assert len(results) == 1
149+
150+
device_result = results[0]
151+
152+
assert device_result.dict()[field_to_drop[0]] == expected_default_value
153+
assert all(field in device_result.dict() for field in MANDATORY_DEVICE_FIELDS)
154+
155+
156+
@pytest.mark.integration
157+
def test__device_repository__drop_mandatory_fields(
158+
device: Device, repository: DeviceRepository
159+
):
160+
repository.write(device)
161+
(_device_123,) = repository.query_by_tag(abc=123)
162+
assert _device_123.dict() == device.dict()
163+
164+
# Query with mandatory fields to drop
165+
fields_to_drop = list(MANDATORY_DEVICE_FIELDS)
166+
167+
with pytest.raises(ValueError, match="Cannot drop mandatory fields:"):
168+
repository.query_by_tag(abc=123, fields_to_drop=fields_to_drop)

src/layers/domain/repository/device_repository/v2.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ROOT_FIELDS_TO_COMPRESS = [TAGS]
3434
NON_ROOT_FIELDS_TO_COMPRESS = ["questionnaire_responses"]
3535
BATCH_GET_SIZE = 100
36+
MANDATORY_DEVICE_FIELDS = {"name", "device_type", "product_team_id", "ods_code"}
3637

3738

3839
class TooManyResults(Exception):
@@ -42,36 +43,38 @@ class TooManyResults(Exception):
4243
def compress_device_fields(data: Event | dict, fields_to_compress=None) -> dict:
4344
_data = copy(data) if isinstance(data, dict) else asdict(data, recurse=False)
4445

45-
# pop unknown keys
46+
# Pop unknown keys
4647
unknown_keys = _data.keys() - set(Device.__fields__)
4748
for k in unknown_keys:
4849
_data.pop(k)
4950

50-
# compress specified keys if they exist in the data
51+
# Compress specified keys if they exist in the data
5152
fields_to_compress = (fields_to_compress or []) + ROOT_FIELDS_TO_COMPRESS
52-
fields_to_compress_that_exist = [f for f in fields_to_compress if f in _data]
53+
fields_to_compress_that_exist = [f for f in fields_to_compress if _data.get(f)]
5354
for field in fields_to_compress_that_exist:
55+
# Only proceed if the field is not empty
5456
if field == TAGS:
55-
# tags are doubly compressed: first compress each tag in the list,
56-
# and then compress the entire list in the line directly after this
57-
# if-block
57+
# Tags are doubly compressed: first compress each tag in the list
5858
_data[field] = [pkl_dumps_gzip(tag) for tag in _data[field]]
59+
# Compress the entire field (which includes the doubly compressed tags)
5960
_data[field] = pkl_dumps_gzip(_data[field])
6061
return _data
6162

6263

6364
def decompress_device_fields(device: dict):
6465
for field in ROOT_FIELDS_TO_COMPRESS:
65-
device[field] = pkl_loads_gzip(device[field])
66-
if field == TAGS:
67-
# tags are doubly compressed, so first decompress the entire tag list
68-
# in the line directly before this if-block, then decompress each tag
69-
# in the list
70-
device[field] = [pkl_loads_gzip(tag) for tag in device[field]]
71-
72-
if device["root"] is False:
66+
if device.get(field): # Check if the field is present and not empty
67+
device[field] = pkl_loads_gzip(device[field]) # First decompression
68+
if field == TAGS: # Tags are doubly compressed.
69+
# Second decompression: Decompress each tag in the list
70+
device[field] = [pkl_loads_gzip(tag) for tag in device[field]]
71+
72+
# Decompress non-root fields if the device is not a root and fields exist
73+
if not device.get("root"): # Use get to handle missing 'root' field
7374
for field in NON_ROOT_FIELDS_TO_COMPRESS:
74-
device[field] = pkl_loads_gzip(device[field])
75+
if device.get(field): # Check if the field is present and non empty
76+
device[field] = pkl_loads_gzip(device[field])
77+
7578
return device
7679

7780

@@ -553,30 +556,64 @@ def read_inactive(self, *key_parts: str) -> Device:
553556
_device = unmarshall(item)
554557
return Device(**decompress_device_fields(_device))
555558

556-
def query_by_tag(self, **kwargs) -> list[Device]:
559+
def query_by_tag(self, fields_to_drop: list[str] = None, **kwargs) -> list[Device]:
557560
"""
558-
Query the device by predefined tags:
559-
560-
repository.query_by_tag(foo="123", bar="456")
561-
562-
NB: the DeviceTag enforces that values (but not keys) are case insensitive
561+
Query the device by predefined tags, optionally dropping specific fields from the query result.
562+
Example: repository.query_by_tag(fields_to_drop=["field1", "field2"], foo="123", bar="456")
563563
"""
564+
564565
tag_value = DeviceTag(**kwargs).value
565566
pk = TableKey.DEVICE_TAG.key(tag_value)
566567

567-
# Initial query to retrieve a list of all the root-device pk's
568-
response = self.client.query(
569-
ExpressionAttributeValues={":pk": marshall_value(pk)},
570-
KeyConditionExpression="pk = :pk",
571-
TableName=self.table_name,
572-
)
568+
query_params = {
569+
"ExpressionAttributeValues": {":pk": marshall_value(pk)},
570+
"KeyConditionExpression": "pk = :pk",
571+
"TableName": self.table_name,
572+
}
573+
574+
# If fields to drop are provided, create a ProjectionExpression
575+
if fields_to_drop:
576+
all_fields = Device.get_all_fields()
577+
578+
# Ensure no mandatory fields are dropped
579+
dropped_mandatory_fields = set(fields_to_drop) & MANDATORY_DEVICE_FIELDS
580+
if dropped_mandatory_fields:
581+
raise ValueError(
582+
f"Cannot drop mandatory fields: {', '.join(dropped_mandatory_fields)}"
583+
)
584+
585+
fields_to_return = all_fields - set(fields_to_drop)
586+
587+
# DynamoDB ProjectionExpression, specifying which fields to return
588+
query_params.update(_dynamodb_projection_expression(fields_to_return))
589+
590+
# Perform the DynamoDB query
591+
response = self.client.query(**query_params)
592+
573593
# Not yet implemented: pagination
574594
if "LastEvaluatedKey" in response:
575595
raise TooManyResults(f"Too many results for query '{kwargs}'")
576596

577-
# Convert to Device, sorted by 'pk', which would have been
578-
# the expected behaviour if tags in the database were
579-
# Device duplicates rather than references
597+
# Convert to Device, sorted by 'pk'
580598
compressed_devices = map(unmarshall, response["Items"])
581599
devices_as_dict = map(decompress_device_fields, compressed_devices)
600+
582601
return [Device(**d) for d in sorted(devices_as_dict, key=lambda d: d["id"])]
602+
603+
604+
def _dynamodb_projection_expression(updated_fields: list[str]):
605+
expression_attribute_names = {}
606+
update_clauses = []
607+
608+
for field_name in updated_fields:
609+
field_name_placeholder = f"#{field_name}"
610+
611+
update_clauses.append(field_name_placeholder)
612+
expression_attribute_names[field_name_placeholder] = field_name
613+
614+
projection_expression = ", ".join(update_clauses)
615+
616+
return dict(
617+
ProjectionExpression=projection_expression,
618+
ExpressionAttributeNames=expression_attribute_names,
619+
)

src/test_helpers/validate_search_response.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ def validate_result_body(result_body, devices, params):
44
for index, result in enumerate(result_body):
55
validate_device(result, devices[index])
66
validate_keys(result["keys"], devices[index])
7-
validate_tags(result["tags"], params)
7+
validate_tags(result["tags"])
88
validate_questionnaire_responses(result, devices[index], params)
99

1010

@@ -18,10 +18,8 @@ def validate_keys(keys, device):
1818
assert key["key_value"] == device["device_key"]
1919

2020

21-
def validate_tags(tags, params):
22-
for tag in tags:
23-
for key, value in params.items():
24-
assert [key, value.lower()] in tag
21+
def validate_tags(tags):
22+
assert tags == [] # The tags field is dropped due to being chunky
2523

2624

2725
def validate_questionnaire_responses(result, device, params):

0 commit comments

Comments
 (0)