-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
369 additions
and
0 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
Scripts/Plotting/MakeSupplement_EffectByGenomeType_Example.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
library(dplyr) | ||
library(tidyr) | ||
library(ggplot2) | ||
library(cowplot) | ||
library(ggbeeswarm) | ||
|
||
source(file.path('Utils', 'plot_utils.R')) | ||
source(file.path('Scripts/Plotting/PlottingConstants.R')) | ||
|
||
|
||
shapley_vals_individual <- readRDS(file.path('Plots', 'Intermediates', 'figure2_virus_shapley_vals.rds')) | ||
featureset_importance <- readRDS(file.path('Plots', 'Intermediates', 'figure2_feature_set_importance.rds')) | ||
genome_types <- readRDS('Plots/Intermediates/figure1_merged_taxonomy.rds') | ||
|
||
|
||
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= | ||
# ---- Example of an effect which works across genome types -------------------------------------- | ||
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= | ||
top_feature <- featureset_importance$Feature[featureset_importance$Rank == 1] | ||
|
||
stopifnot(top_feature == 'GenomicDensity_Housekeeping_CTG.Bias_Coding') | ||
top_feature_name <- 'CTG bias similarity (housekeeping)' | ||
|
||
example_vals <- shapley_vals_individual %>% | ||
filter(Feature == top_feature) %>% | ||
left_join(genome_types, by = c('LatestSppName' = 'Species')) %>% | ||
mutate(Class = if_else(Class == 'No human infections', as.character(Class), 'Infects humans')) | ||
|
||
#ks.test(example_vals$FeatureValue[example_vals$Class == 'Infects humans'], example_vals$FeatureValue[example_vals$Class != 'InfectsHumans']) | ||
|
||
|
||
axis_lims <- range(example_vals$FeatureValue) # Make sure plots are aligned | ||
cut_points <- quantile(example_vals$FeatureValue, probs = c(0.25, 0.75)) #c(0.55, 1.95) | ||
colours <- ZOONOTIC_STATUS_COLOURS[1:2] | ||
names(colours) <- c('No human infections', 'Infects humans') | ||
|
||
hist_plot <- ggplot(example_vals, aes(x = FeatureValue, fill = Class)) + | ||
geom_histogram(bins = 150) + | ||
geom_vline(xintercept = cut_points, linetype = 2, colour = LINE_COLOUR) + | ||
facet_grid(rows = vars(Class)) + | ||
scale_fill_manual(values = colours, guide = FALSE) + | ||
xlim(axis_lims) + | ||
labs(y = 'Count') + | ||
PLOT_THEME + | ||
theme(axis.text.x = element_blank(), | ||
axis.title.x = element_blank(), | ||
axis.ticks.x = element_blank(), | ||
panel.spacing = unit(3, 'pt'), | ||
plot.margin = margin(t = 5.5, r = 5.5, b = 0, l = 5.5)) | ||
|
||
|
||
shap_plot <- ggplot(example_vals, aes(x = FeatureValue, y = SHAP_mean, colour = Class)) + | ||
geom_point(shape = 1) + | ||
geom_hline(yintercept = 0, linetype = 2, colour = LINE_COLOUR) + | ||
geom_vline(xintercept = cut_points, linetype = 2, colour = LINE_COLOUR) + | ||
scale_colour_manual(values = colours, guide = FALSE) + | ||
xlim(axis_lims) + | ||
labs(y = 'Contribution to log odds\n(SHAP value)') + | ||
PLOT_THEME + | ||
theme(axis.text.x = element_blank(), | ||
axis.title.x = element_blank(), | ||
axis.ticks.x = element_blank(), | ||
plot.margin = margin(t = 5.5, r = 5.5, b = 0, l = 5.5)) | ||
|
||
|
||
|
||
|
||
dist_plot <- ggplot(example_vals, aes(x = GenomeType, y = FeatureValue, fill = Class)) + | ||
geom_boxplot(colour = LINE_COLOUR, position = position_dodge(width = 0.9)) + | ||
geom_hline(yintercept = cut_points, linetype = 2, colour = LINE_COLOUR) + | ||
coord_flip() + | ||
#facet_grid(rows = vars(GenomeType), scales = 'free_y') + | ||
scale_fill_manual(values = colours, guide = FALSE) + | ||
ylim(axis_lims) + | ||
labs(x = 'Genome type', y = top_feature_name) + | ||
PLOT_THEME + | ||
theme(panel.spacing = unit(0, 'pt'), | ||
strip.text = element_blank()) | ||
|
||
|
||
p <- plot_grid(hist_plot, shap_plot, dist_plot, | ||
nrow = 3, rel_heights = c(1.5, 1, 2), | ||
align = 'v', axis = 'lr') | ||
|
||
ggsave2(file.path('Plots', 'Supplement_EffectByGenomeType_Example.pdf'), p, width = 7, height = 8, units = 'in') | ||
|
||
|
||
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= | ||
# ---- Values for legend -------------------------------------------------------------------------- | ||
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= | ||
cat("Cutpoints: c1 =", cut_points[1], "| c2 =", cut_points[2]) | ||
|
||
example_vals %>% | ||
group_by(Class) %>% | ||
summarise(below_c1 = sum(FeatureValue < cut_points[1]), | ||
above_c2 = sum(FeatureValue > cut_points[2])) %>% | ||
ungroup() %>% | ||
mutate(below_c1_prop = below_c1/sum(below_c1), | ||
above_c2_prop = above_c2/sum(above_c2)) %>% | ||
print() |
128 changes: 128 additions & 0 deletions
128
Scripts/Plotting/MakeSupplement_RelatednessModelRanks.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# | ||
# Plot an overview of ranks for training set viruses, using ranks from relatedness-based models | ||
# - Very similar to "MakeSupplement_TrainingSetRanks.R", which produces the equivalent plot for | ||
# ranks from the main model | ||
# | ||
|
||
library(dplyr) | ||
library(tidyr) | ||
library(stringr) | ||
library(readr) | ||
library(readxl) | ||
library(ggplot2) | ||
library(cowplot) | ||
|
||
source(file.path('Scripts', 'Plotting', 'PlottingConstants.R')) | ||
source(file.path('Utils', 'plot_utils.R')) | ||
|
||
training_data <- readRDS(file.path('CalculatedData', 'SplitData_Training.rds')) | ||
|
||
merged_taxonomy <- readRDS(file.path('Plots', 'Intermediates', 'figure1_merged_taxonomy.rds')) | ||
zoo_status <- readRDS(file.path('Plots', 'Intermediates', 'figure1_zoo_status.rds')) | ||
|
||
taxonomy_predictions <- readRDS(file.path("RunData", "Taxonomy_LongRun", "Taxonomy_LongRun_Bagged_predictions.rds")) | ||
pn_predictions <- readRDS(file.path("RunData", "PN_LongRun", "PN_LongRun_Bagged_predictions.rds")) | ||
|
||
|
||
## Prepare data | ||
genome_type_order <- unique(merged_taxonomy$GenomeType) | ||
dna <- genome_type_order[grepl('DNA', genome_type_order)] | ||
rna <- genome_type_order[!grepl('DNA', genome_type_order)] | ||
genome_type_order <- rev(c(sort(dna), sort(rna))) # RNA viruses at top of plot | ||
|
||
zoo_status <- zoo_status %>% | ||
rename(infection_class = .data$Class) | ||
|
||
name_matches <- training_data %>% | ||
select(.data$UniversalName, .data$LatestSppName) | ||
|
||
stopifnot(n_distinct(name_matches$UniversalName) == nrow(name_matches)) | ||
|
||
|
||
## Add priority categories | ||
prioritize <- function(lower_bound, upper_bound, median, cutoff) { | ||
stopifnot(length(cutoff) == 1) | ||
stopifnot(cutoff > 0 & cutoff < 1) | ||
|
||
p <- if_else(lower_bound > cutoff, 'Very high', | ||
if_else(median > cutoff, 'High', | ||
if_else(upper_bound > cutoff, 'Medium', | ||
'Low'))) | ||
factor(p, levels = c('Low', 'Medium', 'High', 'Very high')) | ||
} | ||
|
||
|
||
## Plots | ||
plot_ranks <- function(prediction_data) { | ||
prediction_data <- prediction_data %>% | ||
left_join(name_matches, by = "UniversalName") %>% | ||
left_join(merged_taxonomy, by = c('LatestSppName' = 'Species')) %>% | ||
left_join(zoo_status, by = "LatestSppName") | ||
|
||
cutoff <- find_balanced_cutoff(observed_labels = prediction_data$InfectsHumans, | ||
predicted_score = prediction_data$BagScore) | ||
|
||
cat("Using cutoff:", cutoff, "\n") | ||
|
||
ranks <- prediction_data %>% | ||
mutate(priority = prioritize(lower_bound = .data$BagScore_Lower, | ||
upper_bound = .data$BagScore_Upper, | ||
median = .data$BagScore, | ||
cutoff = cutoff)) %>% | ||
|
||
arrange(.data$BagScore) %>% | ||
mutate(species = factor(.data$LatestSppName, levels = .data$LatestSppName), | ||
Family = factor(.data$Family, levels = sort(unique(.data$Family), decreasing = TRUE)), | ||
priority = factor(.data$priority, levels = c('Low', 'Medium', 'High', 'Very high')), | ||
infection_class = factor(.data$infection_class, | ||
levels = c('Human virus', 'Zoonotic', 'No human infections')), | ||
GenomeType = factor(.data$GenomeType, levels = genome_type_order)) | ||
|
||
|
||
## Main panel | ||
main_panel <- ggplot(ranks, aes(x = species, y = BagScore, colour = infection_class)) + | ||
geom_errorbar(aes(ymin = BagScore_Lower, ymax = BagScore_Upper), width = 0.5) + | ||
geom_step(group = 1, colour = 'grey10', size = 0.6) + | ||
geom_hline(yintercept = cutoff, linetype = 2, colour = 'grey10') + | ||
scale_colour_manual(values = ZOONOTIC_STATUS_COLOURS) + | ||
scale_y_continuous(expand = expand_scale(add = c(0.02, 0.02))) + | ||
labs(x = NULL, y = SCORE_LABEL_2LINE, colour = 'Current status') + | ||
PLOT_THEME + | ||
theme(panel.grid.major.x = element_blank(), | ||
axis.ticks.x = element_blank(), | ||
axis.text.x = element_blank(), | ||
plot.margin = margin(t = 5.5, r = 5.5, b = 1, l = 5.5)) | ||
|
||
|
||
## Indicate families | ||
family_indicator_plot <- ggplot(ranks, aes(x = species, y = Family, fill = priority)) + | ||
geom_tile() + | ||
facet_grid(rows = vars(GenomeType), scales = 'free', space = 'free') + | ||
scale_fill_manual(values = PRIORITY_COLOURS, drop = FALSE) + | ||
PLOT_THEME + | ||
labs(x = 'Virus species (ranked)', fill = 'Zoonotic potential') + | ||
theme(axis.text.x = element_blank(), | ||
axis.ticks.x = element_blank(), | ||
axis.text.y = element_text(lineheight = 0.65, face = 'italic'), | ||
panel.spacing = unit(0.05, 'lines'), | ||
panel.background = element_rect(fill = 'grey25'), | ||
strip.background = element_rect(fill = 'white'), | ||
strip.text.y = element_text(angle = 0), | ||
plot.margin = margin(t = 0, r = 5.5, b = 5.5, l = 5.5)) | ||
|
||
|
||
plot_grid(main_panel, family_indicator_plot, | ||
nrow = 2, rel_heights = c(1, 5), | ||
align = 'v', axis = 'lr') | ||
} | ||
|
||
|
||
taxonomy_plot <- plot_ranks(taxonomy_predictions) | ||
pn_plot <- plot_ranks(pn_predictions) | ||
|
||
|
||
## Save | ||
combined_plot <- plot_grid(taxonomy_plot, pn_plot, | ||
nrow = 2, labels = c("A", "B")) | ||
|
||
ggsave2(file.path('Plots', 'Supplement_RelatednessModelRanks.pdf'), combined_plot, width = 7, height = 8.5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
## Plot a comparison of accumulation curves (i.e. fig 1D, but for alternate models too) | ||
|
||
library(dplyr) | ||
library(tidyr) | ||
library(ggplot2) | ||
library(scales) | ||
library(cowplot) | ||
|
||
source(file.path('Scripts', 'Plotting', 'PlottingConstants.R')) | ||
|
||
set.seed(3105221) | ||
|
||
## Data | ||
best_testing <- readRDS(file.path('Plots', 'Intermediates', 'figure1_virus_testing.rds')) | ||
|
||
tax_preds <- readRDS("RunData/Taxonomy_LongRun/Taxonomy_LongRun_Bagged_predictions.rds") %>% | ||
mutate(Rank = rank(-.data$BagScore)) %>% | ||
arrange(.data$Rank) | ||
|
||
pn_preds <- readRDS("RunData/PN_LongRun/PN_LongRun_Bagged_predictions.rds") %>% | ||
mutate(Rank = rank(-.data$BagScore)) %>% | ||
arrange(.data$Rank) | ||
|
||
|
||
## Calculate accumulation curves | ||
simulate_testing <- function(rank_data) { | ||
virus_testing <- data.frame() | ||
total_viruses <- nrow(rank_data) | ||
pos_viruses <- sum(rank_data$InfectsHumans == "True") | ||
|
||
for (i in 1:nrow(rank_data)) { | ||
current_row <- rank_data[i, ] | ||
viruses_before <- rank_data[rank_data$Rank <= current_row$Rank, ] | ||
|
||
virus_testing <- rbind(virus_testing, data.frame( | ||
Rank = current_row$Rank, | ||
prop_screened = nrow(viruses_before)/total_viruses, | ||
prop_found = sum(viruses_before$InfectsHumans == "True")/pos_viruses | ||
)) | ||
} | ||
|
||
virus_testing | ||
} | ||
|
||
tax_testing <- simulate_testing(tax_preds) | ||
|
||
pn_testing <- simulate_testing(pn_preds) | ||
|
||
|
||
combined_testing <- list("Taxonomic" = tax_testing, | ||
"Phylogenetic\nneighbourhood" = pn_testing, | ||
"All genome\ncomposition\nfeature sets" = best_testing) | ||
|
||
|
||
## Simulate random screening | ||
observed_prevalence <- sum(tax_preds$InfectsHumans == "True") / nrow(tax_preds) | ||
|
||
screen_randomly <- function(original_labels = tax_preds$InfectsHumans) { | ||
n <- length(original_labels) | ||
|
||
labels <- sample(original_labels) | ||
|
||
screened <- vector(length = n) | ||
found <- vector(length = n) | ||
|
||
for (i in 1:n) { | ||
screened[i] <- i / n | ||
found[i] <- sum(labels[1:i] == "True") / sum(labels == "True") | ||
} | ||
|
||
data.frame(prop_screened = screened, | ||
prop_found = found) | ||
} | ||
|
||
random_testing <- replicate(1000, screen_randomly(), simplify = FALSE) %>% | ||
bind_rows(.id = "replicate") | ||
|
||
|
||
## Find point at which 50% of human-infecting viruses are found: | ||
get_point <- function(data, example_points) { | ||
positive_accumulation_fun <- ecdf(data$prop_found) | ||
|
||
data.frame(prop_found = example_points, | ||
prop_screened = positive_accumulation_fun(example_points)) | ||
} | ||
|
||
examples <- lapply(combined_testing, get_point, example_points = 0.5) %>% | ||
bind_rows(.id = "model") | ||
|
||
combined_testing <- combined_testing %>% | ||
bind_rows(.id = "model") | ||
|
||
|
||
## Reduction in effort: | ||
best_model = "All genome\ncomposition\nfeature sets" | ||
|
||
reduction_arrow_data <- examples %>% | ||
filter(.data$model == best_model | | ||
.data$prop_screened == min(.data$prop_screened[.data$model != best_model])) %>% | ||
summarise(lower_val = min(.data$prop_screened), | ||
upper_val = max(.data$prop_screened)) | ||
|
||
reduction_label_data <- reduction_arrow_data %>% | ||
summarise(reduction = .data$upper_val/.data$lower_val, | ||
label_pos = (.data$upper_val - .data$lower_val)/2 + .data$lower_val) | ||
|
||
|
||
## Plot | ||
p <- ggplot(combined_testing) + | ||
geom_step(aes(x = prop_screened, y = prop_found, group = replicate), | ||
colour = 'grey90', size = 0.2, | ||
data = random_testing) + | ||
|
||
geom_hline(yintercept = 0.5, linetype = 3, colour = LINE_COLOUR) + | ||
geom_segment(aes(x = prop_screened, xend = prop_screened, y = -Inf, yend = prop_found), | ||
linetype = 3, colour = LINE_COLOUR, data = examples) + | ||
|
||
geom_text(aes(label = sprintf("%2.2f-fold\ndifference\nin effort", reduction), x = label_pos, y = 0.08), | ||
colour = LINE_COLOUR, size = 2.5, data = reduction_label_data) + | ||
geom_segment(aes(x = lower_val, xend = upper_val, y = 0.155, yend = 0.155), | ||
arrow = arrow(length = unit(0.3, "lines"), type = "closed", ends = "both"), | ||
colour = LINE_COLOUR, data = reduction_arrow_data) + | ||
|
||
geom_step(aes(x = prop_screened, y = prop_found, colour = model), size = 1) + | ||
|
||
scale_x_continuous(labels = percent_format(), expand = expand_scale()) + | ||
scale_y_continuous(labels = percent_format(accuracy = 1), expand = expand_scale()) + | ||
scale_colour_brewer(palette = "Set2", direction = -1) + | ||
|
||
labs(x = "Proportion of viruses screened", y = "Proportion of known human-infecting viruses encountered", colour = "Feature set") + | ||
coord_equal() + | ||
PLOT_THEME + | ||
theme(legend.key.height = unit(2, "lines")) | ||
|
||
|
||
ggsave2("Plots/Supplement_ScreeningSuccessRate.pdf", width = 6, height = 4) | ||
|
||
|
||
## Mentioned in text: | ||
cat("\nScreening required to find 50% of human-infecting viruses:\n") | ||
print(examples) |