Skip to content

Commit

Permalink
Fix Prophet.resample_time_stamps bug (#112)
Browse files Browse the repository at this point in the history
* add unit test for Prophet.resample_time_stamps

test failing at this commit because of the bug

* fix bug in Prophet.resample_time_stamps
  • Loading branch information
rafaelleinio committed Jul 14, 2022
1 parent e0a79f0 commit 7047d12
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion merlion/models/forecast/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _train(self, train_data: pd.DataFrame, train_config=None):

def resample_time_stamps(self, time_stamps: Union[int, List[int]], time_series_prev: TimeSeries = None):
if isinstance(time_stamps, (int, float)):
times = pd.date_range(start=self.last_train_time, freq=self.timedelta, periods=int(time_stamps))[1:]
times = pd.date_range(start=self.last_train_time, freq=self.timedelta, periods=int(time_stamps + 1))[1:]
time_stamps = to_timestamp(times)
return time_stamps

Expand Down
23 changes: 23 additions & 0 deletions tests/forecast/test_prophet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest

import pandas as pd
import numpy as np

from merlion.models.forecast.prophet import Prophet, ProphetConfig
from merlion.utils.resample import to_timestamp


class TestProphet(unittest.TestCase):
def test_resample_time_stamps(self):
# arrange
config = ProphetConfig()
prophet = Prophet(config)
prophet.last_train_time = pd._libs.tslibs.timestamps.Timestamp(year=2022, month=1, day=1)
prophet.timedelta = pd._libs.tslibs.timedeltas.Timedelta(days=1)
target = np.array([to_timestamp(pd._libs.tslibs.timestamps.Timestamp(year=2022, month=1, day=2))])

# act
output = prophet.resample_time_stamps(time_stamps=1)

# assert
assert output == target

0 comments on commit 7047d12

Please sign in to comment.