Skip to content

Commit

Permalink
fix sporadic forecasting test failure (#1280)
Browse files Browse the repository at this point in the history
* fix: change ets seasonal_periods search space

* feat: add ETSModel params validation

* fix: add logging on action of params change

* fix: resolve gensim dependency to a specific commit
  • Loading branch information
Lopa10ko committed Apr 5, 2024
1 parent b711ebe commit cd915a8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,24 @@ class ExpSmoothingImplementation(ModelImplementation):
def __init__(self, params: OperationParameters):
super().__init__(params)
self.model = None
if self.params.get("seasonal"):
self.seasonal_periods = int(self.params.get("seasonal_periods"))
if self.params.get('seasonal'):
self.seasonal_periods = int(self.params.get('seasonal_periods'))
else:
self.seasonal_periods = None

def fit(self, input_data):
endog = input_data.features.astype('float64')

# check ets params according to statsmodels restrictions
if self._check_and_correct_params(endog):
self.log.info(f'Changed the following ETSModel parameters: {self.params.changed_parameters}')

self.model = ETSModel(
input_data.features.astype("float64"),
error=self.params.get("error"),
trend=self.params.get("trend"),
seasonal=self.params.get("seasonal"),
damped_trend=self.params.get("damped_trend") if self.params.get("trend") else None,
endog=endog,
error=self.params.get('error'),
trend=self.params.get('trend'),
seasonal=self.params.get('seasonal'),
damped_trend=self.params.get('damped_trend') if self.params.get('trend') else None,
seasonal_periods=self.seasonal_periods
)
self.model = self.model.fit(disp=False)
Expand Down Expand Up @@ -312,3 +318,21 @@ def predict_for_fit(self, input_data: InputData) -> OutputData:
predict=predict,
data_type=DataTypesEnum.table)
return output_data

def _check_and_correct_params(self, endog: np.ndarray) -> bool:
ets_components = ['error', 'trend', 'seasonal']
params_changed = False
if any(self.params.get(component) == 'mul' for component in ets_components):
if np.any(endog <= 0):
for component in ets_components:
if self.params.get(component) == 'mul':
self.params.update(**{f'{component}': 'add'})
params_changed = True

if self.params.get('trend') == 'mul' \
and self.params.get('damped_trend') \
and not self.params.get('seasonal'):
self.params.update(trend='add')
params_changed = True

return params_changed
2 changes: 1 addition & 1 deletion fedot/core/pipelines/tuning/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def get_parameters_dict(self):
'type': 'categorical'},
'seasonal_periods': {
'hyperopt-dist': hp.uniform,
'sampling-scope': [1, 100],
'sampling-scope': [2, 100],
'type': 'continuous'}
},
'glm': {
Expand Down
2 changes: 1 addition & 1 deletion other_requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ opencv-python >= 4.5.5.64
Pillow >= 8.2.0

# Texts
gensim >= 4.1.2
gensim @ git+https://github.com/piskvorky/gensim.git@ad68ee3
nltk >= 3.5

# Misc
Expand Down

0 comments on commit cd915a8

Please sign in to comment.