Skip to content

Commit

Permalink
performance optimization for hamming and tcrdist
Browse files Browse the repository at this point in the history
  • Loading branch information
felixpetschko committed May 6, 2024
1 parent d68a10b commit 0005e63
Showing 1 changed file with 45 additions and 34 deletions.
79 changes: 45 additions & 34 deletions src/scirpy/ir_dist/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,13 @@ def _hamming_mat(

nb.set_num_threads(self.n_jobs)
num_threads = nb.get_num_threads()
print("numba threads: ", num_threads)

@nb.jit(nopython=True, parallel=True, nogil=True)
if(num_threads>1):
jit_parallel = True
else:
jit_parallel = False

@nb.jit(nopython=True, parallel=jit_parallel, nogil=True)
def _nb_hamming_mat():
assert seqs_mat1.shape[0] == seqs_L1.shape[0]
assert seqs_mat2.shape[0] == seqs_L2.shape[0]
Expand All @@ -572,10 +576,11 @@ def _nb_hamming_mat():
data_rows.append([empty_row])
indices_rows.append([empty_row])

for row_index in nb.prange(num_rows):
data_row = np.empty(num_cols)
indices_row = np.empty(num_cols)
data_row_matrix = np.empty((num_threads, num_cols))
indices_row_matrix = np.empty((num_threads, num_cols))

for row_index in nb.prange(num_rows):
thread_id = nb.get_thread_id()
row_end_index = 0
seq1_len = seqs_L1[row_index]

Expand All @@ -587,25 +592,25 @@ def _nb_hamming_mat():
distance += seqs_mat1[row_index, i] != seqs_mat2[col_index, i]

if distance <= cutoff + 1:
data_row[row_end_index] = distance
indices_row[row_end_index] = col_index
data_row_matrix[thread_id, row_end_index] = distance
indices_row_matrix[thread_id, row_end_index] = col_index
row_end_index += 1

data_rows[row_index][0] = data_row[0:row_end_index].copy()
indices_rows[row_index][0] = indices_row[0:row_end_index].copy()
data_rows[row_index][0] = data_row_matrix[thread_id, 0:row_end_index].copy()
indices_rows[row_index][0] = indices_row_matrix[thread_id, 0:row_end_index].copy()
row_element_counts[row_index] = row_end_index

return data_rows, indices_rows, row_element_counts
data_rows_flat = []
indices_rows_flat = []

data_rows, indices_rows, row_element_counts = _nb_hamming_mat()
data_rows_flat = []
indices_rows_flat = []
for i in range(len(data_rows)):
data_rows_flat.append(data_rows[i][0])
indices_rows_flat.append(indices_rows[i][0])

for i in range(len(data_rows)):
data_rows_flat.append(data_rows[i][0])
indices_rows_flat.append(indices_rows[i][0])
return data_rows_flat, indices_rows_flat, row_element_counts

return data_rows_flat, indices_rows_flat, row_element_counts
data_rows, indices_rows, row_element_counts = _nb_hamming_mat()
return data_rows, indices_rows, row_element_counts

_metric_mat = _hamming_mat

Expand Down Expand Up @@ -726,9 +731,13 @@ def _tcrdist_mat(

nb.set_num_threads(self.n_jobs)
num_threads = nb.get_num_threads()
print("numba threads: ", num_threads)

if(num_threads>1):
jit_parallel = True
else:
jit_parallel = False

@nb.jit(nopython=True, parallel=True, nogil=True)
@nb.jit(nopython=True, parallel=jit_parallel, nogil=True)
def _nb_tcrdist_mat():
assert seqs_mat1.shape[0] == seqs_L1.shape[0]
assert seqs_mat2.shape[0] == seqs_L2.shape[0]
Expand All @@ -745,10 +754,11 @@ def _nb_tcrdist_mat():
data_rows.append([empty_row])
indices_rows.append([empty_row])

data_row_matrix = np.empty((num_threads, num_cols))
indices_row_matrix = np.empty((num_threads, num_cols))

for row_index in nb.prange(num_rows):
data_row = np.empty(num_cols)
indices_row = np.empty(num_cols)

thread_id = nb.get_thread_id()
row_end_index = 0
seq1_len = seqs_L1[row_index]

Expand Down Expand Up @@ -795,24 +805,25 @@ def _nb_tcrdist_mat():
distance = min_dist + len_diff * gap_penalty + 1

if distance <= cutoff + 1:
data_row[row_end_index] = distance
indices_row[row_end_index] = col_index
data_row_matrix[thread_id, row_end_index] = distance
indices_row_matrix[thread_id, row_end_index] = col_index
row_end_index += 1

data_rows[row_index][0] = data_row[0:row_end_index].copy()
indices_rows[row_index][0] = indices_row[0:row_end_index].copy()
data_rows[row_index][0] = data_row_matrix[thread_id, 0:row_end_index].copy()
indices_rows[row_index][0] = indices_row_matrix[thread_id, 0:row_end_index].copy()
row_element_counts[row_index] = row_end_index
return data_rows, indices_rows, row_element_counts

data_rows, indices_rows, row_element_counts = _nb_tcrdist_mat()
data_rows_flat = []
indices_rows_flat = []
data_rows_flat = []
indices_rows_flat = []

for i in range(len(data_rows)):
data_rows_flat.append(data_rows[i][0])
indices_rows_flat.append(indices_rows[i][0])
for i in range(len(data_rows)):
data_rows_flat.append(data_rows[i][0])
indices_rows_flat.append(indices_rows[i][0])

return data_rows_flat, indices_rows_flat, row_element_counts
return data_rows_flat, indices_rows_flat, row_element_counts

data_rows, indices_rows, row_element_counts = _nb_tcrdist_mat()
return data_rows, indices_rows, row_element_counts

_metric_mat = _tcrdist_mat

Expand Down

0 comments on commit 0005e63

Please sign in to comment.