Skip to content

Commit a024481

Browse files
Merge pull request #92 from databricks-industry-solutions/debug-sktime
added prophet via sktime
2 parents 2ece8ad + 0279af4 commit a024481

File tree

9 files changed

+80
-33
lines changed

9 files changed

+80
-33
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Get started now!
1212

1313
## What's New
1414

15+
- Feb 2025: [Prophet](https://www.sktime.net/en/stable/api_reference/auto_generated/sktime.forecasting.fbprophet.Prophet.html) is available for univariate forecasting via `SKTimeProphet`. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/local_univariate_daily).
1516
- Feb 2025: Added a post evaluation notebook that shows how to run fine-grained model selection after running MMF. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/post-evaluation-analysis.ipynb).
1617
- Jan 2025: [TimesFM](https://github.com/google-research/timesfm) is available for univariate and covariate forecasting. Try the notebooks: [univariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py) and [covariate](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/external_regressors/foundation_external_regressors_daily.py).
1718
- Jan 2025: [Chronos Bolt](https://github.com/amazon-science/chronos-forecasting) models are available for univariate forecasting. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/daily/foundation_daily.py).
@@ -51,6 +52,7 @@ active_models = [
5152
"RFableNNETAR",
5253
"RFableEnsemble",
5354
"RDynamicHarmonicRegression",
55+
"SKTimeProphet",
5456
"SKTimeTBats",
5557
"SKTimeLgbmDsDt",
5658
]

examples/daily/local_univariate_daily.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def transform_group(df):
152152
"RFableEnsemble",
153153
"RDynamicHarmonicRegression",
154154
"SKTimeTBats",
155+
"SKTimeProphet",
155156
"SKTimeLgbmDsDt",
156157
]
157158

examples/external_regressors/local_univariate_external_regressors_daily.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
"RFableEnsemble",
128128
"RDynamicHarmonicRegression",
129129
"SKTimeTBats",
130+
"SKTimeProphet",
130131
"SKTimeLgbmDsDt",
131132
]
132133

examples/hourly/local_univariate_hourly.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def transform_group(df):
121121
# MAGIC %md ### Models
122122
# MAGIC Let's configure a list of models we are going to apply to our time series for evaluation and forecasting. A comprehensive list of all supported models is available in [mmf_sa/models/models_conf.yaml](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/models_conf.yaml). Look for the models where `model_type: local`; these are the local models we import from [statsforecast](https://github.com/Nixtla/statsforecast). Check their documentations for the description of each model.
123123
# MAGIC
124-
# MAGIC *Note that hourly forecasting is currently not supported for `r fable` and `sktime` models.*
124+
# MAGIC *Note that hourly forecasting is currently not supported for `r fable` models.*
125125

126126
# COMMAND ----------
127127

@@ -140,6 +140,9 @@ def transform_group(df):
140140
"StatsForecastCrostonClassic",
141141
"StatsForecastCrostonOptimized",
142142
"StatsForecastCrostonSBA",
143+
"SKTimeTBats",
144+
"SKTimeProphet",
145+
"SKTimeLgbmDsDt",
143146
]
144147

145148
# COMMAND ----------

examples/m5/local_univariate_daily_m5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"RFableEnsemble",
5858
"RDynamicHarmonicRegression",
5959
"SKTimeTBats",
60+
"SKTimeProphet",
6061
"SKTimeLgbmDsDt",
6162
]
6263

examples/monthly/local_univariate_monthly.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def transform_group(df):
161161
"RFableEnsemble",
162162
"RDynamicHarmonicRegression",
163163
"SKTimeTBats",
164+
"SKTimeProphet",
164165
"SKTimeLgbmDsDt",
165166
]
166167

examples/weekly/local_univariate_weekly.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def transform_group(df):
150150
"RFableEnsemble",
151151
"RDynamicHarmonicRegression",
152152
"SKTimeTBats",
153+
"SKTimeProphet",
153154
"SKTimeLgbmDsDt",
154155
]
155156

mmf_sa/models/models_conf.yaml

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,27 +166,40 @@ models:
166166
model_spec:
167167
fourier_terms:
168168

169-
SKTimeLgbmDsDt:
169+
SKTimeTBats:
170170
module: mmf_sa.models.sktime.SKTimeForecastingPipeline
171-
model_class: SKTimeLgbmDsDt
171+
model_class: SKTimeTBats
172172
framework: SKTime
173173
model_type: local
174174
enable_gcv: false
175175
model_spec:
176-
deseasonalise_model: multiplicative
176+
box_cox: True
177+
use_trend: True
177178
season_length: 7
178-
detrend_poly_degree: 2
179179

