Skip to content

Commit e79ef48

Browse files
Refactor (#3)
* refactor in progress * vignette start + pkgdown
1 parent 3a61cf8 commit e79ef48

34 files changed

+1794
-501
lines changed

.Rbuildignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
^renv$
2+
^renv\.lock$
3+
^doc$
4+
^Meta$
5+
^_pkgdown\.yml$
6+
^docs$
7+
^pkgdown$
8+
^\.github$

.github/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.html

.github/workflows/pkgdown.yaml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
2+
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
3+
on:
4+
push:
5+
branches: [main, master]
6+
pull_request:
7+
branches: [main, master]
8+
release:
9+
types: [published]
10+
workflow_dispatch:
11+
12+
name: pkgdown
13+
14+
jobs:
15+
pkgdown:
16+
runs-on: ubuntu-latest
17+
# Only restrict concurrency for non-PR jobs
18+
concurrency:
19+
group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }}
20+
env:
21+
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
22+
permissions:
23+
contents: write
24+
steps:
25+
- uses: actions/checkout@v3
26+
27+
- uses: r-lib/actions/setup-pandoc@v2
28+
29+
- uses: r-lib/actions/setup-r@v2
30+
with:
31+
use-public-rspm: true
32+
33+
- uses: r-lib/actions/setup-r-dependencies@v2
34+
with:
35+
extra-packages: any::pkgdown, local::.
36+
needs: website
37+
38+
- name: Build site
39+
run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE)
40+
shell: Rscript {0}
41+
42+
- name: Deploy to GitHub pages 🚀
43+
if: github.event_name != 'pull_request'
44+
uses: JamesIves/[email protected]
45+
with:
46+
clean: false
47+
branch: gh-pages
48+
folder: docs

.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
inst/doc
2+
/doc/
3+
/Meta/
4+
/renv/
5+
.Rproj.user
6+
.Rhistory
7+
.Rprofile
8+
.RData
9+
.Ruserdata
10+
.DS_Store
11+
renv/
12+
docs

DESCRIPTION

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@ Imports:
1717
future.apply,
1818
slurmworkflow,
1919
fs,
20+
lhs,
2021
dplyr
2122
Remotes:
2223
github::EpiModel/slurmworkflow
2324
Encoding: UTF-8
2425
Roxygen: list(markdown = TRUE)
2526
RoxygenNote: 7.2.3
27+
Suggests:
28+
ggplot2,
29+
knitr,
30+
rmarkdown
31+
VignetteBuilder: knitr
32+
URL: https://epimodel.github.io/swfcalib/

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,11 @@
33
export(calibration_step1)
44
export(calibration_step2)
55
export(calibration_step3)
6+
export(determ_end_thresh)
7+
export(determ_poly_end)
68
export(load_sideload)
9+
export(make_proposer_se_range)
10+
export(make_shrink_proposer)
11+
export(render_assessment)
712
export(save_sideload)
13+
importFrom(dplyr,.data)

