diff --git a/pyrenew_hew/util.py b/pyrenew_hew/util.py index e465c0c0..a4bff42b 100644 --- a/pyrenew_hew/util.py +++ b/pyrenew_hew/util.py @@ -2,6 +2,42 @@ Pyrenew-HEW utilities """ +from itertools import chain, combinations +from typing import Iterable + + +def powerset(iterable: Iterable) -> Iterable: + """ + Subsequences of the iterable from shortest to longest, + considering only unique elements. + + Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes + """ + s = set(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def hew_models(with_null: bool = True) -> Iterable: + """ + Return an iterable of the Pyrenew-HEW models + as their lowercase letters. + + Parameters + ---------- + with_null + Include the null model ("pyrenew_null"), represented as + the empty tuple `()`? Default ``True``. + + Returns + ------- + Iterable + An iterable yielding tuples of model letters. + """ + result = powerset(("h", "e", "w")) + if not with_null: + result = filter(None, result) + return result + def hew_letters_from_flags( fit_ed_visits: bool = False, diff --git a/tests/test_util.py b/tests/test_util.py index 08c898df..9628edb3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,7 +1,11 @@ +from typing import Iterable + import pytest from pyrenew_hew.util import ( hew_letters_from_flags, + hew_models, + powerset, pyrenew_model_name_from_flags, ) @@ -40,3 +44,51 @@ def test_hew_naming_from_flags( ) == f"pyrenew_{expected_letters}" ) + + +@pytest.mark.parametrize( + "test_items", + [ + range(10), + ["a", "b", "c"], + [None, "a", "b"], + [None, None, "a", "b"], + ["a", "b", "a", "a"], + [1, 1, 1.5, 2], + ], +) +def test_powerset(test_items): + pset_iter = powerset(test_items) + pset = set(pset_iter) + assert isinstance(pset_iter, Iterable) + assert set([(item,) for item in test_items]).issubset(pset) + assert len(pset) == 2 ** len(set(test_items)) + assert () in pset + + +def test_hew_model_iterator(): + expected = [ + (), + ("h",), + ("e",), + ("w",), + ( + "e", + "w", + ), + ( + "h", + "e", + ), + ( + "h", + "w", + ), + ("h", "e", "w"), + ] + assert set([tuple(sorted(i)) for i in hew_models()]) == set( + [tuple(sorted(i)) for i in expected] + ) + assert set([tuple(sorted(i)) for i in hew_models(False)]) == set( + [tuple(sorted(i)) for i in filter(None, expected)] + )