Skip to content

Commit

Permalink
Fix opencl
Browse files Browse the repository at this point in the history
Addresses several issues with the opencl tests:
- Model object does not correctly identify whether model was compiled
with stan_opencl. Therefore, models were not recompiled within the tests
when recompilation is needed. Resolved this by supplying a different
exe_file for each tests so recompilation will always happen.
- Tests referenced a metadata attribute that was not defined.
I've added checks for opencl related attributes that are set when
opencl is functioning properly.
- Version test used a data input that was never created.
  • Loading branch information
katrinabrock committed Aug 8, 2024
1 parent 895e1d0 commit e8ad5f5
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions tests/testthat/test-opencl.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ fit <- testing_fit("bernoulli", method = "sample", seed = 123, chains = 1)

test_that("all methods error when opencl_ids is used with non OpenCL model", {
stan_file <- testing_stan_file("bernoulli")
mod <- cmdstan_model(stan_file = stan_file)
exe_file <- tempfile(pattern = "bernoulli-")
mod <- cmdstan_model(stan_file = stan_file, exe_file = exe_file)
expect_error(
mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0), chains = 1),
"'opencl_ids' is set but the model was not compiled with for use with OpenCL.",
Expand All @@ -22,7 +23,8 @@ test_that("all methods error when opencl_ids is used with non OpenCL model", {
fixed = TRUE
)
stan_file_gq <- testing_stan_file("bernoulli_ppc")
mod_gq <- cmdstan_model(stan_file = stan_file_gq)
exe_file_gq <- tempfile(pattern = "bernoulli_ppc-")
mod_gq <- cmdstan_model(stan_file = stan_file_gq, exe_file_gq)
expect_error(
mod_gq$generate_quantities(fitted_params = fit, data = testing_data("bernoulli"), opencl_ids = c(0, 0)),
"'opencl_ids' is set but the model was not compiled with for use with OpenCL.",
Expand All @@ -33,7 +35,8 @@ test_that("all methods error when opencl_ids is used with non OpenCL model", {
test_that("all methods error on invalid opencl_ids", {
skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true"))
stan_file <- testing_stan_file("bernoulli")
mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE))
exe_file <- tempfile(pattern = "bernoulli-")
mod <- cmdstan_model(stan_file = stan_file, exe_file = exe_file, cpp_options = list(stan_opencl = TRUE))
utils::capture.output(
expect_warning(
mod$sample(data = testing_data("bernoulli"), opencl_ids = c(1000, 1000), chains = 1),
Expand All @@ -56,7 +59,8 @@ test_that("all methods error on invalid opencl_ids", {
)
)
stan_file_gq <- testing_stan_file("bernoulli_ppc")
mod_gq <- cmdstan_model(stan_file = stan_file_gq, cpp_options = list(stan_opencl = TRUE))
exe_file_gq <- tempfile(pattern = "bernoulli_ppc-")
mod_gq <- cmdstan_model(stan_file = stan_file_gq, exe_file = exe_file_gq, cpp_options = list(stan_opencl = TRUE))
utils::capture.output(
expect_warning(
mod_gq$generate_quantities(fitted_params = fit, data = testing_data("bernoulli"), opencl_ids = c(1000, 1000)),
Expand All @@ -69,38 +73,50 @@ test_that("all methods error on invalid opencl_ids", {
test_that("all methods run with valid opencl_ids", {
skip_if_not(Sys.getenv("CMDSTANR_OPENCL_TESTS") %in% c("1", "true"))
stan_file <- testing_stan_file("bernoulli")
mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE))
exe_file <- tempfile(pattern = "bernoulli-")
mod <- cmdstan_model(stan_file = stan_file, exe_file = exe_file, cpp_options = list(stan_opencl = TRUE))
expect_sample_output(
fit <- mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0), chains = 1)
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))
expect_false(is.null(fit$metadata()$device))
expect_false(is.null(fit$metadata()$platform))

stan_file_gq <- testing_stan_file("bernoulli_ppc")
mod_gq <- cmdstan_model(stan_file = stan_file_gq, cpp_options = list(stan_opencl = TRUE))
exe_file_gq <- tempfile(pattern = "bernoulli_ppc-")
mod_gq <- cmdstan_model(stan_file = stan_file_gq, exe_file = exe_file_gq, cpp_options = list(stan_opencl = TRUE))
expect_gq_output(
fit <- mod_gq$generate_quantities(fitted_params = fit, data = testing_data("bernoulli"), opencl_ids = c(0, 0)),
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))
expect_false(is.null(fit$metadata()$device))
expect_false(is.null(fit$metadata()$platform))

expect_sample_output(
fit <- mod$sample(data = testing_data("bernoulli"), opencl_ids = c(0, 0))
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))
expect_false(is.null(fit$metadata()$device))
expect_false(is.null(fit$metadata()$platform))

expect_optim_output(
fit <- mod$optimize(data = testing_data("bernoulli"), opencl_ids = c(0, 0))
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))
expect_false(is.null(fit$metadata()$device))
expect_false(is.null(fit$metadata()$platform))

expect_vb_output(
fit <- mod$variational(data = testing_data("bernoulli"), opencl_ids = c(0, 0))
)
expect_false(is.null(fit$metadata()$opencl_platform_name))
expect_false(is.null(fit$metadata()$opencl_ids_name))
expect_false(is.null(fit$metadata()$opencl_device_name))
expect_false(is.null(fit$metadata()$device))
expect_false(is.null(fit$metadata()$platform))
})

test_that("error for runtime selection of OpenCL devices if version less than 2.26", {
Expand All @@ -111,7 +127,7 @@ test_that("error for runtime selection of OpenCL devices if version less than 2.
mod <- cmdstan_model(stan_file = stan_file, cpp_options = list(stan_opencl = TRUE),
force_recompile = TRUE)
expect_error(
mod$sample(data = data_list, chains = 1, refresh = 0, opencl_ids = c(1,1)),
mod$sample(data = testing_data("bernoulli"), chains = 1, refresh = 0, opencl_ids = c(1,1)),
"Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer",
fixed = TRUE
)
Expand Down

0 comments on commit e8ad5f5

Please sign in to comment.