180-
SKTimeTBats:
180+
SKTimeProphet:
181181
module: mmf_sa.models.sktime.SKTimeForecastingPipeline
182-
model_class: SKTimeTBats
182+
model_class: SKTimeProphet
183183
framework: SKTime
184184
model_type: local
185185
enable_gcv: false
186186
model_spec:
187-
box_cox: True
188-
use_trend: True
187+
growth: linear
188+
yearly_seasonality: auto
189+
weekly_seasonality: auto
190+
daily_seasonality: auto
191+
seasonality_mode: additive
192+
193+
SKTimeLgbmDsDt:
194+
module: mmf_sa.models.sktime.SKTimeForecastingPipeline
195+
model_class: SKTimeLgbmDsDt
196+
framework: SKTime
197+
model_type: local
198+
enable_gcv: false
199+
model_spec:
200+
deseasonalise_model: multiplicative
189201
season_length: 7
202+
detrend_poly_degree: 2
190203

191204
NeuralForecastRNN:
192205
module: mmf_sa.models.neuralforecast.NeuralForecastPipeline

mmf_sa/models/sktime/SKTimeForecastingPipeline.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ForecastingGridSearchCV,
99
)
1010
from sktime.forecasting.tbats import TBATS
11+
from sktime.forecasting.fbprophet import Prophet
1112
from sktime.forecasting.compose import make_reduction
1213
from sktime.forecasting.compose import TransformedTargetForecaster
1314
from sktime.transformations.series.detrend import Detrender, ConditionalDeseasonalizer
@@ -47,7 +48,9 @@ def prepare_data(self, df: pd.DataFrame) -> pd.DataFrame:
4748
return df
4849

4950
def fit(self, x, y=None):
50-
if self.params.get("enable_gcv", False) and self.model is None and self.param_grid:
51+
if (self.params.get("enable_gcv", False)
52+
and self.model is None
53+
and self.param_grid):
5154
_model = self.create_model()
5255
cv = SlidingWindowSplitter(
5356
initial_window=int(len(x) - self.params.prediction_length * 4),
@@ -68,8 +71,8 @@ def predict(self, hist_df: pd.DataFrame, val_df: pd.DataFrame = None):
6871
ForecastingHorizon(np.arange(1, self.params.prediction_length + 1))
6972
)
7073
date_idx = pd.date_range(
71-
_df.index.max().to_timestamp(freq=self.params.freq) + pd.DateOffset(days=1),
72-
_df.index.max().to_timestamp(freq=self.params.freq) + pd.DateOffset(days=self.params.prediction_length),
74+
_df.index.max().to_timestamp(freq=self.params.freq) + self.one_ts_offset,
75+
_df.index.max().to_timestamp(freq=self.params.freq) + self.prediction_length_offset,
7376
freq=self.params.freq,
7477
name=self.params.date_col,
7578
)
@@ -82,6 +85,48 @@ def forecast(self, x, spark=None):
8285
return self.predict(x)
8386

8487

88+
class SKTimeTBats(SKTimeForecastingPipeline):
89+
def __init__(self, params):
90+
super().__init__(params)
91+
92+
def create_model(self) -> BaseForecaster:
93+
model = TBATS(
94+
sp=int(self.model_spec.get("season_length")),
95+
use_trend=self.model_spec.get("use_trend"),
96+
use_box_cox=self.model_spec.get("box_cox"),
97+
n_jobs=-1,
98+
)
99+
return model
100+
101+
def create_param_grid(self):
102+
return {
103+
"use_trend": [True, False],
104+
"use_box_cox": [True, False],
105+
"sp": [1, 7, 14],
106+
}
107+
108+
class SKTimeProphet(SKTimeForecastingPipeline):
109+
def __init__(self, params):
110+
super().__init__(params)
111+
112+
def create_model(self) -> BaseForecaster:
113+
model = Prophet(
114+
freq=self.params.freq,
115+
growth = self.model_spec.get("growth"),
116+
yearly_seasonality=self.model_spec.get("yearly_seasonality"),
117+
weekly_seasonality=self.model_spec.get("weekly_seasonality"),
118+
daily_seasonality=self.model_spec.get("daily_seasonality"),
119+
seasonality_mode=self.model_spec.get("seasonality_mode"),
120+
)
121+
return model
122+
123+
def create_param_grid(self):
124+
return {
125+
"growth": ['linear', 'logarithmic'],
126+
"seasonality_mode": ['additive', 'multiplicative'],
127+
}
128+
129+
85130
class SKTimeLgbmDsDt(SKTimeForecastingPipeline):
86131
def __init__(self, params):
87132
super().__init__(params)
@@ -126,24 +171,3 @@ def create_param_grid(self):
126171
self.params.prediction_length * 2,
127172
],
128173
}
129-
130-
131-
class SKTimeTBats(SKTimeForecastingPipeline):
132-
def __init__(self, params):
133-
super().__init__(params)
134-
135-
def create_model(self) -> BaseForecaster:
136-
model = TBATS(
137-
sp=int(self.model_spec.get("season_length")),
138-
use_trend=self.model_spec.get("use_trend"),
139-
use_box_cox=self.model_spec.get("box_cox"),
140-
n_jobs=-1,
141-
)
142-
return model
143-
144-
def create_param_grid(self):
145-
return {
146-
"use_trend": [True, False],
147-
"use_box_cox": [True, False],
148-
"sp": [1, 7, 14],
149-
}

0 commit comments

Comments
 (0)