Skip to content

Commit

Permalink
Illustrate screening success
Browse files Browse the repository at this point in the history
  • Loading branch information
Nardus committed Jun 1, 2021
1 parent e99e44d commit 3ed542c
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 0 deletions.
100 changes: 100 additions & 0 deletions Scripts/Plotting/MakeSupplement_EffectByGenomeType_Example.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
library(dplyr)
library(tidyr)
library(ggplot2)
library(cowplot)
library(ggbeeswarm)

source(file.path('Utils', 'plot_utils.R'))
source(file.path('Scripts/Plotting/PlottingConstants.R'))


shapley_vals_individual <- readRDS(file.path('Plots', 'Intermediates', 'figure2_virus_shapley_vals.rds'))
featureset_importance <- readRDS(file.path('Plots', 'Intermediates', 'figure2_feature_set_importance.rds'))
genome_types <- readRDS('Plots/Intermediates/figure1_merged_taxonomy.rds')


# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# ---- Example of an effect which works across genome types --------------------------------------
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
top_feature <- featureset_importance$Feature[featureset_importance$Rank == 1]

stopifnot(top_feature == 'GenomicDensity_Housekeeping_CTG.Bias_Coding')
top_feature_name <- 'CTG bias similarity (housekeeping)'

example_vals <- shapley_vals_individual %>%
filter(Feature == top_feature) %>%
left_join(genome_types, by = c('LatestSppName' = 'Species')) %>%
mutate(Class = if_else(Class == 'No human infections', as.character(Class), 'Infects humans'))

#ks.test(example_vals$FeatureValue[example_vals$Class == 'Infects humans'], example_vals$FeatureValue[example_vals$Class != 'InfectsHumans'])


axis_lims <- range(example_vals$FeatureValue) # Make sure plots are aligned
cut_points <- quantile(example_vals$FeatureValue, probs = c(0.25, 0.75)) #c(0.55, 1.95)
colours <- ZOONOTIC_STATUS_COLOURS[1:2]
names(colours) <- c('No human infections', 'Infects humans')

hist_plot <- ggplot(example_vals, aes(x = FeatureValue, fill = Class)) +
geom_histogram(bins = 150) +
geom_vline(xintercept = cut_points, linetype = 2, colour = LINE_COLOUR) +
facet_grid(rows = vars(Class)) +
scale_fill_manual(values = colours, guide = FALSE) +
xlim(axis_lims) +
labs(y = 'Count') +
PLOT_THEME +
theme(axis.text.x = element_blank(),
axis.title.x = element_blank(),
axis.ticks.x = element_blank(),
panel.spacing = unit(3, 'pt'),
plot.margin = margin(t = 5.5, r = 5.5, b = 0, l = 5.5))


shap_plot <- ggplot(example_vals, aes(x = FeatureValue, y = SHAP_mean, colour = Class)) +
geom_point(shape = 1) +
geom_hline(yintercept = 0, linetype = 2, colour = LINE_COLOUR) +
geom_vline(xintercept = cut_points, linetype = 2, colour = LINE_COLOUR) +
scale_colour_manual(values = colours, guide = FALSE) +
xlim(axis_lims) +
labs(y = 'Contribution to log odds\n(SHAP value)') +
PLOT_THEME +
theme(axis.text.x = element_blank(),
axis.title.x = element_blank(),
axis.ticks.x = element_blank(),
plot.margin = margin(t = 5.5, r = 5.5, b = 0, l = 5.5))




dist_plot <- ggplot(example_vals, aes(x = GenomeType, y = FeatureValue, fill = Class)) +
geom_boxplot(colour = LINE_COLOUR, position = position_dodge(width = 0.9)) +
geom_hline(yintercept = cut_points, linetype = 2, colour = LINE_COLOUR) +
coord_flip() +
#facet_grid(rows = vars(GenomeType), scales = 'free_y') +
scale_fill_manual(values = colours, guide = FALSE) +
ylim(axis_lims) +
labs(x = 'Genome type', y = top_feature_name) +
PLOT_THEME +
theme(panel.spacing = unit(0, 'pt'),
strip.text = element_blank())


