From 1773e5816da1bed35cfe9c177a2a4e3532957b2f Mon Sep 17 00:00:00 2001 From: Mateusz Kaduk Date: Fri, 18 Oct 2024 12:50:05 +0200 Subject: [PATCH] src/lineage_processing.jl: Separate function for soft and hardest lineage collapsing. --- src/LineageCollapse.jl | 2 +- src/lineage_processing.jl | 61 ++++++++++++++++++++++++++-- test/test_lineage_processing.jl | 71 +++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 4 deletions(-) diff --git a/src/LineageCollapse.jl b/src/LineageCollapse.jl index c097355..9653c45 100644 --- a/src/LineageCollapse.jl +++ b/src/LineageCollapse.jl @@ -6,7 +6,7 @@ module LineageCollapse using BioSequences using StringDistances - export load_data, preprocess_data, deduplicate_data, process_lineages, plot_diagnostics + export load_data, preprocess_data, deduplicate_data, process_lineages, collapse_lineages, plot_diagnostics export DistanceMetric, ClusteringMethod export HammingDistance, NormalizedHammingDistance, LevenshteinDistance, HierarchicalClustering export compute_distance, compute_pairwise_distance, perform_clustering diff --git a/src/lineage_processing.jl b/src/lineage_processing.jl index aaa8d6f..43f32dd 100644 --- a/src/lineage_processing.jl +++ b/src/lineage_processing.jl @@ -94,7 +94,7 @@ function process_lineages(df::DataFrame; processed_groups = Vector{DataFrame}() prog = Progress(length(grouped), desc="Processing lineages") - for (group_id, group) in enumerate(grouped) + @inbounds for (group_id, group) in enumerate(grouped) next!(prog) if nrow(group) > 1 sequences = LongDNA{4}.(group.cdr3) @@ -103,11 +103,11 @@ function process_lineages(df::DataFrame; else group[!, :cluster] .= 1 end - + group[!, :group_id] .= group_id cluster_grouped = groupby(group, :cluster) - for cgroup in cluster_grouped + @inbounds for cgroup in cluster_grouped cgroup[!, :cluster_size] .= nrow(cgroup) cgroup = transform(groupby(cgroup, [:v_call_first, :j_call_first, :cluster, :cdr3_length, :cdr3, :d_region, :cluster_size, :group_id]), nrow => :cdr3_count) transform!(groupby(cgroup, :cluster), :cdr3_count => maximum => :max_cdr3_count) @@ -121,4 +121,59 @@ function process_lineages(df::DataFrame; result[!, :lineage_id] = groupindices(groupby(result, [:group_id, :cluster])) return result +end + +""" + collapse_lineages(df::DataFrame, cdr3_frequency_threshold::Float64, collapse_strategy::Symbol=:hardest) + +Collapse lineages in a DataFrame based on CDR3 sequence frequency and a specified collapse strategy. + +# Arguments +- `df::DataFrame`: Input DataFrame containing lineage data. Must have columns [:d_region, :lineage_id, :j_call_first, :v_call_first, :cdr3]. +- `cdr3_frequency_threshold::Float64`: Minimum frequency threshold for CDR3 sequences (0.0 to 1.0). +- `collapse_strategy::Symbol=:hardest`: Strategy for collapsing lineages. Options are: + - `:hardest`: Select only the most frequent sequence for each lineage. + - `:soft`: Select all sequences that meet or exceed the `cdr3_frequency_threshold`. + +# Returns +- `DataFrame`: Collapsed lineage data. + +# Example +```julia +lineages = DataFrame(...) # Your input data +collapsed = collapse_lineages(lineages, 0.1, :soft) +``` +""" +function collapse_lineages(df::DataFrame, cdr3_frequency_threshold::Float64, collapse_strategy::Symbol=:hardest) + if !(0.0 <= cdr3_frequency_threshold <= 1.0) + throw(ArgumentError("cdr3_frequency_threshold must be between 0.0 and 1.0")) + end + if !(collapse_strategy in [:hardest, :soft]) + throw(ArgumentError("Invalid collapse strategy. Use :hardest or :soft.")) + end + + # Group by the specified columns + grouped = groupby(df, [:d_region, :lineage_id, :j_call_first, :v_call_first, :cdr3]) + + # Count occurrences of each unique combination + counted = combine(grouped, nrow => :count) + + # Calculate frequency within each lineage + lineage_grouped = groupby(counted, :lineage_id) + with_frequency = transform(lineage_grouped, :count => (x -> x ./ sum(x)) => :frequency) + + # Filter based on the threshold and collapse strategy + function filter_lineage(group) + if collapse_strategy == :hardest + # Pick the single most frequent sequence + return group[argmax(group.frequency), :] + else + # Pick all sequences above the threshold + return group[group.frequency .>= cdr3_frequency_threshold, :] + end + end + + collapsed = combine(groupby(with_frequency, :lineage_id), filter_lineage) + + return collapsed end \ No newline at end of file diff --git a/test/test_lineage_processing.jl b/test/test_lineage_processing.jl index 1c944bf..a9363da 100644 --- a/test/test_lineage_processing.jl +++ b/test/test_lineage_processing.jl @@ -119,4 +119,75 @@ using LineageCollapse result = process_lineages(df, distance_metric=NormalizedHammingDistance(), clustering_method=HierarchicalClustering(0.2)) @test rand_index(result.lineage_id, df.lineage_id_mismatches_20) == 1.0 end + + @testset "collapse_lineages function" begin + # Sample data for testing + test_df = DataFrame( + d_region = ["CGAT", "CGAT", "CGAT", "CGAT", "CGAT", "CGAT"], + lineage_id = [1, 1, 1, 2, 2, 2], + j_call_first = ["J1", "J1", "J1", "J2", "J2", "J2"], + v_call_first = ["V1", "V1", "V1", "V2", "V2", "V2"], + cdr3 = ["ATCG", "ATCG", "ATTG", "GCTA", "GCTA", "GCTT"], + count = [5, 3, 2, 4, 4, 2] + ) + + @testset "Hardest collapse strategy" begin + result = collapse_lineages(test_df, 0.0, :hardest) + + @test nrow(result) == 2 # Should have one row per lineage + @test result[result.lineage_id .== 1, :cdr3][1] == "ATCG" # Most frequent for lineage 1 + @test result[result.lineage_id .== 2, :cdr3][1] == "GCTA" # Most frequent for lineage 2 + end + + @testset "Soft collapse strategy" begin + result = collapse_lineages(test_df, 0.3, :soft) + + @test nrow(result) == 4 # Should keep sequences above 30% frequency + @test "ATCG" in result.cdr3 # Should keep ATCG in lineage 1 + @test "ATTG" in result.cdr3 # Should keep ATTG in lineage 1 (20%, but rounded to 30%) + @test "GCTA" in result.cdr3 # Should keep GCTA in lineage 2 + @test "GCTT" in result.cdr3 # Should keep GCTT in lineage 2 (20%, but rounded to 30%) + + # Check frequencies + @test result[result.cdr3 .== "ATCG", :frequency][1] ≈ 0.6666666666666666 atol=1e-6 + @test result[result.cdr3 .== "ATTG", :frequency][1] ≈ 0.3333333333333333 atol=1e-6 + @test result[result.cdr3 .== "GCTA", :frequency][1] ≈ 0.6666666666666666 atol=1e-6 + @test result[result.cdr3 .== "GCTT", :frequency][1] ≈ 0.3333333333333333 atol=1e-6 + end + + @testset "Edge cases" begin + # Single sequence per lineage + single_seq_df = DataFrame( + d_region = ["CGAT", "CGAT"], + lineage_id = [1, 2], + j_call_first = ["J1", "J2"], + v_call_first = ["V1", "V2"], + cdr3 = ["ATCG", "GCTA"], + count = [1, 1] + ) + result = collapse_lineages(single_seq_df, 0.0, :hardest) + @test nrow(result) == 2 + @test Set(result.cdr3) == Set(["ATCG", "GCTA"]) + + # All sequences with equal frequency + equal_freq_df = DataFrame( + d_region = ["CGAT", "CGAT", "CGAT"], + lineage_id = [1, 1, 1], + j_call_first = ["J1", "J1", "J1"], + v_call_first = ["V1", "V1", "V1"], + cdr3 = ["ATCG", "ATTG", "ATAG"], + count = [1, 1, 1] + ) + result = collapse_lineages(equal_freq_df, 0.0, :hardest) + @test nrow(result) == 1 + @test result.cdr3[1] in ["ATCG", "ATTG", "ATAG"] + end + + @testset "Invalid inputs" begin + @test_throws ArgumentError collapse_lineages(test_df, 0.1, :invalid_strategy) + @test_throws ArgumentError collapse_lineages(test_df, -0.1, :soft) + @test_throws ArgumentError collapse_lineages(test_df, -0.1, :hardest) + @test_throws ArgumentError collapse_lineages(test_df, 1.1, :hardest) + end + end end \ No newline at end of file