Skip to content

Commit 2d0c4fa

Browse files
committed
Improve caching in referenceless
1 parent fc24bf5 commit 2d0c4fa

File tree

2 files changed

+52
-32
lines changed

2 files changed

+52
-32
lines changed

micall/tests/test_referenceless_contig_stitcher.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
2-
from micall.utils.referenceless_contig_stitcher import stitch_consensus
3-
from micall.utils.contig_stitcher_contigs import Contig
2+
from micall.utils.referenceless_contig_stitcher import stitch_consensus, ContigWithAligner
43
from micall.utils.contig_stitcher_context import StitcherContext
54

65

@@ -54,7 +53,7 @@ def disable_acceptable_prob_check(monkeypatch):
5453
],
5554
)
5655
def test_stitch_simple_cases(seqs, expected):
57-
contigs = [Contig(None, seq) for seq in seqs]
56+
contigs = [ContigWithAligner(None, seq) for seq in seqs]
5857
with StitcherContext.fresh():
5958
consenses = tuple(contig.seq for contig in stitch_consensus(contigs))
6059
assert consenses == expected

micall/utils/referenceless_contig_stitcher.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from Bio.SeqRecord import SeqRecord
66
import logging
77
from mappy import Aligner
8+
from functools import cached_property
89

910
from micall.utils.contig_stitcher_context import StitcherContext
1011
from micall.utils.consensus_aligner import Alignment
@@ -19,10 +20,30 @@
1920
ACCEPTABLE_STITCHING_PROB = Fraction(1, 20)
2021

2122

23+
@dataclass(frozen=True)
24+
class ContigWithAligner(Contig):
25+
@cached_property
26+
def aligner(self) -> Aligner:
27+
return Aligner(seq=str(self.seq), bw=500, bw_long=500, preset='map-ont')
28+
29+
@staticmethod
30+
def make(contig: Contig) -> 'ContigWithAligner':
31+
return ContigWithAligner(name=contig.name, seq=contig.seq)
32+
33+
@staticmethod
34+
def empty() -> 'ContigWithAligner':
35+
return ContigWithAligner.make(Contig.empty())
36+
37+
def map_overlap(self, overlap: str) -> Iterator[Alignment]:
38+
for x in self.aligner.map(overlap):
39+
if x.is_primary:
40+
yield x
41+
42+
2243
@dataclass(frozen=True)
2344
class ContigsPath:
2445
# Contig representing all combined contigs in the path.
25-
whole: Contig
46+
whole: ContigWithAligner
2647

2748
# Id's of contigs that comprise this path.
2849
parts_ids: FrozenSet[int]
@@ -44,7 +65,7 @@ def is_empty(self) -> bool:
4465