p <- plot_grid(hist_plot, shap_plot, dist_plot,
nrow = 3, rel_heights = c(1.5, 1, 2),
align = 'v', axis = 'lr')

ggsave2(file.path('Plots', 'Supplement_EffectByGenomeType_Example.pdf'), p, width = 7, height = 8, units = 'in')


# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# ---- Values for legend --------------------------------------------------------------------------
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
cat("Cutpoints: c1 =", cut_points[1], "| c2 =", cut_points[2])

example_vals %>%
group_by(Class) %>%
summarise(below_c1 = sum(FeatureValue < cut_points[1]),
above_c2 = sum(FeatureValue > cut_points[2])) %>%
ungroup() %>%
mutate(below_c1_prop = below_c1/sum(below_c1),
above_c2_prop = above_c2/sum(above_c2)) %>%
print()
128 changes: 128 additions & 0 deletions Scripts/Plotting/MakeSupplement_RelatednessModelRanks.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#
# Plot an overview of ranks for training set viruses, using ranks from relatedness-based models
# - Very similar to "MakeSupplement_TrainingSetRanks.R", which produces the equivalent plot for
# ranks from the main model
#

library(dplyr)
library(tidyr)
library(stringr)
library(readr)
library(readxl)
library(ggplot2)
library(cowplot)

source(file.path('Scripts', 'Plotting', 'PlottingConstants.R'))
source(file.path('Utils', 'plot_utils.R'))

training_data <- readRDS(file.path('CalculatedData', 'SplitData_Training.rds'))

merged_taxonomy <- readRDS(file.path('Plots', 'Intermediates', 'figure1_merged_taxonomy.rds'))
zoo_status <- readRDS(file.path('Plots', 'Intermediates', 'figure1_zoo_status.rds'))

taxonomy_predictions <- readRDS(file.path("RunData", "Taxonomy_LongRun", "Taxonomy_LongRun_Bagged_predictions.rds"))
pn_predictions <- readRDS(file.path("RunData", "PN_LongRun", "PN_LongRun_Bagged_predictions.rds"))


## Prepare data
genome_type_order <- unique(merged_taxonomy$GenomeType)
dna <- genome_type_order[grepl('DNA', genome_type_order)]
rna <- genome_type_order[!grepl('DNA', genome_type_order)]
genome_type_order <- rev(c(sort(dna), sort(rna))) # RNA viruses at top of plot

zoo_status <- zoo_status %>%
rename(infection_class = .data$Class)

name_matches <- training_data %>%
select(.data$UniversalName, .data$LatestSppName)

stopifnot(n_distinct(name_matches$UniversalName) == nrow(name_matches))


## Add priority categories
prioritize <- function(lower_bound, upper_bound, median, cutoff) {
stopifnot(length(cutoff) == 1)
stopifnot(cutoff > 0 & cutoff < 1)

p <- if_else(lower_bound > cutoff, 'Very high',
if_else(median > cutoff, 'High',
if_else(upper_bound > cutoff, 'Medium',
'Low')))
factor(p, levels = c('Low', 'Medium', 'High', 'Very high'))
}