R/assessment_plots.R

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
make_rmse_plot <- function(job_assess) {
2+
d <- job_assess$measures
3+
ggplot2::ggplot(d, ggplot2::aes(x = .data$iteration, y = .data$rmse_mean)) +
4+
ggplot2::geom_line() +
5+
ggplot2::geom_ribbon(
6+
ggplot2::aes(ymin = rmse_mean - rmse_sd, ymax = rmse_mean + rmse_sd),
7+
alpha = 0.3
8+
) +
9+
ggplot2::scale_y_log10() +
10+
ggplot2::theme_light() +
11+
ggplot2::labs(
12+
title = paste0("RMSE Evolution for: ", job_assess$infos$job_id),
13+
x = "Iteration",
14+
y = "RMSE \n(log10 scale)"
15+
)
16+
}
17+
18+
make_param_volume_plot <- function(job_assess) {
19+
d <- job_assess$measures
20+
ggplot2::ggplot(d, ggplot2::aes(x = .data$iteration,
21+
y = .data$param_volume)) +
22+
ggplot2::geom_line() +
23+
ggplot2::scale_y_log10() +
24+
ggplot2::theme_light() +
25+
ggplot2::labs(
26+
title =
27+
paste0("Parameter Space Volume Evolution for: ",
28+
job_assess$infos$job_id),
29+
x = "Iteration",
30+
y = "Parameter Space Volume \n(log10 scale)"
31+
)
32+
}
33+
34+
make_param_spread_plot <- function(job_assess, param) {
35+
d <- job_assess$measures
36+
d[["y"]] <- d[[paste0("spread__", param)]]
37+
ggplot2::ggplot(d, ggplot2::aes(x = .data$iteration, y = .data$y)) +
38+
ggplot2::geom_line() +
39+
ggplot2::scale_y_log10() +
40+
ggplot2::theme_light() +
41+
ggplot2::labs(
42+
title = paste0("Spread of Parameter: ", param),
43+
x = "Iteration",
44+
y = "Spread \n(log10 scale)"
45+
)
46+
}
47+
48+
make_target_err_plot <- function(job_assess, target) {
49+
d <- job_assess$measures
50+
d[["y"]] <- d[[paste0("mean_err__", target)]]
51+
d[["ys"]] <- d[[paste0("sd_err__", target)]]
52+
ggplot2::ggplot(d, ggplot2::aes(x = .data$iteration, y = .data$y)) +
53+
ggplot2::geom_line() +
54+
ggplot2::geom_ribbon(
55+
ggplot2::aes(ymin = .data$y - .data$ys, ymax = .data$y + .data$ys),
56+
alpha = 0.3
57+
) +
58+
ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
59+
ggplot2::theme_light() +
60+
ggplot2::labs(
61+
title = paste0("Mean Error on: ", target),
62+
x = "Iteration",
63+
y = "Error \n(mean + sd)"
64+
)
65+
}
66+
67+
make_job_plots <- function(job_assess) {
68+
out <- list()
69+
infos <- job_assess$infos
70+
out$rmse <- make_rmse_plot(job_assess)
71+
out$volume <- make_param_volume_plot(job_assess)
72+
out$params <- lapply(infos$params, make_param_spread_plot, job = job_assess)
73+
names(out$params) <- infos$params
74+
out$targets <- lapply(infos$targets, make_target_err_plot, job = job_assess)
75+
names(out$targets) <- infos$targets
76+
out
77+
}
78+
79+
make_wave_plots <- function(wave_assess) {
80+
out <- lapply(wave_assess, make_job_plots)
81+
names(out) <- vapply(wave_assess, function(x) x$infos$job_id, character(1))
82+
out
83+
}
84+
85+
make_assessments_plots <- function(assessments) {
86+
out <- lapply(assessments, make_wave_plots)
87+
names(out) <- names(assessments)
88+
out
89+
}

R/assessment_rmd.R

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
make_wave_rmd <- function(assessments, wave_num) {
2+
wave <- assessments[[paste0("wave", wave_num)]]
3+
cat("# Wave", wave_num, "\n\n")
4+
for (i in seq_along(wave)) {
5+
make_job_rmd(wave[[i]])
6+
}
7+
}
8+
9+
make_job_rmd <- function(job_assess) {
10+
cat("##", job_assess$infos$job_id, "\n\n")
11+
12+
cat("### Targets and Parameters", "\n\n")
13+
dplyr::tibble(
14+
target_name = job_assess$infos$targets,
15+
target_value = job_assess$infos$targets_val
16+
) |> knitr::kable(align = "ll") |> print()
17+
18+
dplyr::tibble(
19+
parameter = job_assess$infos$params,
20+
initial_range = vapply(
21+
job_assess$infos$params_ranges,
22+
\(x) paste0(x[1], " - ", x[2]),
23+
""
24+
)
25+
) |> knitr::kable(align = "ll") |> print()
26+
27+
cat("\n\n")
28+
29+
cat("### Parameter Space and RMSE Evolution", "\n\n")
30+
31+
make_param_volume_plot(job_assess) |> print()
32+
make_rmse_plot(job_assess) |> print()
33+
cat("\n\n")
34+
35+
cat("### Parameter Spreads", "\n\n")
36+
for (p in job_assess$infos$params) {
37+
make_param_spread_plot(job_assess, p) |> print()
38+
cat("\n\n")
39+
}
40+
41+
cat("### Target Errors", "\n\n")
42+
for (t in job_assess$infos$targets) {
43+
make_target_err_plot(job_assess, t) |> print()
44+
cat("\n\n")
45+
}
46+
cat("\n\n")
47+
48+
}
49+
50+
#' Generate an html report of the auto-calibration
51+
#'
52+
#' The report contains descriptions of the parameters spaces and residual errors
53+
#' over the duration of the calibration.
54+
#'
55+
#' @param path_to_assessments Path to an `assessments.rds` file generated by an
56+
#' `swfcalib` process.
57+
#' @param output_filename Name of the html report (default = "assessment.html")
58+
#' @param output_dir Directory where to store the report (default = current
59+
#' working directory)
60+
#'
61+
#' @export
62+
render_assessment <- function(path_to_assessments,
63+
output_filename = "assessment.html",
64+
output_dir = NULL) {
65+
if (is.null(output_dir)) output_dir <- getwd()
66+
rmarkdown::render(
67+
system.file("rmd/assessment.Rmd", package = "swfcalib"),
68+
output_file = output_filename,
69+
output_dir = output_dir,
70+
knit_root_dir = getwd(),
71+
params = list(path_to_assessments = path_to_assessments)
72+
)
73+
}

