Skip to content

Commit 93d52ec

Browse files
authored
Generalization to demultiplexing. (#113)
* Moved core logic of array-based demultiplexing to annotate command. * Demultiplex now splits a BAM based on a read tag (e.g. YN for model name). * Added a test for retrieving model information from a BAM header, JSON file, or a specified model name.
1 parent f921f1c commit 93d52ec

File tree

4 files changed

+274
-235
lines changed

4 files changed

+274
-235
lines changed

src/longbow/annotate/command.py

Lines changed: 79 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@
5858
"-m",
5959
"--model",
6060
type=str,
61-
default=longbow.utils.constants.DEFAULT_MODEL,
61+
multiple=True,
6262
show_default=True,
63-
help="The model to use for annotation. If the given value is a pre-configured model name, then that "
63+
help="The model(s) to use for annotation. If the given value is a pre-configured model name, then that "
6464
"model will be used. Otherwise, the given value will be treated as a file name and Longbow will attempt to "
6565
"read in the file and create a LibraryModel from it. Longbow will assume the contents are the configuration "
6666
"of a LibraryModel as per LibraryModel.to_json()."
@@ -122,13 +122,8 @@ def main(pbi, threads, output_bam, model, chunk, min_length, max_length, min_rq,
122122
threads = mp.cpu_count() if threads <= 0 or threads > mp.cpu_count() else threads
123123
logger.info(f"Running with {threads} worker subprocess(es)")
124124

125-
# Get our model:
126-
if LibraryModel.has_prebuilt_model(model):
127-
m = LibraryModel.build_pre_configured_model(model)
128-
else:
129-
logger.info(f"Loading model from json file: %s", model)
130-
m = LibraryModel.from_json_file(model)
131-
logger.info(f"Using %s: %s", model, m.description)
125+
# Get our model(s):
126+
lb_models = bam_utils.load_models(model, input_bam)
132127

133128
pbi = f"{input_bam.name}.pbi" if pbi is None else pbi
134129
read_count = None
@@ -169,7 +164,7 @@ def main(pbi, threads, output_bam, model, chunk, min_length, max_length, min_rq,
169164

170165
for i in range(threads):
171166
p = mp.Process(
172-
target=_worker_segmentation_fn, args=(input_data_queue, results, i, m, min_length, max_length, min_rq)
167+
target=_worker_segmentation_fn, args=(input_data_queue, results, i, lb_models, min_length, max_length, min_rq)
173168
)
174169
p.start()
175170
worker_pool.append(p)
@@ -184,13 +179,13 @@ def main(pbi, threads, output_bam, model, chunk, min_length, max_length, min_rq,
184179
bam_file.seek(start_offset)
185180

186181
# Get our header from the input bam file:
187-
out_header = bam_utils.create_bam_header_with_program_group(logger.name, bam_file.header, models=[m])
182+
out_header = bam_utils.create_bam_header_with_program_group(logger.name, bam_file.header, models=lb_models)
188183

189184
# Start output worker:
190185
res = manager.dict({"num_reads_annotated": 0, "num_sections": 0})
191186
output_worker = mp.Process(
192187
target=_write_thread_fn,
193-
args=(results, out_header, output_bam, not sys.stdin.isatty(), res, read_count, m)
188+
args=(results, out_header, output_bam, not sys.stdin.isatty(), res, read_count, lb_models)
194189
)
195190
output_worker.start()
196191

@@ -239,9 +234,11 @@ def get_segments(read):
239234
]
240235

241236

242-
def _write_thread_fn(out_queue, out_bam_header, out_bam_file_name, disable_pbar, res, read_count, model):
237+
def _write_thread_fn(out_queue, out_bam_header, out_bam_file_name, disable_pbar, res, read_count, lb_models):
243238
"""Thread / process fn to write out all our data."""
244239

240+
lb_models_dict = {k.name:k for k in lb_models}
241+
245242
with pysam.AlignmentFile(
246243
out_bam_file_name, "wb", header=out_bam_header
247244
) as out_bam_file, tqdm.tqdm(
@@ -267,37 +264,38 @@ def _write_thread_fn(out_queue, out_bam_header, out_bam_file_name, disable_pbar,
267264
continue
268265

269266
# Unpack data:
270-
read, ppath, logp, is_rc = raw_data
267+
read, ppath, logp, is_rc, model_name = raw_data
271268

272-
# Condense the output annotations so we can write them out with indices:
273-
segments = bam_utils.collapse_annotations(ppath)
269+
if read is not None:
270+
# Condense the output annotations so we can write them out with indices:
271+
segments = bam_utils.collapse_annotations(ppath)
274272

275-
read = pysam.AlignedSegment.fromstring(read, out_bam_header)
273+
read = pysam.AlignedSegment.fromstring(read, out_bam_header)
276274

277-
# Obligatory log message:
278-
logger.debug(
279-
"Path for read %s (%2.2f)%s: %s",
280-
read.query_name,
281-
logp,
282-
" (RC)" if is_rc else "",
283-
segments,
284-
)
275+
# Obligatory log message:
276+
logger.debug(
277+
"Path for read %s (%2.2f)%s: %s",
278+
read.query_name,
279+
logp,
280+
" (RC)" if is_rc else "",
281+
segments,
282+
)
285283

286-
# Write our our read:
287-
bam_utils.write_annotated_read(read, segments, is_rc, logp, model, ssw_aligner, out_bam_file)
284+
# Write our our read:
285+
bam_utils.write_annotated_read(read, segments, is_rc, logp, lb_models_dict[model_name], ssw_aligner, out_bam_file)
288286

289-
# Increment our counters:
290-
res["num_reads_annotated"] += 1
291-
res["num_sections"] += len(segments)
287+
# Increment our counters:
288+
res["num_reads_annotated"] += 1
289+
res["num_sections"] += len(segments)
292290

293291
pbar.update(1)
294292

295293

296-
def _worker_segmentation_fn(in_queue, out_queue, worker_num, model, min_length, max_length, min_rq):
294+
def _worker_segmentation_fn(in_queue, out_queue, worker_num, lb_models, min_length, max_length, min_rq):
297295
"""Function to run in each subthread / subprocess.
298296
Segments each read and place the segments in the output queue."""
299297

300-
num_reads_segmented = 0
298+
num_reads_processed, num_reads_segmented = 0, 0
301299

302300
while True:
303301
# Wait until we get some data.
@@ -320,42 +318,69 @@ def _worker_segmentation_fn(in_queue, out_queue, worker_num, model, min_length,
320318
)
321319

322320
# Check for min/max length and min quality:
323-
321+
segment_info = None, None, None, None, None
324322
if len(read.query_sequence) < min_length:
325-
logger.warning(f"Read is shorter than min length. "
326-
f"Skipping: {read.query_name} ({len(read.query_sequence)} < {min_length})")
327-
continue
323+
logger.debug(f"Read is shorter than min length. "
324+
f"Skipping: {read.query_name} ({len(read.query_sequence)} < {min_length})")
328325
elif len(read.query_sequence) > max_length:
329-
logger.warning(f"Read is longer than max length. "
330-
f"Skipping: {read.query_name} ({len(read.query_sequence)} > {max_length})")
331-
continue
326+
logger.debug(f"Read is longer than max length. "
327+
f"Skipping: {read.query_name} ({len(read.query_sequence)} > {max_length})")
332328
elif read.get_tag("rq") < min_rq:
333-
logger.warning(f"Read quality is below the minimum. "
334-
f"Skipping: {read.query_name} ({read.get_tag('rq')} < {min_rq})")
335-
continue
336-
337-
# Process and place our data on the output queue:
338-
segment_info = _segment_read(read, model)
329+
logger.debug(f"Read quality is below the minimum. "
330+
f"Skipping: {read.query_name} ({read.get_tag('rq')} < {min_rq})")
331+
else:
332+
# Process and place our data on the output queue:
333+
segment_info = _annotate_and_assign_read_to_model(read, lb_models)
334+
num_reads_segmented += 1
339335

340336
out_queue.put(segment_info)
341-
num_reads_segmented += 1
337+
num_reads_processed += 1
338+
339+
logger.debug(f"Worker %d: Num reads segmented/processed: %d/%d", worker_num, num_reads_segmented, num_reads_processed)
340+
342341

343-
logger.debug(f"Worker %d: Num reads segmented: %d", worker_num, num_reads_segmented)
342+
def _annotate_and_assign_read_to_model(read, model_list):
343+
"""Annotate the given read with all given models and assign the read to the model with the best score."""
344344

345+
best_model = ""
346+
best_logp = -math.inf
347+
best_path = None
348+
best_fit_is_rc = False
349+
model_scores = dict()
350+
for model in model_list:
345351

346-
def _segment_read(read, model):
352+
_, ppath, logp, is_rc = _annotate_read(read, model)
353+
354+
model_scores[model.name] = logp
355+
356+
if logp > best_logp:
357+
best_model = model.name
358+
best_logp = logp
359+
best_path = ppath
360+
best_fit_is_rc = is_rc
361+
362+
# Provide some info as to which model was chosen:
363+
if logger.isEnabledFor(logging.DEBUG):
364+
logger.debug("%s model scores: %s", read.query_name, str(model_scores))
365+
logger.debug(
366+
"Sequence %s scored best with model%s: %s (%2.4f)",
367+
read.query_name,
368+
" in RC " if best_fit_is_rc else "",
369+
best_model,
370+
best_logp
371+
)
372+
373+
return read.to_string(), best_path, best_logp, best_fit_is_rc, best_model
374+
375+
376+
def _annotate_read(read, model):
347377
is_rc = False
348378
logp, ppath = model.annotate(read.query_sequence)
349379

350380
rc_logp, rc_ppath = model.annotate(bam_utils.reverse_complement(read.query_sequence))
351-
352-
# print(f"Forward Path: {logp}: {ppath}")
353-
# print(f"Reverse Path: {rc_logp}: {rc_ppath}")
354-
355381
if rc_logp > logp:
356382
logp = rc_logp
357383
ppath = rc_ppath
358384
is_rc = True
359-
logger.debug("Sequence scored better in RC: %s", read.query_name)
360385

361-
return read.to_string(), ppath, logp, is_rc
386+
return read.to_string(), ppath, logp, is_rc

0 commit comments

Comments
 (0)