Skip to content

Commit

Permalink
Add hew model iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Feb 12, 2025
1 parent 0b4c4d7 commit 85d1446
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
36 changes: 36 additions & 0 deletions pyrenew_hew/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,42 @@
Pyrenew-HEW utilities
"""

from itertools import chain, combinations
from typing import Iterable


def powerset(iterable: Iterable) -> Iterable:
"""
Subsequences of the iterable from shortest to longest,
considering only unique elements.
Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
"""
s = set(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))


def hew_models(with_null: bool = True) -> Iterable:
"""
Return an iterable of the Pyrenew-HEW models
as their lowercase letters.
Parameters
----------
with_null
Include the null model ("pyrenew_null"), represented as
the empty tuple `()`? Default ``True``.
Returns
-------
Iterable
An iterable yielding tuples of model letters.
"""
result = powerset(("h", "e", "w"))
if not with_null:
result = filter(None, result)
return result


def hew_letters_from_flags(
fit_ed_visits: bool = False,
Expand Down
52 changes: 52 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Iterable

import pytest

from pyrenew_hew.util import (
hew_letters_from_flags,
hew_models,
powerset,
pyrenew_model_name_from_flags,
)

Expand Down Expand Up @@ -40,3 +44,51 @@ def test_hew_naming_from_flags(
)
== f"pyrenew_{expected_letters}"
)


@pytest.mark.parametrize(
"test_items",
[
range(10),
["a", "b", "c"],
[None, "a", "b"],
[None, None, "a", "b"],
["a", "b", "a", "a"],
[1, 1, 1.5, 2],
],
)
def test_powerset(test_items):
pset_iter = powerset(test_items)
pset = set(pset_iter)
assert isinstance(pset_iter, Iterable)
assert set([(item,) for item in test_items]).issubset(pset)
assert len(pset) == 2 ** len(set(test_items))
assert () in pset


def test_hew_model_iterator():
expected = [
(),
("h",),
("e",),
("w",),
(
"e",
"w",
),
(
"h",
"e",
),
(
"h",
"w",
),
("h", "e", "w"),
]
assert set([tuple(sorted(i)) for i in hew_models()]) == set(
[tuple(sorted(i)) for i in expected]
)
assert set([tuple(sorted(i)) for i in hew_models(False)]) == set(
[tuple(sorted(i)) for i in filter(None, expected)]
)

0 comments on commit 85d1446

Please sign in to comment.