From 3ed542c18c493975bc923c4593c666747aa43589 Mon Sep 17 00:00:00 2001 From: Nardus Mollentze Date: Tue, 1 Jun 2021 15:21:11 +0100 Subject: [PATCH] Illustrate screening success --- ...akeSupplement_EffectByGenomeType_Example.R | 100 +++++++++++++ .../MakeSupplement_RelatednessModelRanks.R | 128 ++++++++++++++++ .../MakeSupplement_ScreeningSuccessRate.R | 141 ++++++++++++++++++ 3 files changed, 369 insertions(+) create mode 100644 Scripts/Plotting/MakeSupplement_EffectByGenomeType_Example.R create mode 100644 Scripts/Plotting/MakeSupplement_RelatednessModelRanks.R create mode 100644 Scripts/Plotting/MakeSupplement_ScreeningSuccessRate.R diff --git a/Scripts/Plotting/MakeSupplement_EffectByGenomeType_Example.R b/Scripts/Plotting/MakeSupplement_EffectByGenomeType_Example.R new file mode 100644 index 0000000..8c9868b --- /dev/null +++ b/Scripts/Plotting/MakeSupplement_EffectByGenomeType_Example.R @@ -0,0 +1,100 @@ +library(dplyr) +library(tidyr) +library(ggplot2) +library(cowplot) +library(ggbeeswarm) + +source(file.path('Utils', 'plot_utils.R')) +source(file.path('Scripts/Plotting/PlottingConstants.R')) + + +shapley_vals_individual <- readRDS(file.path('Plots', 'Intermediates', 'figure2_virus_shapley_vals.rds')) +featureset_importance <- readRDS(file.path('Plots', 'Intermediates', 'figure2_feature_set_importance.rds')) +genome_types <- readRDS('Plots/Intermediates/figure1_merged_taxonomy.rds') + + +# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= +# ---- Example of an effect which works across genome types -------------------------------------- +# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= +top_feature <- featureset_importance$Feature[featureset_importance$Rank == 1] + +stopifnot(top_feature == 'GenomicDensity_Housekeeping_CTG.Bias_Coding') +top_feature_name <- 'CTG bias similarity (housekeeping)' + +example_vals <- shapley_vals_individual %>% + filter(Feature == top_feature) %>% + left_join(genome_types, by = c('LatestSppName' = 'Species')) %>% + mutate(Class = if_else(Class == 'No human infections', as.character(Class), 'Infects humans')) + +#ks.test(example_vals$FeatureValue[example_vals$Class == 'Infects humans'], example_vals$FeatureValue[example_vals$Class != 'InfectsHumans']) + + +axis_lims <- range(example_vals$FeatureValue) # Make sure plots are aligned +cut_points <- quantile(example_vals$FeatureValue, probs = c(0.25, 0.75)) #c(0.55, 1.95) +colours <- ZOONOTIC_STATUS_COLOURS[1:2] +names(colours) <- c('No human infections', 'Infects humans') + +hist_plot <- ggplot(example_vals, aes(x = FeatureValue, fill = Class)) + + geom_histogram(bins = 150) + + geom_vline(xintercept = cut_points, linetype = 2, colour = LINE_COLOUR) + + facet_grid(rows = vars(Class)) + + scale_fill_manual(values = colours, guide = FALSE) + + xlim(axis_lims) + + labs(y = 'Count') + + PLOT_THEME + + theme(axis.text.x = element_blank(), + axis.title.x = element_blank(), + axis.ticks.x = element_blank(), + panel.spacing = unit(3, 'pt'), + plot.margin = margin(t = 5.5, r = 5.5, b = 0, l = 5.5)) + + +shap_plot <- ggplot(example_vals, aes(x = FeatureValue, y = SHAP_mean, colour = Class)) + + geom_point(shape = 1) + + geom_hline(yintercept = 0, linetype = 2, colour = LINE_COLOUR) + + geom_vline(xintercept = cut_points, linetype = 2, colour = LINE_COLOUR) + + scale_colour_manual(values = colours, guide = FALSE) + + xlim(axis_lims) + + labs(y = 'Contribution to log odds\n(SHAP value)') + + PLOT_THEME + + theme(axis.text.x = element_blank(), + axis.title.x = element_blank(), + axis.ticks.x = element_blank(), + plot.margin = margin(t = 5.5, r = 5.5, b = 0, l = 5.5)) + + + + +dist_plot <- ggplot(example_vals, aes(x = GenomeType, y = FeatureValue, fill = Class)) + + geom_boxplot(colour = LINE_COLOUR, position = position_dodge(width = 0.9)) + + geom_hline(yintercept = cut_points, linetype = 2, colour = LINE_COLOUR) + + coord_flip() + + #facet_grid(rows = vars(GenomeType), scales = 'free_y') + + scale_fill_manual(values = colours, guide = FALSE) + + ylim(axis_lims) + + labs(x = 'Genome type', y = top_feature_name) + + PLOT_THEME + + theme(panel.spacing = unit(0, 'pt'), + strip.text = element_blank()) + + +p <- plot_grid(hist_plot, shap_plot, dist_plot, + nrow = 3, rel_heights = c(1.5, 1, 2), + align = 'v', axis = 'lr') + +ggsave2(file.path('Plots', 'Supplement_EffectByGenomeType_Example.pdf'), p, width = 7, height = 8, units = 'in') + + +# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= +# ---- Values for legend -------------------------------------------------------------------------- +# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= +cat("Cutpoints: c1 =", cut_points[1], "| c2 =", cut_points[2]) + +example_vals %>% + group_by(Class) %>% + summarise(below_c1 = sum(FeatureValue < cut_points[1]), + above_c2 = sum(FeatureValue > cut_points[2])) %>% + ungroup() %>% + mutate(below_c1_prop = below_c1/sum(below_c1), + above_c2_prop = above_c2/sum(above_c2)) %>% + print() diff --git a/Scripts/Plotting/MakeSupplement_RelatednessModelRanks.R b/Scripts/Plotting/MakeSupplement_RelatednessModelRanks.R new file mode 100644 index 0000000..efc2112 --- /dev/null +++ b/Scripts/Plotting/MakeSupplement_RelatednessModelRanks.R @@ -0,0 +1,128 @@ +# +# Plot an overview of ranks for training set viruses, using ranks from relatedness-based models +# - Very similar to "MakeSupplement_TrainingSetRanks.R", which produces the equivalent plot for +# ranks from the main model +# + +library(dplyr) +library(tidyr) +library(stringr) +library(readr) +library(readxl) +library(ggplot2) +library(cowplot) + +source(file.path('Scripts', 'Plotting', 'PlottingConstants.R')) +source(file.path('Utils', 'plot_utils.R')) + +training_data <- readRDS(file.path('CalculatedData', 'SplitData_Training.rds')) + +merged_taxonomy <- readRDS(file.path('Plots', 'Intermediates', 'figure1_merged_taxonomy.rds')) +zoo_status <- readRDS(file.path('Plots', 'Intermediates', 'figure1_zoo_status.rds')) + +taxonomy_predictions <- readRDS(file.path("RunData", "Taxonomy_LongRun", "Taxonomy_LongRun_Bagged_predictions.rds")) +pn_predictions <- readRDS(file.path("RunData", "PN_LongRun", "PN_LongRun_Bagged_predictions.rds")) + + +## Prepare data +genome_type_order <- unique(merged_taxonomy$GenomeType) +dna <- genome_type_order[grepl('DNA', genome_type_order)] +rna <- genome_type_order[!grepl('DNA', genome_type_order)] +genome_type_order <- rev(c(sort(dna), sort(rna))) # RNA viruses at top of plot + +zoo_status <- zoo_status %>% + rename(infection_class = .data$Class) + +name_matches <- training_data %>% + select(.data$UniversalName, .data$LatestSppName) + +stopifnot(n_distinct(name_matches$UniversalName) == nrow(name_matches)) + + +## Add priority categories +prioritize <- function(lower_bound, upper_bound, median, cutoff) { + stopifnot(length(cutoff) == 1) + stopifnot(cutoff > 0 & cutoff < 1) + + p <- if_else(lower_bound > cutoff, 'Very high', + if_else(median > cutoff, 'High', + if_else(upper_bound > cutoff, 'Medium', + 'Low'))) + factor(p, levels = c('Low', 'Medium', 'High', 'Very high')) +} + + +## Plots +plot_ranks <- function(prediction_data) { + prediction_data <- prediction_data %>% + left_join(name_matches, by = "UniversalName") %>% + left_join(merged_taxonomy, by = c('LatestSppName' = 'Species')) %>% + left_join(zoo_status, by = "LatestSppName") + + cutoff <- find_balanced_cutoff(observed_labels = prediction_data$InfectsHumans, + predicted_score = prediction_data$BagScore) + + cat("Using cutoff:", cutoff, "\n") + + ranks <- prediction_data %>% + mutate(priority = prioritize(lower_bound = .data$BagScore_Lower, + upper_bound = .data$BagScore_Upper, + median = .data$BagScore, + cutoff = cutoff)) %>% + + arrange(.data$BagScore) %>% + mutate(species = factor(.data$LatestSppName, levels = .data$LatestSppName), + Family = factor(.data$Family, levels = sort(unique(.data$Family), decreasing = TRUE)), + priority = factor(.data$priority, levels = c('Low', 'Medium', 'High', 'Very high')), + infection_class = factor(.data$infection_class, + levels = c('Human virus', 'Zoonotic', 'No human infections')), + GenomeType = factor(.data$GenomeType, levels = genome_type_order)) + + + ## Main panel + main_panel <- ggplot(ranks, aes(x = species, y = BagScore, colour = infection_class)) + + geom_errorbar(aes(ymin = BagScore_Lower, ymax = BagScore_Upper), width = 0.5) + + geom_step(group = 1, colour = 'grey10', size = 0.6) + + geom_hline(yintercept = cutoff, linetype = 2, colour = 'grey10') + + scale_colour_manual(values = ZOONOTIC_STATUS_COLOURS) + + scale_y_continuous(expand = expand_scale(add = c(0.02, 0.02))) + + labs(x = NULL, y = SCORE_LABEL_2LINE, colour = 'Current status') + + PLOT_THEME + + theme(panel.grid.major.x = element_blank(), + axis.ticks.x = element_blank(), + axis.text.x = element_blank(), + plot.margin = margin(t = 5.5, r = 5.5, b = 1, l = 5.5)) + + + ## Indicate families + family_indicator_plot <- ggplot(ranks, aes(x = species, y = Family, fill = priority)) + + geom_tile() + + facet_grid(rows = vars(GenomeType), scales = 'free', space = 'free') + + scale_fill_manual(values = PRIORITY_COLOURS, drop = FALSE) + + PLOT_THEME + + labs(x = 'Virus species (ranked)', fill = 'Zoonotic potential') + + theme(axis.text.x = element_blank(), + axis.ticks.x = element_blank(), + axis.text.y = element_text(lineheight = 0.65, face = 'italic'), + panel.spacing = unit(0.05, 'lines'), + panel.background = element_rect(fill = 'grey25'), + strip.background = element_rect(fill = 'white'), + strip.text.y = element_text(angle = 0), + plot.margin = margin(t = 0, r = 5.5, b = 5.5, l = 5.5)) + + + plot_grid(main_panel, family_indicator_plot, + nrow = 2, rel_heights = c(1, 5), + align = 'v', axis = 'lr') +} + + +taxonomy_plot <- plot_ranks(taxonomy_predictions) +pn_plot <- plot_ranks(pn_predictions) + + +## Save +combined_plot <- plot_grid(taxonomy_plot, pn_plot, + nrow = 2, labels = c("A", "B")) + +ggsave2(file.path('Plots', 'Supplement_RelatednessModelRanks.pdf'), combined_plot, width = 7, height = 8.5) \ No newline at end of file diff --git a/Scripts/Plotting/MakeSupplement_ScreeningSuccessRate.R b/Scripts/Plotting/MakeSupplement_ScreeningSuccessRate.R new file mode 100644 index 0000000..e147eab --- /dev/null +++ b/Scripts/Plotting/MakeSupplement_ScreeningSuccessRate.R @@ -0,0 +1,141 @@ +## Plot a comparison of accumulation curves (i.e. fig 1D, but for alternate models too) + +library(dplyr) +library(tidyr) +library(ggplot2) +library(scales) +library(cowplot) + +source(file.path('Scripts', 'Plotting', 'PlottingConstants.R')) + +set.seed(3105221) + +## Data +best_testing <- readRDS(file.path('Plots', 'Intermediates', 'figure1_virus_testing.rds')) + +tax_preds <- readRDS("RunData/Taxonomy_LongRun/Taxonomy_LongRun_Bagged_predictions.rds") %>% + mutate(Rank = rank(-.data$BagScore)) %>% + arrange(.data$Rank) + +pn_preds <- readRDS("RunData/PN_LongRun/PN_LongRun_Bagged_predictions.rds") %>% + mutate(Rank = rank(-.data$BagScore)) %>% + arrange(.data$Rank) + + +## Calculate accumulation curves +simulate_testing <- function(rank_data) { + virus_testing <- data.frame() + total_viruses <- nrow(rank_data) + pos_viruses <- sum(rank_data$InfectsHumans == "True") + + for (i in 1:nrow(rank_data)) { + current_row <- rank_data[i, ] + viruses_before <- rank_data[rank_data$Rank <= current_row$Rank, ] + + virus_testing <- rbind(virus_testing, data.frame( + Rank = current_row$Rank, + prop_screened = nrow(viruses_before)/total_viruses, + prop_found = sum(viruses_before$InfectsHumans == "True")/pos_viruses + )) + } + + virus_testing +} + +tax_testing <- simulate_testing(tax_preds) + +pn_testing <- simulate_testing(pn_preds) + + +combined_testing <- list("Taxonomic" = tax_testing, + "Phylogenetic\nneighbourhood" = pn_testing, + "All genome\ncomposition\nfeature sets" = best_testing) + + +## Simulate random screening +observed_prevalence <- sum(tax_preds$InfectsHumans == "True") / nrow(tax_preds) + +screen_randomly <- function(original_labels = tax_preds$InfectsHumans) { + n <- length(original_labels) + + labels <- sample(original_labels) + + screened <- vector(length = n) + found <- vector(length = n) + + for (i in 1:n) { + screened[i] <- i / n + found[i] <- sum(labels[1:i] == "True") / sum(labels == "True") + } + + data.frame(prop_screened = screened, + prop_found = found) +} + +random_testing <- replicate(1000, screen_randomly(), simplify = FALSE) %>% + bind_rows(.id = "replicate") + + +## Find point at which 50% of human-infecting viruses are found: +get_point <- function(data, example_points) { + positive_accumulation_fun <- ecdf(data$prop_found) + + data.frame(prop_found = example_points, + prop_screened = positive_accumulation_fun(example_points)) +} + +examples <- lapply(combined_testing, get_point, example_points = 0.5) %>% + bind_rows(.id = "model") + +combined_testing <- combined_testing %>% + bind_rows(.id = "model") + + +## Reduction in effort: +best_model = "All genome\ncomposition\nfeature sets" + +reduction_arrow_data <- examples %>% + filter(.data$model == best_model | + .data$prop_screened == min(.data$prop_screened[.data$model != best_model])) %>% + summarise(lower_val = min(.data$prop_screened), + upper_val = max(.data$prop_screened)) + +reduction_label_data <- reduction_arrow_data %>% + summarise(reduction = .data$upper_val/.data$lower_val, + label_pos = (.data$upper_val - .data$lower_val)/2 + .data$lower_val) + + +## Plot +p <- ggplot(combined_testing) + + geom_step(aes(x = prop_screened, y = prop_found, group = replicate), + colour = 'grey90', size = 0.2, + data = random_testing) + + + geom_hline(yintercept = 0.5, linetype = 3, colour = LINE_COLOUR) + + geom_segment(aes(x = prop_screened, xend = prop_screened, y = -Inf, yend = prop_found), + linetype = 3, colour = LINE_COLOUR, data = examples) + + + geom_text(aes(label = sprintf("%2.2f-fold\ndifference\nin effort", reduction), x = label_pos, y = 0.08), + colour = LINE_COLOUR, size = 2.5, data = reduction_label_data) + + geom_segment(aes(x = lower_val, xend = upper_val, y = 0.155, yend = 0.155), + arrow = arrow(length = unit(0.3, "lines"), type = "closed", ends = "both"), + colour = LINE_COLOUR, data = reduction_arrow_data) + + + geom_step(aes(x = prop_screened, y = prop_found, colour = model), size = 1) + + + scale_x_continuous(labels = percent_format(), expand = expand_scale()) + + scale_y_continuous(labels = percent_format(accuracy = 1), expand = expand_scale()) + + scale_colour_brewer(palette = "Set2", direction = -1) + + + labs(x = "Proportion of viruses screened", y = "Proportion of known human-infecting viruses encountered", colour = "Feature set") + + coord_equal() + + PLOT_THEME + + theme(legend.key.height = unit(2, "lines")) + + +ggsave2("Plots/Supplement_ScreeningSuccessRate.pdf", width = 6, height = 4) + + +## Mentioned in text: +cat("\nScreening required to find 50% of human-infecting viruses:\n") +print(examples)