Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce construct.model parameter #51

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions R/impact_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ FormatInputPrePostPeriod <- function(pre.period, post.period, data) {

FormatInputForCausalImpact <- function(data, pre.period, post.period,
model.args, bsts.model,
post.period.response, alpha) {
post.period.response, alpha,
construct.model) {
# Checks and formats all input arguments supplied to CausalImpact(). See the
# documentation of CausalImpact() for details.
#
Expand All @@ -147,6 +148,7 @@ FormatInputForCausalImpact <- function(data, pre.period, post.period,
# bsts.model: fitted bsts model (instead of data)
# post.period.response: observed response in the post-period
# alpha: tail-area for posterior intervals
# construct.model: custom model constructor
#
# Returns:
# list of checked (and possibly reformatted) input arguments
Expand Down Expand Up @@ -213,7 +215,8 @@ CausalImpact <- function(data = NULL,
model.args = NULL,
bsts.model = NULL,
post.period.response = NULL,
alpha = 0.05) {
alpha = 0.05,
construct.model = NULL) {
# CausalImpact() performs causal inference through counterfactual
# predictions using a Bayesian structural time-series model.
#
Expand Down Expand Up @@ -278,6 +281,8 @@ CausalImpact <- function(data = NULL,
# alpha: Desired tail-area probability for posterior intervals.
# Defaults to 0.05, which will produce central 95\% intervals.
#
# construct.model: Custom model constructor.
#
# Returns:
# A CausalImpact object. This is a list of:
# series: observed data, counterfactual, pointwise and cumulative impact
Expand Down Expand Up @@ -341,7 +346,8 @@ CausalImpact <- function(data = NULL,
# Check input
checked <- FormatInputForCausalImpact(data, pre.period, post.period,
model.args, bsts.model,
post.period.response, alpha)
post.period.response, alpha,
construct.model)
data <- checked$data
pre.period <- checked$pre.period
post.period <- checked$post.period
Expand All @@ -352,7 +358,7 @@ CausalImpact <- function(data = NULL,

# Depending on input, dispatch to the appropriate Run* method()
if (!is.null(data)) {
impact <- RunWithData(data, pre.period, post.period, model.args, alpha)
impact <- RunWithData(data, pre.period, post.period, model.args, alpha, construct.model)
# Return pre- and post-period in the time unit of the time series.
times <- time(data)
impact$model$pre.period <- times[pre.period]
Expand All @@ -364,7 +370,7 @@ CausalImpact <- function(data = NULL,
return(impact)
}

RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
RunWithData <- function(data, pre.period, post.period, model.args, alpha, construct.model) {
# Runs an impact analysis on top of a fitted bsts model.
#
# Args:
Expand All @@ -375,6 +381,7 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
# limits.
# model.args: list of model arguments
# alpha: tail-probabilities of posterior intervals
# construct.model: custom model constructor
#
# Returns:
# See CausalImpact().
Expand Down Expand Up @@ -409,7 +416,16 @@ RunWithData <- function(data, pre.period, post.period, model.args, alpha) {
window(data.modeling[, 1], start = pre.period[2] + 1) <- NA

# Construct model and perform inference
bsts.model <- ConstructModel(data.modeling, model.args)
if (!is.null(construct.model)) {
checked <- FormatInputForConstructModel(data.modeling, model.args)
y <- checked$data[, 1]
# If the series is ill-conditioned, abort inference and return NULL
bsts.model <- if (ObservationsAreIllConditioned(y)) NULL else {
construct.model(checked$data)
}
} else {
bsts.model <- ConstructModel(data.modeling, model.args)
}

# Compile posterior inferences
if (!is.null(bsts.model)) {
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test-impact-analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,33 @@ test_that("CausalImpact.RunWithData.MissingTimePoint", {
expect_equal(indices, time(series)[-17])
})

test_that("CausalImpact.RunWithData.CustomConstructModel", {
# Test daily data (zoo object)
data <- zoo(cbind(rnorm(200), rnorm(200), rnorm(200)),
seq.Date(as.Date("2014-01-01"), as.Date("2014-01-01") + 199,
by = 1))

pre.period <- as.Date(c("2014-01-01", "2014-04-10")) # 100 days
post.period <- as.Date(c("2014-04-21", "2014-07-09")) # 80 days

go <- function(data) {
y <- data[, 1]
sdy <- sd(y, na.rm = TRUE)
sd.prior <- SdPrior(sigma.guess = 0.01 * sdy,
upper.limit = sdy,
sample.size = 32)
ss <- AddLocalLevel(list(), y, sigma.prior = sd.prior)
bsts.model <- bsts(y, state.specification = ss, niter = 100,
seed = 1, ping = 0)
}

suppressWarnings(impact <- CausalImpact(data, pre.period, post.period,
construct.model = go))
expect_equal(time(impact$model$bsts.model$original.series), 1:200)
expect_equal(time(impact$series), time(data))
CallAllS3Methods(impact)
})

test_that("CausalImpact.RunWithBstsModel", {

# Test on a healthy bsts object
Expand Down