Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score fixed #485

Open
wants to merge 4 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,7 +1781,8 @@ def predict(
def score(
self,
y_pred: np.ndarray,
y_test: Union[np.ndarray, pd.DataFrame]
y_test: Union[np.ndarray, pd.DataFrame],
**score_kwargs: Any,
) -> Dict[str, float]:
"""Calculate the score on the test set.
Calculate the evaluation measure on the test set.
Expand All @@ -1791,6 +1792,10 @@ def score(
The test predictions
y_test (np.ndarray):
The test ground truth labels.
score_kwargs: Any
additional arguments for computing the scores. Some metrics might require special arguments. Currently,
this argument is required by time series forecasting tasks. For detailed information, please check
autoPyTorch/api/time_series_forecasting.py

Returns:
Dict[str, float]:
Expand All @@ -1804,7 +1809,8 @@ def score(
"Please check the log file for related errors. ")
return calculate_score(target=y_test, prediction=y_pred,
task_type=STRING_TO_TASK_TYPES[self.task_type],
metrics=[self._metric])
metrics=[self._metric],
**score_kwargs)

def __getstate__(self) -> Dict[str, Any]:
# Cannot serialize a client!
Expand Down
63 changes: 63 additions & 0 deletions autoPyTorch/api/time_series_forecasting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import warnings

import numpy as np

Expand Down Expand Up @@ -590,3 +591,65 @@ def update_sliding_window_size(self, n_prediction_steps: int) -> None:
],
default_value=int(np.ceil(1.25 * base_window_size)),
)

def score(
self,
y_pred: np.ndarray,
y_test: Union[np.ndarray, pd.DataFrame],
y_test_past: Optional[List[Union[np.ndarray, pd.DataFrame]]] = None,
sp: Optional[int] = None,
n_prediction_steps: Optional[int] = None,
**score_kwargs: Any,
) -> Dict[str, float]:
"""Calculate the score on the test set.
Calculate the evaluation measure on the test set. As forecasting metrics might require some additional metrics,
they can be either given by the users or by default given by the api.

NOTE: MASE metric requires to scale the predicted values with the previous data. However, given that we do not
have access to the previous data within this function, users need to manually compute the coefficient with the
function autoPyTorch.pipeline.components.training.metrics.metrics.compute_mase_coefficient and multiple both
y_pred and y_test with that coefficient

Args:
y_pred (np.ndarray):
The test predictions
y_test (np.ndarray):
The test ground truth labels.
y_test_past:
sp (Optional[int]):

score_kwargs: Any
additional arguments for computing the scores. Some metrics might require special arguments

Returns:
Dict[str, float]:
Value of the evaluation metric calculated on the test set.
"""
y_pred = np.asarray(y_pred)
y_test = np.asarray(y_test)
assert np.all(y_pred.shape == y_test.shape), f"y_pred and y_test must have the same shape! But they are " \
f"{y_pred.shape} and {y_test.shape}!"
score_kwargs_forecasting = {}
if sp is None:
sp = self.dataset.seasonality
score_kwargs_forecasting['sp'] = sp
if n_prediction_steps is None:
n_prediction_steps = self.dataset.n_prediction_steps
score_kwargs_forecasting['n_prediction_steps'] = n_prediction_steps
if y_test_past is not None:
assert len(y_test_past) == len(y_test), f'The length of y_test_past must be equal to the length of ' \
f'y_test.But they are {len(y_test_past)} and {len(y_test)}'
from autoPyTorch.pipeline.components.training.metrics.metrics import compute_mase_coefficient
mase_coefficient = np.asarray([compute_mase_coefficient(y_past, sp) for y_past in y_test_past])
if len(y_pred.shape) > 2:
mase_coefficient = np.expand_dims(mase_coefficient, 1)
# Match the shape of mase_coefficient and y_test
score_kwargs_forecasting['mase_coefficient'] = mase_coefficient
else:
if self._metric.name.endswith('MASE_Forecasting'):
warnings.warn("To compute MASE losses, the past target values must be provided. Here we simply ignore "
"the scaling coefficient and the loss degenerate to a MAE loss")

return super(TimeSeriesForecastingTask, self).score(np.asarray(y_pred), np.asarray(y_test),
**score_kwargs_forecasting,
**score_kwargs)
2 changes: 1 addition & 1 deletion autoPyTorch/datasets/time_series_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self,
Y: np.ndarray,
start_time: Optional[pd.DatetimeIndex] = None,
freq: str = '1Y',
time_feature_transform: List[TimeFeature] = [ConstantTransform],
time_feature_transform: List[TimeFeature] = [ConstantTransform()],
X_test: Optional[np.ndarray] = None,
Y_test: Optional[np.ndarray] = None,
train_transforms: Optional[torchvision.transforms.Compose] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,7 @@ def __init__(self, backend: Backend, queue: Queue,
self.num_sequences = self.datamanager.num_sequences
self.num_targets = self.datamanager.num_targets
self.seq_length_min = np.min(self.num_sequences)
seasonality = SEASONALITY_MAP.get(self.datamanager.freq, 1)
if isinstance(seasonality, list):
seasonality = min(seasonality) # Use to calculate MASE
self.seasonality = int(seasonality) # type: ignore[call-overload]
self.seasonality = self.datamanager.seasonality
Copy link
Contributor

@ravinkohli ravinkohli Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the AbstractEvaluator class, we use a _init_datamanager_info (see here). At the end of this method we delete the datamaneger object to reduce the memory usage. Could you also update the time series forecasting evaluator, maybe overriding the parent class function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also do it as a separate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data manager is required somewhere else by the time series evaluator because I defined time series dataset's getitem differently. Maybe in the future, we could fix that. But currently, I would prefer to keep self.datamanager in the evaluator

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay sure.


self.max_budget = max_budget
self.min_num_test_instances = min_num_test_instances
Expand Down
5 changes: 5 additions & 0 deletions examples/20_basics/example_time_series_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import warnings
import copy

import numpy as np

os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
Expand Down Expand Up @@ -91,3 +93,6 @@
# test_sets2 = api.dataset.generate_test_seqs()

pred = api.predict(test_sets)
# To compute the scores with AutoPyTorch, the ground truth value must be of shape [n_seq, seq_length, n_output]
# or [n_seq * seq_length, n_output]
score = api.score(np.expand_dims(pred, -1), np.expand_dims(np.asarray(y_test), -1))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
include_package_data=True,
extras_require={
"forecasting": [
"gluonts>=0.10.0",
"gluonts>=0.11.0",
"sktime",
"pytorch-forecasting",
],
Expand Down
13 changes: 11 additions & 2 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
from autoPyTorch.optimizer.smbo import AutoMLSMBO
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.setup.traditional_ml.traditional_learner import _traditional_learners
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy

from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy, mean_MASE_forecasting

CV_NUM_SPLITS = 2
HOLDOUT_NUM_SPLITS = 1
Expand Down Expand Up @@ -578,13 +577,23 @@ def test_time_series_forecasting(forecasting_toy_dataset, resampling_strategy, b
y_pred = estimator.predict(X_test)

assert np.shape(y_pred) == np.shape(y_test)
score_mse = estimator.score(np.expand_dims(np.asarray(y_pred), -1),
np.expand_dims(np.asarray(y_test), -1))
assert isinstance(score_mse, dict)

estimator._metric = mean_MASE_forecasting
score_mase = estimator.score(np.expand_dims(np.asarray(y_pred), -1),
np.expand_dims(np.asarray(y_test), -1),
y_test_past=y_train)

# Test refit on dummy data
estimator.refit(dataset=backend.load_datamanager())
# Make sure that a configuration space is stored in the estimator
assert isinstance(estimator.get_search_space(), CS.ConfigurationSpace)




@pytest.mark.parametrize('openml_id', (
1590, # Adult to test NaN in categorical columns
))
Expand Down
9 changes: 4 additions & 5 deletions test/test_datasets/test_time_series_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable, List, Tuple

from gluonts.time_feature import Constant as ConstantTransform
from gluonts.time_feature import DayOfMonth
from gluonts.time_feature import day_of_month

import numpy as np

Expand All @@ -12,7 +12,6 @@

import torch


from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes
from autoPyTorch.datasets.time_series_dataset import (
TimeSeriesForecastingDataset,
Expand Down Expand Up @@ -40,7 +39,7 @@ def setUp(self) -> None:

self.x_test_data = rng.rand(self.n_prediction_steps, 5)
self.y_test = rng.rand(self.n_prediction_steps, 1)
self.time_feature_transform = [DayOfMonth(), ConstantTransform(10.0)]
self.time_feature_transform = [day_of_month, ConstantTransform()]
self.known_future_features_index = [0, 2]
self.seq_uni = TimeSeriesSequence(X=None, Y=self.y,
n_prediction_steps=self.n_prediction_steps,
Expand Down Expand Up @@ -122,8 +121,8 @@ def test_uni_get_update_time_features(self):
self.assertEqual(len(self.seq_uni._cached_time_features), len(self.y))
self.assertTrue(list(past_features.shape) == [3 + 1, len(self.time_feature_transform)])
self.assertTrue(list(future_features.shape) == [self.n_prediction_steps, len(self.time_feature_transform)])
self.assertTrue(torch.all(past_features[:, 1] == 10.))
self.assertTrue(torch.all(future_features[:, 1] == 10.))
self.assertTrue(torch.all(past_features[:, 1] == 0.))
self.assertTrue(torch.all(future_features[:, 1] == 0.))

def test_uni_to_test_set(self):
self.seq_uni.transform_time_features = True
Expand Down