diff --git a/src/naibr/__main__.py b/src/naibr/__main__.py index c3eed80..52b8de2 100644 --- a/src/naibr/__main__.py +++ b/src/naibr/__main__.py @@ -58,6 +58,7 @@ input_candidates, is_proper_chrom, parallel_execute, + UnionDict, write_novel_adjacencies, ) @@ -286,27 +287,6 @@ def parse_args(args): return file_configs -class UnionDict(collections.defaultdict): - """Extends defaultdict to allow combining with other UnionDicts or defaultdicts""" - - def combine(self, other): - if not isinstance(other, type(self)) and not isinstance(other, collections.defaultdict): - raise TypeError(f"Can only combine with other UnionDicts or defaultdicts ({type(other)})") - - if self.default_factory is not other.default_factory: - raise ValueError("Can only combine UnionDicts with the same default_factory") - - if self.default_factory is list: - for k, v in other.items(): - self[k].extend(v) - - elif self.default_factory is set: - for k, v in other.items(): - self[k].update(v) - else: - raise ValueError("Can only combine UnionDicts with default_factory of list or set") - - def run(configs): starttime = time.time() diff --git a/src/naibr/utils.py b/src/naibr/utils.py index 1d2ea45..8fd736b 100644 --- a/src/naibr/utils.py +++ b/src/naibr/utils.py @@ -606,3 +606,24 @@ def parse_naibr(lines): cands = [c for c in cands if abs(c[1] - c[3]) > min_sv or c[0] != c[2]] logger.info(f"Found {n_total:,} candidates in file of which {len(cands):,} are long enough") return cands + + +class UnionDict(defaultdict): + """Extends defaultdict to allow combining with other UnionDicts or defaultdicts""" + + def combine(self, other): + if not isinstance(other, type(self)) and not isinstance(other, defaultdict): + raise TypeError(f"Can only combine with other UnionDicts or defaultdicts ({type(other)})") + + if self.default_factory is not other.default_factory: + raise ValueError("Can only combine UnionDicts with the same default_factory") + + if self.default_factory is list: + for k, v in other.items(): + self[k].extend(v) + + elif self.default_factory is set: + for k, v in other.items(): + self[k].update(v) + else: + raise ValueError("Can only combine UnionDicts with default_factory of list or set") diff --git a/tests/test_utils.py b/tests/test_utils.py index 029b64d..f3445e5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ from io import StringIO -from naibr.utils import input_candidates +import pytest + +from naibr.utils import input_candidates, UnionDict def test_input_candates_bepde_format(): @@ -40,3 +42,47 @@ def test_input_candates_filter_too_short(): s += "chr1\t10000\t10000\tchr1\t14000\t14000\t-+\n" with StringIO(s) as f: assert len(input_candidates(f, min_sv=1000)) == 2 + + +def test_union_dict_combine_lists(): + d1 = UnionDict(list) + d1["a"].extend([1, 2, 3]) + d1["b"].extend([7, 8, 9]) + + d2 = UnionDict(list) + d2["a"].extend([3, 4, 5]) + d2["b"].extend([8, 9, 10]) + + d1.combine(d2) + assert sorted(d1["a"]) == sorted([1, 2, 3, 3, 4, 5]) + assert sorted(d1["b"]) == sorted([7, 8, 8, 9, 9, 10]) + + +def test_union_dict_combine_sets(): + d1 = UnionDict(set) + d1["a"].union({1, 2, 3}) + d1["b"].union({7, 8, 9}) + + d2 = UnionDict(set) + d2["a"].union({3, 4, 5}) + d2["b"].union({8, 9, 10}) + + d1.combine(d2) + assert d1["a"] - {1, 2, 3, 4, 5} == set() + assert d1["b"] - {7, 8, 9, 10} == set() + + +def test_union_dict_combine_different_fails(): + d1 = UnionDict(set) + d1["a"].union({1, 2, 3}) + d1["b"].union({7, 8, 9}) + + d2 = UnionDict(list) + d2["a"].extend([3, 4, 5]) + d2["b"].extend([8, 9, 10]) + + with pytest.raises(ValueError): + d1.combine(d2) + + with pytest.raises(ValueError): + d2.combine(d1)