Skip to content

Commit

Permalink
src/lineage_processing.jl: Separate function for soft and hardest lin…
Browse files Browse the repository at this point in the history
…eage collapsing.
  • Loading branch information
mashu committed Oct 18, 2024
1 parent 7a5c66d commit 1773e58
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/LineageCollapse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module LineageCollapse
using BioSequences
using StringDistances

export load_data, preprocess_data, deduplicate_data, process_lineages, plot_diagnostics
export load_data, preprocess_data, deduplicate_data, process_lineages, collapse_lineages, plot_diagnostics
export DistanceMetric, ClusteringMethod
export HammingDistance, NormalizedHammingDistance, LevenshteinDistance, HierarchicalClustering
export compute_distance, compute_pairwise_distance, perform_clustering
Expand Down
61 changes: 58 additions & 3 deletions src/lineage_processing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function process_lineages(df::DataFrame;
processed_groups = Vector{DataFrame}()

prog = Progress(length(grouped), desc="Processing lineages")
for (group_id, group) in enumerate(grouped)
@inbounds for (group_id, group) in enumerate(grouped)
next!(prog)
if nrow(group) > 1
sequences = LongDNA{4}.(group.cdr3)
Expand All @@ -103,11 +103,11 @@ function process_lineages(df::DataFrame;
else
group[!, :cluster] .= 1
end

group[!, :group_id] .= group_id

cluster_grouped = groupby(group, :cluster)
for cgroup in cluster_grouped
@inbounds for cgroup in cluster_grouped
cgroup[!, :cluster_size] .= nrow(cgroup)
cgroup = transform(groupby(cgroup, [:v_call_first, :j_call_first, :cluster, :cdr3_length, :cdr3, :d_region, :cluster_size, :group_id]), nrow => :cdr3_count)
transform!(groupby(cgroup, :cluster), :cdr3_count => maximum => :max_cdr3_count)
Expand All @@ -121,4 +121,59 @@ function process_lineages(df::DataFrame;
result[!, :lineage_id] = groupindices(groupby(result, [:group_id, :cluster]))

return result
end

"""
collapse_lineages(df::DataFrame, cdr3_frequency_threshold::Float64, collapse_strategy::Symbol=:hardest)
Collapse lineages in a DataFrame based on CDR3 sequence frequency and a specified collapse strategy.
# Arguments
- `df::DataFrame`: Input DataFrame containing lineage data. Must have columns [:d_region, :lineage_id, :j_call_first, :v_call_first, :cdr3].
- `cdr3_frequency_threshold::Float64`: Minimum frequency threshold for CDR3 sequences (0.0 to 1.0).
- `collapse_strategy::Symbol=:hardest`: Strategy for collapsing lineages. Options are:
- `:hardest`: Select only the most frequent sequence for each lineage.
- `:soft`: Select all sequences that meet or exceed the `cdr3_frequency_threshold`.
# Returns
- `DataFrame`: Collapsed lineage data.
# Example
```julia
lineages = DataFrame(...) # Your input data
collapsed = collapse_lineages(lineages, 0.1, :soft)
```
"""
function collapse_lineages(df::DataFrame, cdr3_frequency_threshold::Float64, collapse_strategy::Symbol=:hardest)
if !(0.0 <= cdr3_frequency_threshold <= 1.0)
throw(ArgumentError("cdr3_frequency_threshold must be between 0.0 and 1.0"))
end
if !(collapse_strategy in [:hardest, :soft])
throw(ArgumentError("Invalid collapse strategy. Use :hardest or :soft."))
end

# Group by the specified columns
grouped = groupby(df, [:d_region, :lineage_id, :j_call_first, :v_call_first, :cdr3])

# Count occurrences of each unique combination
counted = combine(grouped, nrow => :count)

# Calculate frequency within each lineage
lineage_grouped = groupby(counted, :lineage_id)
with_frequency = transform(lineage_grouped, :count => (x -> x ./ sum(x)) => :frequency)

# Filter based on the threshold and collapse strategy
function filter_lineage(group)
if collapse_strategy == :hardest
# Pick the single most frequent sequence
return group[argmax(group.frequency), :]
else
# Pick all sequences above the threshold
return group[group.frequency .>= cdr3_frequency_threshold, :]
end
end

collapsed = combine(groupby(with_frequency, :lineage_id), filter_lineage)

return collapsed
end
71 changes: 71 additions & 0 deletions test/test_lineage_processing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,75 @@ using LineageCollapse
result = process_lineages(df, distance_metric=NormalizedHammingDistance(), clustering_method=HierarchicalClustering(0.2))
@test rand_index(result.lineage_id, df.lineage_id_mismatches_20) == 1.0
end

@testset "collapse_lineages function" begin
# Sample data for testing
test_df = DataFrame(
d_region = ["CGAT", "CGAT", "CGAT", "CGAT", "CGAT", "CGAT"],
lineage_id = [1, 1, 1, 2, 2, 2],
j_call_first = ["J1", "J1", "J1", "J2", "J2", "J2"],
v_call_first = ["V1", "V1", "V1", "V2", "V2", "V2"],
cdr3 = ["ATCG", "ATCG", "ATTG", "GCTA", "GCTA", "GCTT"],
count = [5, 3, 2, 4, 4, 2]
)

@testset "Hardest collapse strategy" begin
result = collapse_lineages(test_df, 0.0, :hardest)

@test nrow(result) == 2 # Should have one row per lineage
@test result[result.lineage_id .== 1, :cdr3][1] == "ATCG" # Most frequent for lineage 1
@test result[result.lineage_id .== 2, :cdr3][1] == "GCTA" # Most frequent for lineage 2
end

@testset "Soft collapse strategy" begin
result = collapse_lineages(test_df, 0.3, :soft)

@test nrow(result) == 4 # Should keep sequences above 30% frequency
@test "ATCG" in result.cdr3 # Should keep ATCG in lineage 1
@test "ATTG" in result.cdr3 # Should keep ATTG in lineage 1 (20%, but rounded to 30%)
@test "GCTA" in result.cdr3 # Should keep GCTA in lineage 2
@test "GCTT" in result.cdr3 # Should keep GCTT in lineage 2 (20%, but rounded to 30%)

# Check frequencies
@test result[result.cdr3 .== "ATCG", :frequency][1] 0.6666666666666666 atol=1e-6
@test result[result.cdr3 .== "ATTG", :frequency][1] 0.3333333333333333 atol=1e-6
@test result[result.cdr3 .== "GCTA", :frequency][1] 0.6666666666666666 atol=1e-6
@test result[result.cdr3 .== "GCTT", :frequency][1] 0.3333333333333333 atol=1e-6
end

@testset "Edge cases" begin
# Single sequence per lineage
single_seq_df = DataFrame(
d_region = ["CGAT", "CGAT"],
lineage_id = [1, 2],
j_call_first = ["J1", "J2"],
v_call_first = ["V1", "V2"],
cdr3 = ["ATCG", "GCTA"],
count = [1, 1]
)
result = collapse_lineages(single_seq_df, 0.0, :hardest)
@test nrow(result) == 2
@test Set(result.cdr3) == Set(["ATCG", "GCTA"])

# All sequences with equal frequency
equal_freq_df = DataFrame(
d_region = ["CGAT", "CGAT", "CGAT"],
lineage_id = [1, 1, 1],
j_call_first = ["J1", "J1", "J1"],
v_call_first = ["V1", "V1", "V1"],
cdr3 = ["ATCG", "ATTG", "ATAG"],
count = [1, 1, 1]
)
result = collapse_lineages(equal_freq_df, 0.0, :hardest)
@test nrow(result) == 1
@test result.cdr3[1] in ["ATCG", "ATTG", "ATAG"]
end

@testset "Invalid inputs" begin
@test_throws ArgumentError collapse_lineages(test_df, 0.1, :invalid_strategy)
@test_throws ArgumentError collapse_lineages(test_df, -0.1, :soft)
@test_throws ArgumentError collapse_lineages(test_df, -0.1, :hardest)
@test_throws ArgumentError collapse_lineages(test_df, 1.1, :hardest)
end
end
end

0 comments on commit 1773e58

Please sign in to comment.