Skip to content

Commit 126c6ef

Browse files
committed
wrote method to recurse through datamodel
1 parent 9537006 commit 126c6ef

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

src/phenopacket_mapper/data_standards/data_model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from dataclasses import dataclass, field
1212
from pathlib import Path
13-
from typing import Union, List, Literal, Dict, Optional, Any, Callable, Tuple
13+
from typing import Union, List, Literal, Dict, Optional, Any, Callable, Tuple, Iterable
1414
import warnings
1515

1616
import pandas as pd
@@ -522,3 +522,28 @@ def __getattr__(self, var_name: str) -> Union[DataField, DataSection, 'OrGroup']
522522
if f.id == var_name:
523523
return f
524524
raise AttributeError(f"'OrGroup' object has no attribute '{var_name}'")
525+
526+
527+
def recursive_collect_all_members_data_model(
528+
data_model: Union[DataModel, DataSection, OrGroup, DataField]
529+
) -> Iterable[Union[DataSection, OrGroup, DataField]]:
530+
"""Recursively collect all members of a DataModel, DataSection, OrGroup, or DataField
531+
532+
:param data_model: DataModel, DataSection, OrGroup, or DataField to collect all members from
533+
:return: Iterable of DataSection, OrGroup, and DataField members
534+
"""
535+
if isinstance(data_model, DataModel):
536+
for f in data_model.fields:
537+
yield from recursive_collect_all_members_data_model(f)
538+
elif isinstance(data_model, DataSection):
539+
yield data_model
540+
for f in data_model.fields:
541+
yield from recursive_collect_all_members_data_model(f)
542+
elif isinstance(data_model, OrGroup):
543+
yield data_model
544+
for f in data_model.fields:
545+
yield from recursive_collect_all_members_data_model(f)
546+
elif isinstance(data_model, DataField):
547+
yield data_model
548+
else:
549+
raise ValueError(f"Unsupported data_model type: {type(data_model)}")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
3+
from phenopacket_mapper import DataModel
4+
from phenopacket_mapper.data_standards import DataField, DataSection, OrGroup
5+
from phenopacket_mapper.data_standards.data_model import recursive_collect_all_members_data_model
6+
7+
df1 = DataField(
8+
name="test_field_1",
9+
specification=str,
10+
)
11+
12+
df2 = DataField(
13+
name="test_field_2",
14+
specification=int,
15+
)
16+
17+
df3 = DataField(
18+
name="test_field_3",
19+
specification=bool,
20+
)
21+
22+
ds1 = DataSection(
23+
name="test_section_1",
24+
fields=(df1, df2)
25+
)
26+
27+
og1 = OrGroup(
28+
name="test_or_group_1",
29+
fields=(df1, df2)
30+
)
31+
32+
33+
@pytest.mark.parametrize(
34+
"data_model, members",
35+
[
36+
(
37+
DataModel(
38+
name="test",
39+
fields=(df1, df2)
40+
),
41+
[df1, df2]
42+
), # tabular data model
43+
44+
(
45+
DataModel(
46+
name="test",
47+
fields=(ds1, df3)
48+
),
49+
[df1, df2, ds1, df3]
50+
), # hierarchical with section data model
51+
52+
(
53+
DataModel(
54+
name="test",
55+
fields=(og1, df3)
56+
),
57+
[df1, df2, og1, df3]
58+
), # hierarchical with or group data model
59+
]
60+
)
61+
def test_recursive_collect_all_members_data_model(data_model, members):
62+
assert set(recursive_collect_all_members_data_model(data_model)) == set(members)

0 commit comments

Comments
 (0)