33
33
ROOT_FIELDS_TO_COMPRESS = [TAGS ]
34
34
NON_ROOT_FIELDS_TO_COMPRESS = ["questionnaire_responses" ]
35
35
BATCH_GET_SIZE = 100
36
+ MANDATORY_DEVICE_FIELDS = {"name" , "device_type" , "product_team_id" , "ods_code" }
36
37
37
38
38
39
class TooManyResults (Exception ):
@@ -42,36 +43,38 @@ class TooManyResults(Exception):
42
43
def compress_device_fields (data : Event | dict , fields_to_compress = None ) -> dict :
43
44
_data = copy (data ) if isinstance (data , dict ) else asdict (data , recurse = False )
44
45
45
- # pop unknown keys
46
+ # Pop unknown keys
46
47
unknown_keys = _data .keys () - set (Device .__fields__ )
47
48
for k in unknown_keys :
48
49
_data .pop (k )
49
50
50
- # compress specified keys if they exist in the data
51
+ # Compress specified keys if they exist in the data
51
52
fields_to_compress = (fields_to_compress or []) + ROOT_FIELDS_TO_COMPRESS
52
- fields_to_compress_that_exist = [f for f in fields_to_compress if f in _data ]
53
+ fields_to_compress_that_exist = [f for f in fields_to_compress if _data . get ( f ) ]
53
54
for field in fields_to_compress_that_exist :
55
+ # Only proceed if the field is not empty
54
56
if field == TAGS :
55
- # tags are doubly compressed: first compress each tag in the list,
56
- # and then compress the entire list in the line directly after this
57
- # if-block
57
+ # Tags are doubly compressed: first compress each tag in the list
58
58
_data [field ] = [pkl_dumps_gzip (tag ) for tag in _data [field ]]
59
+ # Compress the entire field (which includes the doubly compressed tags)
59
60
_data [field ] = pkl_dumps_gzip (_data [field ])
60
61
return _data
61
62
62
63
63
64
def decompress_device_fields (device : dict ):
64
65
for field in ROOT_FIELDS_TO_COMPRESS :
65
- device [ field ] = pkl_loads_gzip ( device [ field ])
66
- if field == TAGS :
67
- # tags are doubly compressed, so first decompress the entire tag list
68
- # in the line directly before this if-block, then decompress each tag
69
- # in the list
70
- device [ field ] = [ pkl_loads_gzip ( tag ) for tag in device [ field ]]
71
-
72
- if device [ "root" ] is False :
66
+ if device . get ( field ): # Check if the field is present and not empty
67
+ device [ field ] = pkl_loads_gzip ( device [ field ]) # First decompression
68
+ if field == TAGS : # Tags are doubly compressed.
69
+ # Second decompression: Decompress each tag in the list
70
+ device [ field ] = [ pkl_loads_gzip ( tag ) for tag in device [ field ]]
71
+
72
+ # Decompress non-root fields if the device is not a root and fields exist
73
+ if not device . get ( "root" ): # Use get to handle missing 'root' field
73
74
for field in NON_ROOT_FIELDS_TO_COMPRESS :
74
- device [field ] = pkl_loads_gzip (device [field ])
75
+ if device .get (field ): # Check if the field is present and non empty
76
+ device [field ] = pkl_loads_gzip (device [field ])
77
+
75
78
return device
76
79
77
80
@@ -553,30 +556,64 @@ def read_inactive(self, *key_parts: str) -> Device:
553
556
_device = unmarshall (item )
554
557
return Device (** decompress_device_fields (_device ))
555
558
556
- def query_by_tag (self , ** kwargs ) -> list [Device ]:
559
+ def query_by_tag (self , fields_to_drop : list [ str ] = None , ** kwargs ) -> list [Device ]:
557
560
"""
558
- Query the device by predefined tags:
559
-
560
- repository.query_by_tag(foo="123", bar="456")
561
-
562
- NB: the DeviceTag enforces that values (but not keys) are case insensitive
561
+ Query the device by predefined tags, optionally dropping specific fields from the query result.
562
+ Example: repository.query_by_tag(fields_to_drop=["field1", "field2"], foo="123", bar="456")
563
563
"""
564
+
564
565
tag_value = DeviceTag (** kwargs ).value
565
566
pk = TableKey .DEVICE_TAG .key (tag_value )
566
567
567
- # Initial query to retrieve a list of all the root-device pk's
568
- response = self .client .query (
569
- ExpressionAttributeValues = {":pk" : marshall_value (pk )},
570
- KeyConditionExpression = "pk = :pk" ,
571
- TableName = self .table_name ,
572
- )
568
+ query_params = {
569
+ "ExpressionAttributeValues" : {":pk" : marshall_value (pk )},
570
+ "KeyConditionExpression" : "pk = :pk" ,
571
+ "TableName" : self .table_name ,
572
+ }
573
+
574
+ # If fields to drop are provided, create a ProjectionExpression
575
+ if fields_to_drop :
576
+ all_fields = Device .get_all_fields ()
577
+
578
+ # Ensure no mandatory fields are dropped
579
+ dropped_mandatory_fields = set (fields_to_drop ) & MANDATORY_DEVICE_FIELDS
580
+ if dropped_mandatory_fields :
581
+ raise ValueError (
582
+ f"Cannot drop mandatory fields: { ', ' .join (dropped_mandatory_fields )} "
583
+ )
584
+
585
+ fields_to_return = all_fields - set (fields_to_drop )
586
+
587
+ # DynamoDB ProjectionExpression, specifying which fields to return
588
+ query_params .update (_dynamodb_projection_expression (fields_to_return ))
589
+
590
+ # Perform the DynamoDB query
591
+ response = self .client .query (** query_params )
592
+
573
593
# Not yet implemented: pagination
574
594
if "LastEvaluatedKey" in response :
575
595
raise TooManyResults (f"Too many results for query '{ kwargs } '" )
576
596
577
- # Convert to Device, sorted by 'pk', which would have been
578
- # the expected behaviour if tags in the database were
579
- # Device duplicates rather than references
597
+ # Convert to Device, sorted by 'pk'
580
598
compressed_devices = map (unmarshall , response ["Items" ])
581
599
devices_as_dict = map (decompress_device_fields , compressed_devices )
600
+
582
601
return [Device (** d ) for d in sorted (devices_as_dict , key = lambda d : d ["id" ])]
602
+
603
+
604
+ def _dynamodb_projection_expression (updated_fields : list [str ]):
605
+ expression_attribute_names = {}
606
+ update_clauses = []
607
+
608
+ for field_name in updated_fields :
609
+ field_name_placeholder = f"#{ field_name } "
610
+
611
+ update_clauses .append (field_name_placeholder )
612
+ expression_attribute_names [field_name_placeholder ] = field_name
613
+
614
+ projection_expression = ", " .join (update_clauses )
615
+
616
+ return dict (
617
+ ProjectionExpression = projection_expression ,
618
+ ExpressionAttributeNames = expression_attribute_names ,
619
+ )
0 commit comments