## Plots
plot_ranks <- function(prediction_data) {
prediction_data <- prediction_data %>%
left_join(name_matches, by = "UniversalName") %>%
left_join(merged_taxonomy, by = c('LatestSppName' = 'Species')) %>%
left_join(zoo_status, by = "LatestSppName")

cutoff <- find_balanced_cutoff(observed_labels = prediction_data$InfectsHumans,
predicted_score = prediction_data$BagScore)

cat("Using cutoff:", cutoff, "\n")

ranks <- prediction_data %>%
mutate(priority = prioritize(lower_bound = .data$BagScore_Lower,
upper_bound = .data$BagScore_Upper,
median = .data$BagScore,
cutoff = cutoff)) %>%

arrange(.data$BagScore) %>%
mutate(species = factor(.data$LatestSppName, levels = .data$LatestSppName),
Family = factor(.data$Family, levels = sort(unique(.data$Family), decreasing = TRUE)),
priority = factor(.data$priority, levels = c('Low', 'Medium', 'High', 'Very high')),
infection_class = factor(.data$infection_class,
levels = c('Human virus', 'Zoonotic', 'No human infections')),
GenomeType = factor(.data$GenomeType, levels = genome_type_order))


## Main panel
main_panel <- ggplot(ranks, aes(x = species, y = BagScore, colour = infection_class)) +
geom_errorbar(aes(ymin = BagScore_Lower, ymax = BagScore_Upper), width = 0.5) +
geom_step(group = 1, colour = 'grey10', size = 0.6) +
geom_hline(yintercept = cutoff, linetype = 2, colour = 'grey10') +
scale_colour_manual(values = ZOONOTIC_STATUS_COLOURS) +
scale_y_continuous(expand = expand_scale(add = c(0.02, 0.02))) +
labs(x = NULL, y = SCORE_LABEL_2LINE, colour = 'Current status') +
PLOT_THEME +
theme(panel.grid.major.x = element_blank(),
axis.ticks.x = element_blank(),
axis.text.x = element_blank(),
plot.margin = margin(t = 5.5, r = 5.5, b = 1, l = 5.5))


## Indicate families
family_indicator_plot <- ggplot(ranks, aes(x = species, y = Family, fill = priority)) +
geom_tile() +
facet_grid(rows = vars(GenomeType), scales = 'free', space = 'free') +
scale_fill_manual(values = PRIORITY_COLOURS, drop = FALSE) +
PLOT_THEME +
labs(x = 'Virus species (ranked)', fill = 'Zoonotic potential') +
theme(axis.text.x = element_blank(),
axis.ticks.x = element_blank(),
axis.text.y = element_text(lineheight = 0.65, face = 'italic'),
panel.spacing = unit(0.05, 'lines'),
panel.background = element_rect(fill = 'grey25'),
strip.background = element_rect(fill = 'white'),
strip.text.y = element_text(angle = 0),
plot.margin = margin(t = 0, r = 5.5, b = 5.5, l = 5.5))


plot_grid(main_panel, family_indicator_plot,
nrow = 2, rel_heights = c(1, 5),
align = 'v', axis = 'lr')
}


taxonomy_plot <- plot_ranks(taxonomy_predictions)
pn_plot <- plot_ranks(pn_predictions)


## Save
combined_plot <- plot_grid(taxonomy_plot, pn_plot,
nrow = 2, labels = c("A", "B"))

ggsave2(file.path('Plots', 'Supplement_RelatednessModelRanks.pdf'), combined_plot, width = 7, height = 8.5)
141 changes: 141 additions & 0 deletions Scripts/Plotting/MakeSupplement_ScreeningSuccessRate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
## Plot a comparison of accumulation curves (i.e. fig 1D, but for alternate models too)

library(dplyr)
library(tidyr)
library(ggplot2)
library(scales)
library(cowplot)

source(file.path('Scripts', 'Plotting', 'PlottingConstants.R'))

set.seed(3105221)

## Data
best_testing <- readRDS(file.path('Plots', 'Intermediates', 'figure1_virus_testing.rds'))

tax_preds <- readRDS("RunData/Taxonomy_LongRun/Taxonomy_LongRun_Bagged_predictions.rds") %>%
mutate(Rank = rank(-.data$BagScore)) %>%
arrange(.data$Rank)

pn_preds <- readRDS("RunData/PN_LongRun/PN_LongRun_Bagged_predictions.rds") %>%
mutate(Rank = rank(-.data$BagScore)) %>%
arrange(.data$Rank)


