Skip to content

Commit

Permalink
added hound_isd_bed
Browse files Browse the repository at this point in the history
  • Loading branch information
Anya Korsakova committed Sep 25, 2024
1 parent 16815fa commit 233043e
Showing 1 changed file with 315 additions and 0 deletions.
315 changes: 315 additions & 0 deletions src/baskerville/scripts/hound_isd_bed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
#!/usr/bin/env python
# Copyright 2017 Calico LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
from optparse import OptionParser

import json
import os

import h5py
import numpy as np
import pandas as pd

from baskerville import bed
from baskerville import dataset
from baskerville import dna
from baskerville import seqnn
from baskerville import snps

"""
hound_ism_bed.py
Perform an in silico saturation mutagenesis of sequences in a BED file.
"""


def main():
usage = "usage: %prog [options] <params_file> <model_file> <bed_file>"
parser = OptionParser(usage)
parser.add_option(
"-d",
dest="mut_down",
default=0,
type="int",
help="Nucleotides downstream of center sequence to mutate [Default: %default]",
)
parser.add_option(
"-f",
dest="genome_fasta",
default=None,
help="Genome FASTA for sequences [Default: %default]",
)
parser.add_option(
"-l",
dest="mut_len",
default=0,
type="int",
help="Length of center sequence to mutate [Default: %default]",
)
parser.add_option(
"-s",
dest="del_len",
default=1,
type="int",
help="Deletion size for ISD [Default: %default]",
)
parser.add_option(
"-o",
dest="out_dir",
default="sat_mut",
help="Output directory [Default: %default]",
)
parser.add_option(
"-p",
dest="processes",
default=None,
type="int",
help="Number of processes, passed by multi script",
)
parser.add_option(
"--rc",
dest="rc",
default=False,
action="store_true",
help="Ensemble forward and reverse complement predictions [Default: %default]",
)
parser.add_option(
"--shifts",
dest="shifts",
default="0",
help="Ensemble prediction shifts [Default: %default]",
)
parser.add_option(
"--stats",
dest="snp_stats",
default="logSUM",
help="Comma-separated list of stats to save. [Default: %default]",
)
parser.add_option(
"-t",
dest="targets_file",
default=None,
type="str",
help="File specifying target indexes and labels in table format",
)
parser.add_option(
"-u",
dest="mut_up",
default=0,
type="int",
help="Nucleotides upstream of center sequence to mutate [Default: %default]",
)
parser.add_option(
"--untransform_old",
dest="untransform_old",
default=False,
action="store_true",
help="Untransform old models [Default: %default]",
)
(options, args) = parser.parse_args()

if len(args) == 3:
# single worker
params_file = args[0]
model_file = args[1]
bed_file = args[2]
else:
parser.error("Must provide parameter and model files and BED file")

if not os.path.isdir(options.out_dir):
os.mkdir(options.out_dir)

options.shifts = [int(shift) for shift in options.shifts.split(",")]
options.snp_stats = [snp_stat for snp_stat in options.snp_stats.split(",")]

if options.mut_up > 0 or options.mut_down > 0:
options.mut_len = options.mut_up + options.mut_down
else:
assert options.mut_len > 0
options.mut_up = options.mut_len // 2
options.mut_down = options.mut_len - options.mut_up

#################################################################
# read parameters and targets

# read model parameters
with open(params_file) as params_open:
params = json.load(params_open)
params_model = params["model"]

# read targets
if options.targets_file is None:
parser.error("Must provide targets file to clarify stranded datasets")
targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0)

# handle strand pairs
if "strand_pair" in targets_df.columns:
# prep strand
targets_strand_df = dataset.targets_prep_strand(targets_df)

# set strand pairs (using new indexing)
orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0])))
targets_strand_pair = np.array(
[orig_new_index[ti] for ti in targets_df.strand_pair]
)
params_model["strand_pair"] = [targets_strand_pair]

# construct strand sum transform
strand_transform = dataset.make_strand_transform(targets_df, targets_strand_df)
else:
targets_strand_df = targets_df
strand_transform = None
num_targets = targets_strand_df.shape[0]

#################################################################
# setup model

seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file)
seqnn_model.build_slice(targets_df.index)
seqnn_model.build_ensemble(options.rc)

#################################################################
# sequence dataset

# read sequences from BED
seqs_dna, seqs_coords = bed.make_bed_seqs(
bed_file, options.genome_fasta, params_model["seq_length"], stranded=True
)
num_seqs = len(seqs_dna)

# determine mutation region limits
seq_mid = params_model["seq_length"] // 2
mut_start = seq_mid - options.mut_up
mut_end = mut_start + options.mut_len

model_stride = seqnn_model.model_strides[0]

#################################################################
# setup output

scores_h5_file = "%s/scores.h5" % options.out_dir
if os.path.isfile(scores_h5_file):
os.remove(scores_h5_file)
scores_h5 = h5py.File(scores_h5_file, "w")
scores_h5.create_dataset("seqs", dtype="bool", shape=(num_seqs, options.mut_len, 4))
for snp_stat in options.snp_stats:
scores_h5.create_dataset(
snp_stat, dtype="float16", shape=(num_seqs, options.mut_len, 4, num_targets)
)

# store mutagenesis sequence coordinates
scores_chr = []
scores_start = []
scores_end = []
scores_strand = []
for seq_chr, seq_start, seq_end, seq_strand in seqs_coords:
scores_chr.append(seq_chr)
scores_strand.append(seq_strand)
if seq_strand == "+":
score_start = seq_start + mut_start
score_end = score_start + options.mut_len
else:
score_end = seq_end - mut_start
score_start = score_end - options.mut_len
scores_start.append(score_start)
scores_end.append(score_end)

scores_h5.create_dataset("chr", data=np.array(scores_chr, dtype="S"))
scores_h5.create_dataset("start", data=np.array(scores_start))
scores_h5.create_dataset("end", data=np.array(scores_end))
scores_h5.create_dataset("strand", data=np.array(scores_strand, dtype="S"))

#################################################################
# predict scores, write output

# make list of shifts for reference stitching
ref_shifts = []
for shift in options.shifts:
ref_shifts.append(shift)
ref_shifts.append(shift - options.del_len)

for si, seq_dna in enumerate(seqs_dna):
print("Predicting %d" % si, flush=True)

# 1 hot code DNA
ref_1hot = dna.dna_1hot(seq_dna)
ref_1hot = np.expand_dims(ref_1hot, axis=0)

# save sequence
scores_h5["seqs"][si] = ref_1hot[0, mut_start:mut_end].astype("bool")

# predict reference
ref_preds = []
for shift in ref_shifts:
# shift sequence and predict
ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift)
ref_preds_shift = seqnn_model.predict_transform(
ref_1hot_shift,
targets_df,
strand_transform,
options.untransform_old,
)
ref_preds.append(ref_preds_shift)

# for mutation positions
for mi in range(mut_start, mut_end):
# for each nucleotide
# copy and modify
alt_1hot = np.copy(ref_1hot)

# left-matched shift: delete 1 nucleotide at position mi
dna.hot1_delete(alt_1hot[0], mi, options.del_len)

# predict alternate
alt_preds = []
for shift in options.shifts:
# shift sequence and predict
alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift)
alt_preds_shift = seqnn_model.predict_transform(
alt_1hot_shift,
targets_df,
strand_transform,
options.untransform_old,
)
alt_preds.append(alt_preds_shift)
alt_preds = np.array(alt_preds)

out_seq_len = alt_preds.shape[1]
out_seq_center = out_seq_len // 2

# stitch reference predictions at the mutated nucleotide
snp_seq_bin = out_seq_center + (mi - options.mut_up) // model_stride

ref_preds_stitch = snps.stitch_preds(ref_preds, options.shifts, snp_seq_bin)
ref_preds_stitch = np.array(ref_preds_stitch)

ism_scores = snps.compute_scores(
ref_preds_stitch, alt_preds, options.snp_stats, None
)
for snp_stat in options.snp_stats:
scores_h5[snp_stat][si, mi - mut_start, 0] = ism_scores[
snp_stat
]

# close output HDF5
scores_h5.close()


################################################################################
# __main__
################################################################################
if __name__ == "__main__":
main()

0 comments on commit 233043e

Please sign in to comment.