From 126c6efb27b07bf580ce748dff8652d467920225 Mon Sep 17 00:00:00 2001 From: frehburg Date: Thu, 17 Oct 2024 17:43:59 +0200 Subject: [PATCH] wrote method to recurse through datamodel --- .../data_standards/data_model.py | 27 +++++++- ...ecursive_collect_all_members_data_model.py | 62 +++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_recursive_collect_all_members_data_model.py diff --git a/src/phenopacket_mapper/data_standards/data_model.py b/src/phenopacket_mapper/data_standards/data_model.py index 1821b72..5cd9a2d 100644 --- a/src/phenopacket_mapper/data_standards/data_model.py +++ b/src/phenopacket_mapper/data_standards/data_model.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Union, List, Literal, Dict, Optional, Any, Callable, Tuple +from typing import Union, List, Literal, Dict, Optional, Any, Callable, Tuple, Iterable import warnings import pandas as pd @@ -522,3 +522,28 @@ def __getattr__(self, var_name: str) -> Union[DataField, DataSection, 'OrGroup'] if f.id == var_name: return f raise AttributeError(f"'OrGroup' object has no attribute '{var_name}'") + + +def recursive_collect_all_members_data_model( + data_model: Union[DataModel, DataSection, OrGroup, DataField] +) -> Iterable[Union[DataSection, OrGroup, DataField]]: + """Recursively collect all members of a DataModel, DataSection, OrGroup, or DataField + + :param data_model: DataModel, DataSection, OrGroup, or DataField to collect all members from + :return: Iterable of DataSection, OrGroup, and DataField members + """ + if isinstance(data_model, DataModel): + for f in data_model.fields: + yield from recursive_collect_all_members_data_model(f) + elif isinstance(data_model, DataSection): + yield data_model + for f in data_model.fields: + yield from recursive_collect_all_members_data_model(f) + elif isinstance(data_model, OrGroup): + yield data_model + for f in data_model.fields: + yield from recursive_collect_all_members_data_model(f) + elif isinstance(data_model, DataField): + yield data_model + else: + raise ValueError(f"Unsupported data_model type: {type(data_model)}") diff --git a/tests/utils/test_recursive_collect_all_members_data_model.py b/tests/utils/test_recursive_collect_all_members_data_model.py new file mode 100644 index 0000000..fe6b146 --- /dev/null +++ b/tests/utils/test_recursive_collect_all_members_data_model.py @@ -0,0 +1,62 @@ +import pytest + +from phenopacket_mapper import DataModel +from phenopacket_mapper.data_standards import DataField, DataSection, OrGroup +from phenopacket_mapper.data_standards.data_model import recursive_collect_all_members_data_model + +df1 = DataField( + name="test_field_1", + specification=str, +) + +df2 = DataField( + name="test_field_2", + specification=int, +) + +df3 = DataField( + name="test_field_3", + specification=bool, +) + +ds1 = DataSection( + name="test_section_1", + fields=(df1, df2) +) + +og1 = OrGroup( + name="test_or_group_1", + fields=(df1, df2) +) + + +@pytest.mark.parametrize( + "data_model, members", + [ + ( + DataModel( + name="test", + fields=(df1, df2) + ), + [df1, df2] + ), # tabular data model + + ( + DataModel( + name="test", + fields=(ds1, df3) + ), + [df1, df2, ds1, df3] + ), # hierarchical with section data model + + ( + DataModel( + name="test", + fields=(og1, df3) + ), + [df1, df2, og1, df3] + ), # hierarchical with or group data model + ] +) +def test_recursive_collect_all_members_data_model(data_model, members): + assert set(recursive_collect_all_members_data_model(data_model)) == set(members)