4566
@staticmethod
4667
def empty() -> 'ContigsPath':
47-
return ContigsPath(Contig.empty(), frozenset(),
68+
return ContigsPath(ContigWithAligner.empty(), frozenset(),
4869
probability=Fraction(1),
4970
pessimisstic_probability=ACCEPTABLE_STITCHING_PROB)
5071

@@ -58,29 +79,29 @@ class Overlap:
5879
shift: int
5980

6081

61-
def get_overlap(finder: OverlapFinder, left: Contig, right: Contig) -> Optional[Overlap]:
82+
GET_OVERLAP_CACHE: MutableMapping[Tuple[int, int], Optional[Overlap]] = {}
83+
84+
85+
def get_overlap(finder: OverlapFinder, left: ContigWithAligner, right: ContigWithAligner) -> Optional[Overlap]:
6286
if len(left.seq) == 0 or len(right.seq) == 0:
6387
return None
6488

89+
key = (left.id, right.id)
6590
shift = find_maximum_overlap(left.seq, right.seq, finder=finder)
6691
if shift == 0:
67-
return None
68-
69-
return Overlap(shift)
70-
92+
ret = None
93+
GET_OVERLAP_CACHE[key] = ret
94+
return ret
7195

72-
def map_overlap_onto_candidate(overlap: str, candidate: str) -> Iterator[Alignment]:
73-
# TODO: Move this implementation into consensus_aligner maybe.
74-
aligner = Aligner(seq=candidate, bw=500, bw_long=500, preset='map-ont')
75-
for x in aligner.map(overlap):
76-
if x.is_primary:
77-
yield x
96+
ret = Overlap(shift)
97+
GET_OVERLAP_CACHE[key] = ret
98+
return ret
7899

79100

80101
def try_combine_contigs(finder: OverlapFinder,
81102
max_acceptable_prob: Fraction,
82-
a: Contig, b: Contig,
83-
) -> Optional[Tuple[Contig, Fraction]]:
103+
a: ContigWithAligner, b: ContigWithAligner,
104+
) -> Optional[Tuple[ContigWithAligner, Fraction]]:
84105
# TODO: Memoize this function.
85106
# Two-layer caching seems most optimal:
86107
# first by key=contig.id, then by key=contig.seq.
@@ -128,25 +149,25 @@ def try_combine_contigs(finder: OverlapFinder,
128149
right_initial_overlap = right.seq[:abs(shift)]
129150

130151
if len(left_initial_overlap) < len(right_initial_overlap):
131-
left_overlap_alignments = map_overlap_onto_candidate(str(right_initial_overlap), str(left.seq))
152+
left_overlap_alignments = left.map_overlap(str(right_initial_overlap))
132153
left_cutoff = min((al.r_st for al in left_overlap_alignments), default=None)
133154
if left_cutoff is None:
134155
logger.debug("Overlap alignment between %s and %s failed.", a.unique_name, b.unique_name)
135156
return None
136157

137-
right_overlap_alignments = map_overlap_onto_candidate(str(left_initial_overlap), str(right.seq))
158+
right_overlap_alignments = right.map_overlap(str(left_initial_overlap))
138159
right_cutoff = max((al.r_en for al in right_overlap_alignments), default=None)
139160
if right_cutoff is None:
140161
logger.debug("Overlap alignment between %s and %s failed.", a.unique_name, b.unique_name)
141162
return None
142163
else:
143-
right_overlap_alignments = map_overlap_onto_candidate(str(left_initial_overlap), str(right.seq))
164+
right_overlap_alignments = right.map_overlap(str(left_initial_overlap))
144165
right_cutoff = max((al.r_en for al in right_overlap_alignments), default=None)
145166
if right_cutoff is None:
146167
logger.debug("Overlap alignment between %s and %s failed.", a.unique_name, b.unique_name)
147168
return None
148169

149-
left_overlap_alignments = map_overlap_onto_candidate(str(right_initial_overlap), str(left.seq))
170+
left_overlap_alignments = left.map_overlap(str(right_initial_overlap))
150171
left_cutoff = min((al.r_st for al in left_overlap_alignments), default=None)
151172
if left_cutoff is None:
152173
logger.debug("Overlap alignment between %s and %s failed.", a.unique_name, b.unique_name)
@@ -186,7 +207,7 @@ def try_combine_contigs(finder: OverlapFinder,
186207
right_overlap_chunk = ''.join(x for x in aligned_right[max_concordance_index:] if x != '-')
187208

188209
result_seq = left_remainder + left_overlap_chunk + right_overlap_chunk + right_remainder
189-
result_contig = Contig(None, result_seq)
210+
result_contig = ContigWithAligner(None, result_seq)
190211

191212
logger.debug("Joined %s and %s together in a contig %s with lengh %s.",
192213
a.unique_name, b.unique_name,
@@ -198,7 +219,7 @@ def try_combine_contigs(finder: OverlapFinder,
198219
def extend_by_1(finder: OverlapFinder,
199220
max_acceptable_prob: Fraction,
200221
path: ContigsPath,
201-
candidate: Contig,
222+
candidate: ContigWithAligner,
202223
) -> Iterator[ContigsPath]:
203224
if path.has_contig(candidate):
204225
return
@@ -217,7 +238,7 @@ def extend_by_1(finder: OverlapFinder,
217238

218239
def calc_extension(finder: OverlapFinder,
219240
max_acceptable_prob: Fraction,
220-
contigs: Sequence[Contig],
241+
contigs: Sequence[ContigWithAligner],
221242
path: ContigsPath,
222243
) -> Iterator[ContigsPath]:
223244

@@ -228,7 +249,7 @@ def calc_extension(finder: OverlapFinder,
228249
def calc_multiple_extensions(finder: OverlapFinder,
229250
max_acceptable_prob: Fraction,
230251
paths: Iterable[ContigsPath],
231-
contigs: Sequence[Contig],
252+
contigs: Sequence[ContigWithAligner],
232253
) -> Iterator[ContigsPath]:
233254
for path in paths:
234255
yield from calc_extension(finder, max_acceptable_prob, contigs, path)
@@ -249,7 +270,7 @@ def filter_extensions(existing: MutableMapping[str, ContigsPath],
249270
yield from ret.values()
250271

251272

252-
def calculate_all_paths(contigs: Sequence[Contig]) -> Iterator[ContigsPath]:
273+
def calculate_all_paths(contigs: Sequence[ContigWithAligner]) -> Iterator[ContigsPath]:
253274
max_acceptable_prob = ACCEPTABLE_STITCHING_PROB
254275
existing: MutableMapping[str, ContigsPath] = {}
255276
finder = OverlapFinder.make('ACTG')
@@ -290,12 +311,12 @@ def calculate_all_paths(contigs: Sequence[Contig]) -> Iterator[ContigsPath]:
290311
max_acceptable_prob = max(x.pessimisstic_probability for x in paths)
291312

292313

293-
def find_most_probable_path(contigs: Sequence[Contig]) -> ContigsPath:
314+
def find_most_probable_path(contigs: Sequence[ContigWithAligner]) -> ContigsPath:
294315
paths = calculate_all_paths(contigs)
295316
return max(paths, key=ContigsPath.score)
296317

297318

298-
def stitch_consensus(contigs: Iterable[Contig]) -> Iterable[Contig]:
319+
def stitch_consensus(contigs: Iterable[ContigWithAligner]) -> Iterator[ContigWithAligner]:
299320
remaining = tuple(contigs)
300321
while remaining:
301322
most_probable = find_most_probable_path(remaining)
@@ -304,7 +325,7 @@ def stitch_consensus(contigs: Iterable[Contig]) -> Iterable[Contig]:
304325
if not most_probable.has_contig(contig))
305326

306327

307-
def write_contigs(output_fasta: TextIO, contigs: Iterable[Contig]):
328+
def write_contigs(output_fasta: TextIO, contigs: Iterable[ContigWithAligner]):
308329
records = (SeqRecord(Seq.Seq(contig.seq),
309330
description='',
310331
id=contig.unique_name,
@@ -313,9 +334,9 @@ def write_contigs(output_fasta: TextIO, contigs: Iterable[Contig]):
313334
SeqIO.write(records, output_fasta, "fasta")
314335

315336

316-
def read_contigs(input_fasta: TextIO) -> Iterable[Contig]:
337+
def read_contigs(input_fasta: TextIO) -> Iterable[ContigWithAligner]:
317338
for record in SeqIO.parse(input_fasta, "fasta"):
318-
yield Contig(name=record.name, seq=record.seq)
339+
yield ContigWithAligner(name=record.name, seq=record.seq)
319340

320341

321342
def referenceless_contig_stitcher(input_fasta: TextIO,

0 commit comments

Comments
 (0)