Skip to content

Commit

Permalink
Introduce construct.model parameter
Browse files Browse the repository at this point in the history
construct.model is a function that takes modeling data (NA in
post.period) and returns a bsts model.

This invocation:

    go <- function(data) {
      y <- data[, 1]
      sdy <- sd(y, na.rm = TRUE)
      sd.prior <- SdPrior(sigma.guess = 0.01 * sdy,
                          upper.limit = sdy,
                          sample.size = 32)
      ss <- AddLocalLevel(list(), y, sigma.prior = sd.prior)
      bsts.model <- bsts(y, state.specification = ss, niter = 100,
                         seed = 1, ping = 0)
    }

    impact <- CausalImpact(data, pre.period, post.period,
                           construct.model = go)

is equivalent to:

    model.args <- list(niter = 100)
    impact <- CausalImpact(data, pre.period, post.period, model.args)

. This change provides an interface that has RunWithData's flexibility
with pre- and post-period, and RunWithBstsModel's full configurability
of the model. For example, the caller can now specify a custom model
where the post-period is before the end of the data, which isn't
possible with the previous interface.
  • Loading branch information
luciferous committed Jul 25, 2021
1 parent 7e0f59f commit fdf2936
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
28 changes: 22 additions & 6 deletions R/impact_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ FormatInputPrePostPeriod <- function(pre.period, post.period, data) {

FormatInputForCausalImpact <- function(data, pre.period, post.period,
model.args, bsts.model,
post.period.response, alpha) {
post.period.response, alpha,
construct.model) {
# Checks and formats all input arguments supplied to CausalImpact(). See the
# documentation of CausalImpact() for details.
#
Expand All @@ -147,6 +148,7 @@ FormatInputForCausalImpact <- function(data, pre.period, post.period,
# bsts.model: fitted bsts model (instead of data)
# post.period.response: observed response in the post-period
# alpha: tail-area for posterior intervals
# construct.model: custom model constructor
#
# Returns:
# list of checked (and possibly reformatted) input arguments
Expand Down Expand Up @@ -213,7 +215,8 @@ CausalImpact <- function(data = NULL,
model.args = NULL,
bsts.model = NULL,
post.period.response = NULL,
alpha = 0.05) {
alpha = 0.05,
construct.model = NULL) {
# CausalImpact() performs causal inference through counterfactual
# predictions using a Bayesian structural time-series model.
#
Expand Down Expand Up @@ -278,6 +281,8 @@ CausalImpact <- function(data = NULL,
# alpha: Desired tail-area probability for posterior intervals.
# Defaults to 0.05, which will produce central 95\% intervals.
#
# construct.model: Custom model constructor.
#
# Returns:
# A CausalImpact object. This is a list of:
# series: observed data, counterfactual, pointwise and cumulative impact
Expand Down Expand Up @@ -341,7 +346,8 @@ CausalImpact <- function(data = NULL,
# Check input
checked <- FormatInputForCausalImpact(data, pre.period, post.period,
model.args, bsts.model,
post.period.response, alpha)
post.period.response, alpha,
construct.model)
data <- checked$data
pre.period <- checked$pre.period
post.period <- checked$post.period
Expand All @@ -352,7 +358,7 @@ CausalImpact <- function(data = NULL,

# Depending on input, dispatch to the appropriate Run* method()
if (!is.null(data)) {
impact <- RunWithData(data, pre.period, post.period, model.args, alpha)
impact <- RunWithData(data, pre.period, post.period, model.args, alpha, construct.model)
# Return pre- and post-period in the time unit of the time series.
times <- time(data)
impact$model$pre.period <- times[pre.period]
Expand All @@ -364,7 +370,7 @@ CausalImpact <- function(data = NULL,
return(impact)
}

RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
RunWithData <- function(data, pre.period, post.period, model.args, alpha, construct.model) {
# Runs an impact analysis on top of a fitted bsts model.
#
# Args:
Expand All @@ -375,6 +381,7 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
# limits.
# model.args: list of model arguments
# alpha: tail-probabilities of posterior intervals
# construct.model: custom model constructor
#
# Returns:
# See CausalImpact().
Expand Down Expand Up @@ -409,7 +416,16 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
window(data.modeling[, 1], start = pre.period[2] + 1) <- NA

# Construct model and perform inference
bsts.model <- ConstructModel(data.modeling, model.args)
if (!is.null(construct.model)) {
checked <- FormatInputForConstructModel(data.modeling, model.args)
y <- checked$data[, 1]
# If the series is ill-conditioned, abort inference and return NULL
bsts.model <- if (ObservationsAreIllConditioned(y)) NULL else {
construct.model(checked$data)
}
} else {
bsts.model <- ConstructModel(data.modeling, model.args)
}

# Compile posterior inferences
if (!is.null(bsts.model)) {
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test-impact-analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,33 @@ test_that("CausalImpact.RunWithData.MissingTimePoint", {
expect_equal(indices, time(series)[-17])
})

test_that("CausalImpact.RunWithData.CustomConstructModel", {
# Test daily data (zoo object)
data <- zoo(cbind(rnorm(200), rnorm(200), rnorm(200)),
seq.Date(as.Date("2014-01-01"), as.Date("2014-01-01") + 199,
by = 1))

pre.period <- as.Date(c("2014-01-01", "2014-04-10")) # 100 days
post.period <- as.Date(c("2014-04-11", "2014-07-09")) # 90 days

go <- function(data) {
y <- data[, 1]
sdy <- sd(y, na.rm = TRUE)
sd.prior <- SdPrior(sigma.guess = 0.01 * sdy,
upper.limit = sdy,
sample.size = 32)
ss <- AddLocalLevel(list(), y, sigma.prior = sd.prior)
bsts.model <- bsts(y, state.specification = ss, niter = 100,
seed = 1, ping = 0)
}

suppressWarnings(impact <- CausalImpact(data, pre.period, post.period,
construct.model = go))
expect_equal(time(impact$model$bsts.model$original.series), 1:200)
expect_equal(time(impact$series), time(data))
CallAllS3Methods(impact)
})

test_that("CausalImpact.RunWithBstsModel", {

# Test on a healthy bsts object
Expand Down

0 comments on commit fdf2936

Please sign in to comment.