Skip to content

Commit

Permalink
implemented parallelization with n_jobs and n_blocks for hamming and …
Browse files Browse the repository at this point in the history
…tcrdist distance metrics
  • Loading branch information
felixpetschko committed May 6, 2024
1 parent 9ee1a2b commit d68a10b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 29 deletions.
109 changes: 81 additions & 28 deletions src/scirpy/ir_dist/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,10 @@ def _seqs2mat(


class NumbaDistanceCalculator(abc.ABC):
def __init__(self, n_jobs: int = 1):
def __init__(self, n_jobs: int = 1, n_blocks: int = 1):
super().__init__()
self.n_jobs = n_jobs
self.n_blocks = n_blocks

@abc.abstractmethod
def _metric_mat(
Expand All @@ -425,6 +426,42 @@ def _metric_mat(
is_symmetric: bool = False,
start_column: int = 0,
) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray]:
"""
This function should be implemented by the derived class in a way sucht that it computes the pairwise distances
for sequences in seqs_mat1 and seqs_mat2 based on a certain distance metric. The result should be a distance matrix
that is returned in the form of the data, indices and intptr arrays of a (scipy) compressed sparse row matrix.
If this function is used to compute a block of a bigger result matrix, is_symmetric and start_column
can be used to only compute the part of the block that would be part of the upper triangular matrix of the
result matrix.
Parameters
----------
seqs_mat1/2:
Matrix containing sequences created by seqs2mat with padding to accomodate
sequences of different lengths (-1 padding)
seqs_L1/2:
A vector containing the length of each sequence in the respective seqs_mat matrix,
without the padding in seqs_mat
is_symmetric:
Determines whether the final result matrix is symmetric, assuming that this function is
only used to compute a block of a bigger result matrix
start_column:
Determines at which column the calculation should be started. This is only used if this function is
used to compute a block of a bigger result matrix that is symmetric
Returns
-------
data_rows:
List with arrays containing the non-zero data values of the result matrix per row,
needed to create the final scipy CSR result matrix later
indices_rows:
List with arrays containing the non-zero entry column indeces of the result matrix per row,
needed to create the final scipy CSR result matrix later
row_element_counts:
Array with integers that indicate the amount of non-zero values of the result matrix per row,
needed to create the final scipy CSR result matrix later
"""
pass

def _calc_dist_mat_block(
Expand Down Expand Up @@ -470,15 +507,14 @@ def calc_dist_mat(self, seqs: Sequence[str], seqs2: Optional[Sequence[str]] = No
seqs = np.array(seqs)
seqs2 = np.array(seqs2)
is_symmetric = np.array_equal(seqs, seqs2)
n_blocks = self.n_jobs * 2

if False: # self.n_jobs > 1: --- only for intermediate version set to False
split_seqs = np.array_split(seqs, n_blocks)

if self.n_blocks > 1:
split_seqs = np.array_split(seqs, self.n_blocks)
start_columns = np.cumsum([0] + [len(seq) for seq in split_seqs[:-1]])
arguments = [(split_seqs[x], seqs2, is_symmetric, start_columns[x]) for x in range(n_blocks)]
arguments = [(split_seqs[x], seqs2, is_symmetric, start_columns[x]) for x in range(self.n_blocks)]

delayed_jobs = [joblib.delayed(self._calc_dist_mat_block)(*args) for args in arguments]
results = list(_parallelize_with_joblib(delayed_jobs, total=len(arguments), n_jobs=self.n_jobs))
results = joblib.Parallel(return_as="list")(delayed_jobs)
distance_matrix_csr = scipy.sparse.vstack(results)
else:
distance_matrix_csr = self._calc_dist_mat_block(seqs, seqs2, is_symmetric)
Expand All @@ -496,9 +532,10 @@ class HammingDistanceCalculator(NumbaDistanceCalculator):
def __init__(
self,
n_jobs: int = 1,
n_blocks: int = 1,
cutoff: int = 2,
):
super().__init__(n_jobs=n_jobs)
super().__init__(n_jobs=n_jobs, n_blocks=n_blocks)
self.cutoff = cutoff

def _hamming_mat(
Expand Down Expand Up @@ -618,14 +655,15 @@ def __init__(
ctrim: int = 2,
fixed_gappos: bool = True,
n_jobs: int = 1,
n_blocks: int = 1,
):
self.dist_weight = dist_weight
self.gap_penalty = gap_penalty
self.ntrim = ntrim
self.ctrim = ctrim
self.fixed_gappos = fixed_gappos
self.cutoff = cutoff
super().__init__(n_jobs=n_jobs)
super().__init__(n_jobs=n_jobs, n_blocks=n_blocks)

def _tcrdist_mat(
self,
Expand Down Expand Up @@ -686,36 +724,45 @@ def _tcrdist_mat(
dist_mat_weighted = self.tcr_nb_distance_matrix * dist_weight
start_column *= is_symmetric

@nb.jit(nopython=True, parallel=False, nogil=True)
nb.set_num_threads(self.n_jobs)
num_threads = nb.get_num_threads()
print("numba threads: ", num_threads)

@nb.jit(nopython=True, parallel=True, nogil=True)
def _nb_tcrdist_mat():
assert seqs_mat1.shape[0] == seqs_L1.shape[0]
assert seqs_mat2.shape[0] == seqs_L2.shape[0]

num_rows = seqs_mat1.shape[0]
num_cols = seqs_mat2.shape[0]

data_rows = nb.typed.List()
indices_rows = nb.typed.List()
row_element_counts = np.zeros(seqs_mat1.shape[0])
row_element_counts = np.zeros(num_rows)

empty_row = np.zeros(0)
for _ in range(0, seqs_mat1.shape[0]):
data_rows.append(empty_row)
indices_rows.append(empty_row)
for _ in range(0, num_rows):
data_rows.append([empty_row])
indices_rows.append([empty_row])

data_row = np.zeros(seqs_mat2.shape[0])
indices_row = np.zeros(seqs_mat2.shape[0])
for row_index in range(seqs_mat1.shape[0]):
for row_index in nb.prange(num_rows):
data_row = np.empty(num_cols)
indices_row = np.empty(num_cols)

row_end_index = 0
for col_index in range(start_column + row_index * is_symmetric, seqs_mat2.shape[0]):
q_L = seqs_L1[row_index]
s_L = seqs_L2[col_index]
seq1_len = seqs_L1[row_index]

for col_index in range(start_column + row_index * is_symmetric, num_cols):
distance = 1
seq2_len = seqs_L2[col_index]

if q_L == s_L:
for i in range(ntrim, q_L - ctrim):
if seq1_len == seq2_len:
for i in range(ntrim, seq1_len - ctrim):
distance += dist_mat_weighted[seqs_mat1[row_index, i], seqs_mat2[col_index, i]]

else:
short_len = min(q_L, s_L)
len_diff = abs(q_L - s_L)
short_len = min(seq1_len, seq2_len)
len_diff = abs(seq1_len - seq2_len)
if fixed_gappos:
min_gappos = min(6, 3 + (short_len - 5) // 2)
max_gappos = min_gappos
Expand All @@ -736,7 +783,7 @@ def _nb_tcrdist_mat():

for c_i in range(ctrim, remainder):
tmp_dist += dist_mat_weighted[
seqs_mat1[row_index, q_L - 1 - c_i], seqs_mat2[col_index, s_L - 1 - c_i]
seqs_mat1[row_index, seq1_len - 1 - c_i], seqs_mat2[col_index, seq2_len - 1 - c_i]
]

if tmp_dist < min_dist or min_dist == -1:
Expand All @@ -752,14 +799,20 @@ def _nb_tcrdist_mat():
indices_row[row_end_index] = col_index
row_end_index += 1

data_rows[row_index] = data_row[0:row_end_index].copy()
indices_rows[row_index] = indices_row[0:row_end_index].copy()
data_rows[row_index][0] = data_row[0:row_end_index].copy()
indices_rows[row_index][0] = indices_row[0:row_end_index].copy()
row_element_counts[row_index] = row_end_index
return data_rows, indices_rows, row_element_counts

data_rows, indices_rows, row_element_counts = _nb_tcrdist_mat()
data_rows_flat = []
indices_rows_flat = []

return data_rows, indices_rows, row_element_counts
for i in range(len(data_rows)):
data_rows_flat.append(data_rows[i][0])
indices_rows_flat.append(indices_rows[i][0])

return data_rows_flat, indices_rows_flat, row_element_counts

_metric_mat = _tcrdist_mat

Expand Down
3 changes: 2 additions & 1 deletion src/scirpy/tests/test_ir_dist_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def test_tcrdist_reference():
fixed_gappos=True,
cutoff=15,
n_jobs=2,
n_blocks=2,
)
res = tcrdist_calculator.calc_dist_mat(seqs, seqs)

Expand All @@ -660,7 +661,7 @@ def test_hamming_reference():
seqs = np.load(TESTDATA / "hamming_test_data/hamming_WU3k_seqs.npy")
reference_result = scipy.sparse.load_npz(TESTDATA / "hamming_test_data/hamming_WU3k_csr_result.npz")

hamming_calculator = HammingDistanceCalculator(2, 2)
hamming_calculator = HammingDistanceCalculator(2, 2, 2)
res = hamming_calculator.calc_dist_mat(seqs, seqs)

assert np.array_equal(res.data, reference_result.data)
Expand Down

0 comments on commit d68a10b

Please sign in to comment.