From 7047d12b7c9091d613e7b01ce76b636854a53a2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20Leini=C3=B6?= Date: Thu, 14 Jul 2022 00:56:11 -0300 Subject: [PATCH] Fix Prophet.resample_time_stamps bug (#112) * add unit test for Prophet.resample_time_stamps test failing at this commit because of the bug * fix bug in Prophet.resample_time_stamps --- merlion/models/forecast/prophet.py | 2 +- tests/forecast/test_prophet.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 tests/forecast/test_prophet.py diff --git a/merlion/models/forecast/prophet.py b/merlion/models/forecast/prophet.py index bad4d20c0..45a8bb7b2 100644 --- a/merlion/models/forecast/prophet.py +++ b/merlion/models/forecast/prophet.py @@ -209,7 +209,7 @@ def _train(self, train_data: pd.DataFrame, train_config=None): def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_prev: TimeSeries = None): if isinstance(time_stamps, (int, float)): - times = pd.date_range(start=self.last_train_time, freq=self.timedelta, periods=int(time_stamps))[1:] + times = pd.date_range(start=self.last_train_time, freq=self.timedelta, periods=int(time_stamps + 1))[1:] time_stamps = to_timestamp(times) return time_stamps diff --git a/tests/forecast/test_prophet.py b/tests/forecast/test_prophet.py new file mode 100644 index 000000000..4fd5d20b0 --- /dev/null +++ b/tests/forecast/test_prophet.py @@ -0,0 +1,23 @@ +import unittest + +import pandas as pd +import numpy as np + +from merlion.models.forecast.prophet import Prophet, ProphetConfig +from merlion.utils.resample import to_timestamp + + +class TestProphet(unittest.TestCase): + def test_resample_time_stamps(self): + # arrange + config = ProphetConfig() + prophet = Prophet(config) + prophet.last_train_time = pd._libs.tslibs.timestamps.Timestamp(year=2022, month=1, day=1) + prophet.timedelta = pd._libs.tslibs.timedeltas.Timedelta(days=1) + target = np.array([to_timestamp(pd._libs.tslibs.timestamps.Timestamp(year=2022, month=1, day=2))]) + + # act + output = prophet.resample_time_stamps(time_stamps=1) + + # assert + assert output == target