R/assessments.R

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
update_assessments <- function(calib_object, results) {
2+
out <- load_assessments(calib_object)
3+
if (nrow(results) == 0) {
4+
save_assessments(calib_object, out)
5+
return(invisible(calib_object))
6+
}
7+
8+
cur_wave <- paste0("wave", get_current_wave(calib_object))
9+
10+
assessments <- future.apply::future_lapply(
11+
get_current_jobs(calib_object),
12+
make_job_assessment,
13+
calib_object = calib_object,
14+
results = results
15+
)
16+
17+
out[[cur_wave]] <- merge_wave_assements(assessments, out[[cur_wave]])
18+
save_assessments(calib_object, out)
19+
invisible(calib_object)
20+
}
21+
22+
merge_wave_assements <- function(new, old) {
23+
old <- if (is.null(old)) list() else old
24+
for (nme in names(new)) {
25+
new[[nme]] <- merge_job_assessment(new[[nme]], old[[nme]])
26+
}
27+
new
28+
}
29+
30+
merge_job_assessment <- function(new, old) {
31+
list(
32+
infos = new$infos,
33+
measures = dplyr::bind_rows(old$measures, new$measures)
34+
)
35+
}
36+
37+
save_assessments <- function(calib_object, assessments) {
38+
saveRDS(assessments, get_assessments_path(calib_object))
39+
}
40+
41+
get_assessments_path <- function(calib_object) {
42+
fs::path(get_root_dir(calib_object), "assessments.rds")
43+
}
44+
45+
load_assessments <- function(calib_object) {
46+
f_path <- get_assessments_path(calib_object)
47+
if (fs::file_exists(f_path)) readRDS(f_path) else list()
48+
}
49+
50+
make_job_assessment <- function(calib_object, job, results) {
51+
out <- list()
52+
53+
out$infos <- job[c("targets", "targets_val", "params")]
54+
out$infos$job_id <- get_job_id(job)
55+
out$infos$params_ranges <- lapply(job$initial_proposals, range)
56+
57+
current_iteration <- get_current_iteration(calib_object)
58+
current_wave <- get_current_wave(calib_object)
59+
d <- dplyr::filter(
60+
results,
61+
.data$.iteration == current_iteration,
62+
.data$.wave == current_wave
63+
)
64+
65+
make_rmse <- function(x, target) sqrt(mean((target - x)^2))
66+
iter_rmse <- apply(d[job$targets], 1, make_rmse, target = job$targets_val)
67+
68+
get_spread <- function(x) diff(range(x))
69+
spreads <- vapply(d[job$params], get_spread, numeric(1))
70+
71+
out$measures <- dplyr::tibble(
72+
iteration = current_iteration,
73+
rmse_mean = mean(iter_rmse),
74+
rmse_sd = sd(iter_rmse),
75+
param_volume = prod(spreads)
76+
)
77+
78+
errors <- Map(function(x, target) target - x, d[job$targets], job$targets_val)
79+
80+
out$measures[paste0("spread__", names(spreads))] <- as.list(spreads)
81+
out$measures[paste0("mean_err__", job$targets)] <- lapply(errors, mean)
82+
out$measures[paste0("sd_err__", job$targets)] <- lapply(errors, sd)
83+
84+
out
85+
}

0 commit comments

Comments
 (0)