Skip to content

Commit

Permalink
Fix n_retrain option in benchmark_forecast.py. (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
aadyotb committed Apr 18, 2022
1 parent dd47e85 commit de605a4
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion benchmark_forecast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021 salesforce.com, inc.
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Expand Down Expand Up @@ -305,9 +305,11 @@ def train_model(
elif retrain_type == "sliding_window_retrain":
retrain_freq = math.ceil(test_window_len / int(n_retrain))
train_window = train_window_len
horizon = min(retrain_freq, horizon)
elif retrain_type == "expanding_window_retrain":
retrain_freq = math.ceil(test_window_len / int(n_retrain))
train_window = None
horizon = min(retrain_freq, horizon)
else:
raise ValueError(
"the retrain_type should be without_retrain, sliding_window_retrain or expanding_window_retrain"
Expand Down

0 comments on commit de605a4

Please sign in to comment.