diff --git a/src/layers/etl_utils/ldif/ldif.py b/src/layers/etl_utils/ldif/ldif.py index c461a972..66ada7f7 100644 --- a/src/layers/etl_utils/ldif/ldif.py +++ b/src/layers/etl_utils/ldif/ldif.py @@ -1,7 +1,8 @@ import re +from collections import defaultdict from io import BytesIO from types import FunctionType -from typing import IO, TYPE_CHECKING, Callable, Generator +from typing import IO, TYPE_CHECKING, Callable, Generator, Protocol from etl_utils.ldif.model import DistinguishedName from smart_open import open as _smart_open @@ -67,6 +68,13 @@ def ldif_dump(fp: IO, obj: list[PARSED_RECORD]) -> str: ) +class _StreamBlock(Protocol): + def flush(self) -> str: ... + def reset(self): ... + def parse(self, line: bytes): ... + def __bool__(self): ... + + class StreamBlock: def __init__(self, filter_terms: list[tuple[str, str]]): self.data = BytesIO() @@ -80,7 +88,7 @@ def flush(self) -> str: self.data.write(self.buffer) self.reset() - def reset(self) -> str: + def reset(self): self.buffer = bytes() self.keep = False @@ -93,20 +101,41 @@ def __bool__(self): return bool(self.buffer) and self.keep -def filter_ldif_from_s3_by_property( - s3_path, filter_terms: list[tuple[str, str]], s3_client: "S3Client" -) -> memoryview: - """ - Efficiently streams a file from S3 directly into a bytes memoryview, - filtering out any LDIF record without any (attribute_name, attribute_value) - matching at least one of the filter terms. +class GroupedStreamBlock: + def __init__(self, group_field: str, filter_terms: list[tuple[str, str]]): + self.data = defaultdict(BytesIO) + self.filters: list[FunctionType] = [ + re.compile(rf"(?i)^({key}): ({value})\n$".encode()).match + for key, value in filter_terms + ] + self.group_filter = re.compile(rf"(?i)^({group_field}): (.*)\n$".encode()).match + self.reset() - The output of this function can then be parsed using' - 'parse_ldif(file_opener=BytesIO, path_or_data=filtered_ldif)' - """ + def flush(self) -> str: + if self.group is None: + raise Exception + self.data[self.group].write(self.buffer) + self.reset() - stream_block = StreamBlock(filter_terms) + def reset(self) -> str: + self.buffer = bytes() + self.keep = False + self.group = None + + def parse(self, line: bytes): + group_match = self.group_filter(line) + if group_match: + (_, self.group) = group_match.groups() + + if not self.keep and any(filter(line) for filter in self.filters): + self.keep = True + self.buffer += line + def __bool__(self): + return bool(self.buffer) and self.keep + + +def stream_to_block(s3_path: str, s3_client: "S3Client", stream_block: _StreamBlock): with _smart_open(s3_path, mode="rb", transport_params={"client": s3_client}) as f: for line in f.readlines(): line_is_empty = line.strip() == EMPTY_BYTESTRING @@ -118,4 +147,46 @@ def filter_ldif_from_s3_by_property( if stream_block: stream_block.flush() + + +def filter_ldif_from_s3_by_property( + s3_path, filter_terms: list[tuple[str, str]], s3_client: "S3Client" +) -> memoryview: + """ + Efficiently streams a file from S3 directly into a bytes memoryview, + filtering out any LDIF record without any (attribute_name, attribute_value) + matching at least one of the filter terms. + + The output of this function can then be parsed using' + 'parse_ldif(file_opener=BytesIO, path_or_data=filtered_ldif)' + """ + stream_block = StreamBlock(filter_terms) + stream_to_block(s3_path=s3_path, s3_client=s3_client, stream_block=stream_block) return stream_block.data.getbuffer() + + +def filter_and_group_ldif_from_s3_by_property( + s3_path, + group_field: str, + filter_terms: list[tuple[str, str]], + s3_client: "S3Client", +) -> memoryview: + """ + Efficiently streams a file from S3 directly into a bytes memoryview, + filtering out any LDIF record without any (attribute_name, attribute_value) + matching at least one of the filter terms, and then also grouping records + by the group_field. + + The output of this function can then be parsed using' + 'parse_ldif(file_opener=BytesIO, path_or_data=filtered_and_grouped_ldif)' + """ + + stream_block = GroupedStreamBlock( + group_field=group_field, filter_terms=filter_terms + ) + stream_to_block(s3_path=s3_path, s3_client=s3_client, stream_block=stream_block) + + data = BytesIO() + for group in stream_block.data.values(): + data.write(group.getbuffer()) + return data.getbuffer() diff --git a/src/layers/etl_utils/ldif/tests/test_ldif.py b/src/layers/etl_utils/ldif/tests/test_ldif.py index 77146232..006d33ec 100644 --- a/src/layers/etl_utils/ldif/tests/test_ldif.py +++ b/src/layers/etl_utils/ldif/tests/test_ldif.py @@ -5,6 +5,7 @@ import pytest from etl_utils.ldif.ldif import ( DistinguishedName, + filter_and_group_ldif_from_s3_by_property, filter_ldif_from_s3_by_property, ldif_dump, parse_ldif, @@ -247,6 +248,54 @@ }, ) +LDIF_TO_FILTER_AND_GROUP_EXAMPLE = """ +dn: uniqueIdentifier=AAA1 +myField: AAA +myOtherField: 123 + +dn: uniqueIdentifier=BBB1 +myfield: BBB +myOtherField: 123 + +dn: uniqueIdentifier=BBB2 +myfield: BBB +myOtherField: 123 + +dn: uniqueIdentifier=AAA2 +myfield: AAA +myOtherField: 123 + +dn: uniqueIdentifier=AAA3 +myField: AAA +myOtherField: 234 + +dn: uniqueIdentifier=BBB3 +myfield: BBB +myOtherField: 123 +""" + +FILTERED_AND_GROUPED_LDIF_TO_FILTER_AND_GROUP_EXAMPLE = """ +dn: uniqueIdentifier=AAA1 +myField: AAA +myOtherField: 123 + +dn: uniqueIdentifier=AAA2 +myfield: AAA +myOtherField: 123 + +dn: uniqueIdentifier=BBB1 +myfield: BBB +myOtherField: 123 + +dn: uniqueIdentifier=BBB2 +myfield: BBB +myOtherField: 123 + +dn: uniqueIdentifier=BBB3 +myfield: BBB +myOtherField: 123 +""" + @pytest.mark.parametrize( ("raw_distinguished_name", "parsed_distinguished_name"), @@ -322,6 +371,25 @@ def test_filter_ldif_from_s3_by_property(mocked_open): ) +@mock.patch( + "etl_utils.ldif.ldif._smart_open", + return_value=BytesIO(LDIF_TO_FILTER_AND_GROUP_EXAMPLE.encode()), +) +def test_filter_and_group_ldif_from_s3_by_property(mocked_open): + with mock_aws(): + s3_client = boto3.client("s3") + filtered_ldif = filter_and_group_ldif_from_s3_by_property( + s3_client=s3_client, + s3_path="s3://dummy_bucket/dummy_key", + group_field="myField", + filter_terms=[("myOtherField", "123")], + ) + assert ( + filtered_ldif.tobytes().decode() + == FILTERED_AND_GROUPED_LDIF_TO_FILTER_AND_GROUP_EXAMPLE + ) + + @pytest.mark.parametrize( ["raw_ldif", "parsed_ldif"], [