Skip to content

Commit

Permalink
move UnionDict to utils and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pontushojer committed Oct 12, 2023
1 parent 9a61f8d commit d8d1b26
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 22 deletions.
22 changes: 1 addition & 21 deletions src/naibr/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
input_candidates,
is_proper_chrom,
parallel_execute,
UnionDict,
write_novel_adjacencies,
)

Expand Down Expand Up @@ -286,27 +287,6 @@ def parse_args(args):
return file_configs


class UnionDict(collections.defaultdict):
"""Extends defaultdict to allow combining with other UnionDicts or defaultdicts"""

def combine(self, other):
if not isinstance(other, type(self)) and not isinstance(other, collections.defaultdict):
raise TypeError(f"Can only combine with other UnionDicts or defaultdicts ({type(other)})")

if self.default_factory is not other.default_factory:
raise ValueError("Can only combine UnionDicts with the same default_factory")

if self.default_factory is list:
for k, v in other.items():
self[k].extend(v)

elif self.default_factory is set:
for k, v in other.items():
self[k].update(v)
else:
raise ValueError("Can only combine UnionDicts with default_factory of list or set")


def run(configs):
starttime = time.time()

Expand Down
21 changes: 21 additions & 0 deletions src/naibr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,24 @@ def parse_naibr(lines):
cands = [c for c in cands if abs(c[1] - c[3]) > min_sv or c[0] != c[2]]
logger.info(f"Found {n_total:,} candidates in file of which {len(cands):,} are long enough")
return cands


class UnionDict(defaultdict):
"""Extends defaultdict to allow combining with other UnionDicts or defaultdicts"""

def combine(self, other):
if not isinstance(other, type(self)) and not isinstance(other, defaultdict):
raise TypeError(f"Can only combine with other UnionDicts or defaultdicts ({type(other)})")

if self.default_factory is not other.default_factory:
raise ValueError("Can only combine UnionDicts with the same default_factory")

if self.default_factory is list:
for k, v in other.items():
self[k].extend(v)

elif self.default_factory is set:
for k, v in other.items():
self[k].update(v)
else:
raise ValueError("Can only combine UnionDicts with default_factory of list or set")
48 changes: 47 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from io import StringIO

from naibr.utils import input_candidates
import pytest

from naibr.utils import input_candidates, UnionDict


def test_input_candates_bepde_format():
Expand Down Expand Up @@ -40,3 +42,47 @@ def test_input_candates_filter_too_short():
s += "chr1\t10000\t10000\tchr1\t14000\t14000\t-+\n"
with StringIO(s) as f:
assert len(input_candidates(f, min_sv=1000)) == 2


def test_union_dict_combine_lists():
d1 = UnionDict(list)
d1["a"].extend([1, 2, 3])
d1["b"].extend([7, 8, 9])

d2 = UnionDict(list)
d2["a"].extend([3, 4, 5])
d2["b"].extend([8, 9, 10])

d1.combine(d2)
assert sorted(d1["a"]) == sorted([1, 2, 3, 3, 4, 5])
assert sorted(d1["b"]) == sorted([7, 8, 8, 9, 9, 10])


def test_union_dict_combine_sets():
d1 = UnionDict(set)
d1["a"].union({1, 2, 3})
d1["b"].union({7, 8, 9})

d2 = UnionDict(set)
d2["a"].union({3, 4, 5})
d2["b"].union({8, 9, 10})

d1.combine(d2)
assert d1["a"] - {1, 2, 3, 4, 5} == set()
assert d1["b"] - {7, 8, 9, 10} == set()


def test_union_dict_combine_different_fails():
d1 = UnionDict(set)
d1["a"].union({1, 2, 3})
d1["b"].union({7, 8, 9})

d2 = UnionDict(list)
d2["a"].extend([3, 4, 5])
d2["b"].extend([8, 9, 10])

with pytest.raises(ValueError):
d1.combine(d2)

with pytest.raises(ValueError):
d2.combine(d1)

0 comments on commit d8d1b26

Please sign in to comment.