diff --git a/R/impact_analysis.R b/R/impact_analysis.R index 00ecf86..0a566c4 100644 --- a/R/impact_analysis.R +++ b/R/impact_analysis.R @@ -135,7 +135,8 @@ FormatInputPrePostPeriod <- function(pre.period, post.period, data) { FormatInputForCausalImpact <- function(data, pre.period, post.period, model.args, bsts.model, - post.period.response, alpha) { + post.period.response, alpha, + construct.model) { # Checks and formats all input arguments supplied to CausalImpact(). See the # documentation of CausalImpact() for details. # @@ -147,6 +148,7 @@ FormatInputForCausalImpact <- function(data, pre.period, post.period, # bsts.model: fitted bsts model (instead of data) # post.period.response: observed response in the post-period # alpha: tail-area for posterior intervals + # construct.model: custom model constructor # # Returns: # list of checked (and possibly reformatted) input arguments @@ -213,7 +215,8 @@ CausalImpact <- function(data = NULL, model.args = NULL, bsts.model = NULL, post.period.response = NULL, - alpha = 0.05) { + alpha = 0.05, + construct.model = NULL) { # CausalImpact() performs causal inference through counterfactual # predictions using a Bayesian structural time-series model. # @@ -278,6 +281,8 @@ CausalImpact <- function(data = NULL, # alpha: Desired tail-area probability for posterior intervals. # Defaults to 0.05, which will produce central 95\% intervals. # + # construct.model: Custom model constructor. + # # Returns: # A CausalImpact object. This is a list of: # series: observed data, counterfactual, pointwise and cumulative impact @@ -341,7 +346,8 @@ CausalImpact <- function(data = NULL, # Check input checked <- FormatInputForCausalImpact(data, pre.period, post.period, model.args, bsts.model, - post.period.response, alpha) + post.period.response, alpha, + construct.model) data <- checked$data pre.period <- checked$pre.period post.period <- checked$post.period @@ -352,7 +358,7 @@ CausalImpact <- function(data = NULL, # Depending on input, dispatch to the appropriate Run* method() if (!is.null(data)) { - impact <- RunWithData(data, pre.period, post.period, model.args, alpha) + impact <- RunWithData(data, pre.period, post.period, model.args, alpha, construct.model) # Return pre- and post-period in the time unit of the time series. times <- time(data) impact$model$pre.period <- times[pre.period] @@ -364,7 +370,7 @@ CausalImpact <- function(data = NULL, return(impact) } -RunWithData <- function(data, pre.period, post.period, model.args, alpha) { +RunWithData <- function(data, pre.period, post.period, model.args, alpha, construct.model) { # Runs an impact analysis on top of a fitted bsts model. # # Args: @@ -375,6 +381,7 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) { # limits. # model.args: list of model arguments # alpha: tail-probabilities of posterior intervals + # construct.model: custom model constructor # # Returns: # See CausalImpact(). @@ -409,7 +416,16 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) { window(data.modeling[, 1], start = pre.period[2] + 1) <- NA # Construct model and perform inference - bsts.model <- ConstructModel(data.modeling, model.args) + if (!is.null(construct.model)) { + checked <- FormatInputForConstructModel(data.modeling, model.args) + y <- checked$data[, 1] + # If the series is ill-conditioned, abort inference and return NULL + bsts.model <- if (ObservationsAreIllConditioned(y)) NULL else { + construct.model(checked$data) + } + } else { + bsts.model <- ConstructModel(data.modeling, model.args) + } # Compile posterior inferences if (!is.null(bsts.model)) { diff --git a/tests/testthat/test-impact-analysis.R b/tests/testthat/test-impact-analysis.R index f6d1340..2c19bb7 100644 --- a/tests/testthat/test-impact-analysis.R +++ b/tests/testthat/test-impact-analysis.R @@ -593,6 +593,33 @@ test_that("CausalImpact.RunWithData.MissingTimePoint", { expect_equal(indices, time(series)[-17]) }) +test_that("CausalImpact.RunWithData.CustomConstructModel", { + # Test daily data (zoo object) + data <- zoo(cbind(rnorm(200), rnorm(200), rnorm(200)), + seq.Date(as.Date("2014-01-01"), as.Date("2014-01-01") + 199, + by = 1)) + + pre.period <- as.Date(c("2014-01-01", "2014-04-10")) # 100 days + post.period <- as.Date(c("2014-04-11", "2014-07-09")) # 90 days + + go <- function(data) { + y <- data[, 1] + sdy <- sd(y, na.rm = TRUE) + sd.prior <- SdPrior(sigma.guess = 0.01 * sdy, + upper.limit = sdy, + sample.size = 32) + ss <- AddLocalLevel(list(), y, sigma.prior = sd.prior) + bsts.model <- bsts(y, state.specification = ss, niter = 100, + seed = 1, ping = 0) + } + + suppressWarnings(impact <- CausalImpact(data, pre.period, post.period, + construct.model = go)) + expect_equal(time(impact$model$bsts.model$original.series), 1:200) + expect_equal(time(impact$series), time(data)) + CallAllS3Methods(impact) +}) + test_that("CausalImpact.RunWithBstsModel", { # Test on a healthy bsts object