Skip to content

Commit

Permalink
src/lineage_processing.jl: Performance optimization for distance calc…
Browse files Browse the repository at this point in the history
…ulation.
  • Loading branch information
mashu committed Oct 18, 2024
1 parent e52f6e9 commit 52b3783
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/lineage_processing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Compute the distance between two `LongDNA{4}` sequences using the specified dist
"""
function compute_distance(::HammingDistance, x::LongDNA{4}, y::LongDNA{4})::Float64
@assert length(x) == length(y)
return evaluate(Hamming(), String(x), String(y))
return mismatches(x, y)
end

"""
Expand All @@ -37,7 +37,7 @@ Compute the distance between two `LongDNA{4}` sequences using the specified dist
"""
function compute_distance(::NormalizedHammingDistance, x::LongDNA{4}, y::LongDNA{4})::Float64
@assert length(x) == length(y)
return evaluate(Hamming(), String(x), String(y)) / max(length(x), length(y))
return mismatches(x, y) / length(x)
end

"""
Expand All @@ -54,15 +54,19 @@ end
Compute pairwise distances between sequences using the specified distance metric.
"""
function compute_pairwise_distance(metric::Union{DistanceMetric, NormalizedDistanceMetric}, sequences::Vector{LongDNA{4}})::Matrix{Float64}
function compute_pairwise_distance(
metric::M,
sequences::AbstractVector{S}
)::Matrix{Float64} where {M <: Union{DistanceMetric, NormalizedDistanceMetric}, S <: LongSequence{DNAAlphabet{4}}}
n = length(sequences)
@assert n > 0 "No sequences provided for distance calculation"
dist_matrix = zeros(Float64, n, n)

Threads.@threads for i in 1:n
for j in i+1:n
dist = compute_distance(metric, sequences[i], sequences[j])
dist_matrix[i, j] = dist
dist_matrix[j, i] = dist
@inbounds dist = compute_distance(metric, sequences[i], sequences[j])
@inbounds dist_matrix[i, j] = dist
@inbounds dist_matrix[j, i] = dist
end
end
return dist_matrix
Expand Down Expand Up @@ -90,15 +94,17 @@ function process_lineages(df::DataFrame;
distance_metric::Union{DistanceMetric, NormalizedDistanceMetric} = NormalizedHammingDistance(),
clustering_method::ClusteringMethod = HierarchicalClustering(0.1),
linkage::Symbol = :single)::DataFrame
# Convert upfront to LongDNA{4} for performance
df.cdr3 = LongDNA{4}.(df.cdr3)

grouped = groupby(df, [:v_call_first, :j_call_first, :cdr3_length])
processed_groups = Vector{DataFrame}()

prog = Progress(length(grouped), desc="Processing lineages")
@inbounds for (group_id, group) in enumerate(grouped)
next!(prog)
if nrow(group) > 1
sequences = LongDNA{4}.(group.cdr3)
dist_matrix = compute_pairwise_distance(distance_metric, sequences)
dist_matrix = compute_pairwise_distance(distance_metric, group.cdr3)
group[!, :cluster] = perform_clustering(clustering_method, linkage, dist_matrix)
else
group[!, :cluster] .= 1
Expand Down

0 comments on commit 52b3783

Please sign in to comment.