Skip to content

Commit

Permalink
enable non-seasonal forecast
Browse files Browse the repository at this point in the history
  • Loading branch information
Akai01 committed Aug 4, 2022
1 parent 95571ba commit b916a96
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 42 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Encoding: UTF-8
LazyData: true
SystemRequirements: Python (>= 3.6)
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.2
RoxygenNote: 7.2.0
Imports:
dplyr (>= 1.0.7),
forecast (>= 8.15),
Expand Down
12 changes: 9 additions & 3 deletions R/ngbfor.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,15 @@ NGBforecast <- R6::R6Class(
train_loss_monitor = train_loss_monitor,
val_loss_monitor = val_loss_monitor,
early_stopping_rounds = early_stopping_rounds)
method = paste0("NGBforecast", " with ",
gsub("[()]", "", paste0(private$Base)),
"(",max_lag,", ", K, ")")
if(seasonal){
method = paste0("NGBforecast", " with ",
gsub("[()]", "", paste0(private$Base)),
"(",max_lag,", ", K, ")")
} else {
method = paste0("NGBforecast", " with ",
gsub("[()]", "", paste0(private$Base)),
"(",max_lag,", ", 0, ")")
}

fitted_int <- ts(c(rep(NA, max_lag), c(model$predict(X))),
start = start(y), frequency = frequency(y))
Expand Down
14 changes: 7 additions & 7 deletions R/ngbforcv.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,37 +94,37 @@ NGBforecastCV <- R6::R6Class(
tol = 0.0001,
random_state = NULL){

if(class(Dist)!="list"){
if(!inherits(Dist, "list")){
stop(
"Please profide a list of Dist object with at least one spesification"
)
}
if(class(Score)!="list"){
if(!inherits(Score, "list")){
stop(
"Please profide a list of Score object with at least one spesification"
)
}
if(class(Base)!="list"){
if(!inherits(Base, "list")){
stop(
"Please profide a list of Base base learnesr with at least one spesification"
)
}
if(class(n_estimators)!="list"){
if(!inherits(n_estimators, "list")){
stop(
"Please profide a list of n_estimators object with at least one spesification"
)
}
if(class(learning_rate)!="list"){
if(!inherits(learning_rate, "list")){
stop(
"Please profide a list of learning_rate object with at least one spesification"
)
}
if(class(minibatch_frac)!="list"){
if(!inherits(minibatch_frac, "list")){
stop(
"Please profide a list of minibatch_frac object with at least one spesification"
)
}
if(class(col_sample)!="list"){
if(!inherits(col_sample, "list")){
stop(
"Please profide a list of col_sample object with at least one spesification"
)
Expand Down
4 changes: 3 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ prepare_data <- function(y,
x <- matrix(0, nrow = c(length_y - max_lag), ncol = ncolx)

x[, seq_len(max_lag)] <- lag_maker(y, max_lag)


fourier_s <- NULL

if (seasonal == TRUE & freq > 1)
{
fourier_s <- forecast::fourier(modified_y, K = K)
Expand Down
42 changes: 21 additions & 21 deletions man/NGBforecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 9 additions & 9 deletions man/NGBforecastCV.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions tests/testthat/test-ngbforecastcv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
if(require(testthat)){
test_that("tests for some arguments in NGBforecastCV", {

if_not_ngboost_exist_skip()

dists <- list(Dist("Normal"))
base_learners <- list(sklearner(module = "tree", class = "DecisionTreeRegressor",
max_depth = 6),
sklearner(module = "tree", class = "DecisionTreeRegressor",
max_depth = 7))
scores <- list(Scores("LogScore"))

model <- NGBforecastCV$new(Dist = dists,
Base = base_learners,
Score = scores,
natural_gradient = TRUE,
n_estimators = list(10, 12),
learning_rate = list(0.1),
minibatch_frac = list(0.1),
col_sample = list(0.3),
verbose = FALSE,
verbose_eval = 100,
tol = 1e-5)

params <- model$tune(y = AirPassengers,
seasonal = TRUE,
max_lag = 12,
xreg = NULL,
early_stopping_rounds = NULL,
n_splits = 4L)

out <- class(params)

expect_equal(out, "list")

})

}
85 changes: 85 additions & 0 deletions tests/testthat/test-ngboostfor.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,89 @@ if(require(testthat)){
expect_equal(fc, c(454.3927, 425.1821), tolerance = 4)

})

test_that("tests for non-seasonal NGBforecast", {

if_not_ngboost_exist_skip()

model <- NGBforecast$new(Dist = Dist("LogNormal"),
Base = sklearner(),
Score = Scores("LogScore"),
natural_gradient = TRUE,
n_estimators = 200,
learning_rate = 0.1,
minibatch_frac = 1,
col_sample = 1,
verbose = TRUE,
verbose_eval = 100,
tol = 1e-5)

model$fit(y = AirPassengers, seasonal = FALSE, max_lag = 12, xreg = NULL,
early_stopping_rounds = 10L)


fc <- c(model$forecast(h = 2, level = c(90, 80), xreg = NULL)$mean)

expect_equal(fc, c(468, 461), tolerance = 5)

})

test_that("tests for non-seasonal xreg NGBforecast", {

if_not_ngboost_exist_skip()

model <- NGBforecast$new(Dist = Dist("LogNormal"),
Base = sklearner(),
Score = Scores("LogScore"),
natural_gradient = TRUE,
n_estimators = 200,
learning_rate = 0.1,
minibatch_frac = 1,
col_sample = 1,
verbose = TRUE,
verbose_eval = 100,
tol = 1e-5)

xreg <- matrix(rnorm(length(AirPassengers)*2, mean = 3), ncol = 2)
newxreg <- matrix(rnorm(4, mean = 3), ncol = 2)

model$fit(y = AirPassengers, seasonal = FALSE, max_lag = 12, xreg = xreg,
early_stopping_rounds = 10L)


fc <- c(model$forecast(h = 2, level = c(90, 80), xreg = newxreg)$mean)

expect_equal(fc, c(463, 461), tolerance = 5)

})

test_that("tests for seasonal xreg NGBforecast", {

if_not_ngboost_exist_skip()

model <- NGBforecast$new(Dist = Dist("LogNormal"),
Base = sklearner(),
Score = Scores("LogScore"),
natural_gradient = TRUE,
n_estimators = 200,
learning_rate = 0.1,
minibatch_frac = 1,
col_sample = 1,
verbose = TRUE,
verbose_eval = 100,
tol = 1e-5)

xreg <- matrix(rnorm(length(AirPassengers)*2, mean = 3), ncol = 2)
newxreg <- matrix(rnorm(4, mean = 3), ncol = 2)

model$fit(y = AirPassengers, seasonal = TRUE, max_lag = 12, xreg = xreg,
early_stopping_rounds = 10L)


fc <- c(model$forecast(h = 2, level = c(90, 80), xreg = newxreg)$mean)

expect_equal(fc, c(470, 463), tolerance = 5)

})

}

0 comments on commit b916a96

Please sign in to comment.