From de605a464f122d97af3ea15b8ef2c1c6751d8eaf Mon Sep 17 00:00:00 2001 From: Aadyot Bhatnagar Date: Mon, 18 Apr 2022 15:40:08 -0700 Subject: [PATCH] Fix n_retrain option in benchmark_forecast.py. (#86) --- benchmark_forecast.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmark_forecast.py b/benchmark_forecast.py index d70500f5c..2a75a3a02 100644 --- a/benchmark_forecast.py +++ b/benchmark_forecast.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021 salesforce.com, inc. +# Copyright (c) 2022 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause @@ -305,9 +305,11 @@ def train_model( elif retrain_type == "sliding_window_retrain": retrain_freq = math.ceil(test_window_len / int(n_retrain)) train_window = train_window_len + horizon = min(retrain_freq, horizon) elif retrain_type == "expanding_window_retrain": retrain_freq = math.ceil(test_window_len / int(n_retrain)) train_window = None + horizon = min(retrain_freq, horizon) else: raise ValueError( "the retrain_type should be without_retrain, sliding_window_retrain or expanding_window_retrain"