## Calculate accumulation curves
simulate_testing <- function(rank_data) {
virus_testing <- data.frame()
total_viruses <- nrow(rank_data)
pos_viruses <- sum(rank_data$InfectsHumans == "True")

for (i in 1:nrow(rank_data)) {
current_row <- rank_data[i, ]
viruses_before <- rank_data[rank_data$Rank <= current_row$Rank, ]

virus_testing <- rbind(virus_testing, data.frame(
Rank = current_row$Rank,
prop_screened = nrow(viruses_before)/total_viruses,
prop_found = sum(viruses_before$InfectsHumans == "True")/pos_viruses
))
}

virus_testing
}

tax_testing <- simulate_testing(tax_preds)

pn_testing <- simulate_testing(pn_preds)


combined_testing <- list("Taxonomic" = tax_testing,
"Phylogenetic\nneighbourhood" = pn_testing,
"All genome\ncomposition\nfeature sets" = best_testing)


## Simulate random screening
observed_prevalence <- sum(tax_preds$InfectsHumans == "True") / nrow(tax_preds)

screen_randomly <- function(original_labels = tax_preds$InfectsHumans) {
n <- length(original_labels)

labels <- sample(original_labels)

screened <- vector(length = n)
found <- vector(length = n)

for (i in 1:n) {
screened[i] <- i / n
found[i] <- sum(labels[1:i] == "True") / sum(labels == "True")
}

data.frame(prop_screened = screened,
prop_found = found)
}

random_testing <- replicate(1000, screen_randomly(), simplify = FALSE) %>%
bind_rows(.id = "replicate")


## Find point at which 50% of human-infecting viruses are found:
get_point <- function(data, example_points) {
positive_accumulation_fun <- ecdf(data$prop_found)

data.frame(prop_found = example_points,
prop_screened = positive_accumulation_fun(example_points))
}

examples <- lapply(combined_testing, get_point, example_points = 0.5) %>%
bind_rows(.id = "model")

combined_testing <- combined_testing %>%
bind_rows(.id = "model")


## Reduction in effort:
best_model = "All genome\ncomposition\nfeature sets"

reduction_arrow_data <- examples %>%
filter(.data$model == best_model |
.data$prop_screened == min(.data$prop_screened[.data$model != best_model])) %>%
summarise(lower_val = min(.data$prop_screened),
upper_val = max(.data$prop_screened))

reduction_label_data <- reduction_arrow_data %>%
summarise(reduction = .data$upper_val/.data$lower_val,
label_pos = (.data$upper_val - .data$lower_val)/2 + .data$lower_val)


## Plot
p <- ggplot(combined_testing) +
geom_step(aes(x = prop_screened, y = prop_found, group = replicate),
colour = 'grey90', size = 0.2,
data = random_testing) +

geom_hline(yintercept = 0.5, linetype = 3, colour = LINE_COLOUR) +
geom_segment(aes(x = prop_screened, xend = prop_screened, y = -Inf, yend = prop_found),
linetype = 3, colour = LINE_COLOUR, data = examples) +

geom_text(aes(label = sprintf("%2.2f-fold\ndifference\nin effort", reduction), x = label_pos, y = 0.08),
colour = LINE_COLOUR, size = 2.5, data = reduction_label_data) +
geom_segment(aes(x = lower_val, xend = upper_val, y = 0.155, yend = 0.155),
arrow = arrow(length = unit(0.3, "lines"), type = "closed", ends = "both"),
colour = LINE_COLOUR, data = reduction_arrow_data) +

geom_step(aes(x = prop_screened, y = prop_found, colour = model), size = 1) +

scale_x_continuous(labels = percent_format(), expand = expand_scale()) +
scale_y_continuous(labels = percent_format(accuracy = 1), expand = expand_scale()) +
scale_colour_brewer(palette = "Set2", direction = -1) +

labs(x = "Proportion of viruses screened", y = "Proportion of known human-infecting viruses encountered", colour = "Feature set") +
coord_equal() +
PLOT_THEME +
theme(legend.key.height = unit(2, "lines"))


ggsave2("Plots/Supplement_ScreeningSuccessRate.pdf", width = 6, height = 4)


## Mentioned in text:
cat("\nScreening required to find 50% of human-infecting viruses:\n")
print(examples)

0 comments on commit 3ed542c

Please sign in to comment.