From 9bb4e55c6167c02cb53391db17b5bd68d9bf289c Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Sat, 20 Feb 2021 03:15:39 +0900 Subject: [PATCH 1/5] [Fork from #105] Made CrossValFuncs and HoldOutFuncs class to group the functions --- autoPyTorch/datasets/base_dataset.py | 24 ++- autoPyTorch/datasets/resampling_strategy.py | 193 +++++++++++++------- 2 files changed, 134 insertions(+), 83 deletions(-) diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 88d960d66..cb8f0841a 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, Callable import numpy as np @@ -13,18 +13,16 @@ from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES from autoPyTorch.datasets.resampling_strategy import ( - CROSS_VAL_FN, + CrossValFuncs, CrossValTypes, DEFAULT_RESAMPLING_PARAMETERS, - HOLDOUT_FN, HoldoutValTypes, - get_cross_validators, - get_holdout_validators, - is_stratified, + HoldOutFuncs, ) from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset] +SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] def check_valid_data(data: Any) -> None: @@ -104,8 +102,8 @@ def __init__( if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors - self.cross_validators: Dict[str, CROSS_VAL_FN] = {} - self.holdout_validators: Dict[str, HOLDOUT_FN] = {} + self.cross_validators: Dict[str, SplitFunc] = {} + self.holdout_validators: Dict[str, SplitFunc] = {} self.rng = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy @@ -126,8 +124,8 @@ def __init__( self.is_small_preprocess = True # Make sure cross validation splits are created once - self.cross_validators = get_cross_validators(*CrossValTypes) - self.holdout_validators = get_holdout_validators(*HoldoutValTypes) + self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes) self.splits = self.get_splits_from_resampling_strategy() # We also need to be able to transform the data, be it for pre-processing @@ -175,7 +173,7 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]: Returns: A transformed single point prediction """ - + X = self.train_tensors[0].iloc[[index]] if hasattr(self.train_tensors[0], 'loc') \ else self.train_tensors[0][index] @@ -255,7 +253,7 @@ def create_cross_val_splits( if not isinstance(cross_val_type, CrossValTypes): raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.') kwargs = {} - if is_stratified(cross_val_type): + if cross_val_type.is_stratified(): # we need additional information about the data for stratification kwargs["stratify"] = self.train_tensors[-1] splits = self.cross_validators[cross_val_type.name]( @@ -290,7 +288,7 @@ def create_holdout_val_split( if not isinstance(holdout_val_type, HoldoutValTypes): raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.') kwargs = {} - if is_stratified(holdout_val_type): + if holdout_val_type.is_stratified(): # we need additional information about the data for stratification kwargs["stratify"] = self.train_tensors[-1] train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs) diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index 1d0bc3077..7d5a29039 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable import numpy as np @@ -15,8 +15,12 @@ from typing_extensions import Protocol +SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] + + # Use callback protocol as workaround, since callable with function fields count 'self' as argument class CROSS_VAL_FN(Protocol): + """TODO: deprecate soon""" def __call__(self, num_splits: int, indices: np.ndarray, @@ -25,26 +29,59 @@ def __call__(self, class HOLDOUT_FN(Protocol): + """TODO: deprecate soon""" def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] ) -> Tuple[np.ndarray, np.ndarray]: ... class CrossValTypes(IntEnum): + """The type of cross validation + + This class is used to specify the cross validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> cv_type = CrossValTypes.k_fold_cross_validation + >>> print(cv_type.name) + + k_fold_cross_validation + + >>> for cross_val_type in CrossValTypes: + print(cross_val_type.name, cross_val_type.value) + + stratified_k_fold_cross_validation 1 + k_fold_cross_validation 2 + stratified_shuffle_split_cross_validation 3 + shuffle_split_cross_validation 4 + time_series_cross_validation 5 + """ stratified_k_fold_cross_validation = 1 k_fold_cross_validation = 2 stratified_shuffle_split_cross_validation = 3 shuffle_split_cross_validation = 4 time_series_cross_validation = 5 + def is_stratified(self) -> bool: + stratified = [self.stratified_k_fold_cross_validation, + self.stratified_shuffle_split_cross_validation] + return getattr(self, self.name) in stratified + class HoldoutValTypes(IntEnum): + """The type of hold out validation (refer to CrossValTypes' doc-string)""" holdout_validation = 6 stratified_holdout_validation = 7 + def is_stratified(self) -> bool: + stratified = [self.stratified_holdout_validation] + return getattr(self, self.name) in stratified + +"""TODO: deprecate soon""" RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] +"""TODO: deprecate soon""" DEFAULT_RESAMPLING_PARAMETERS = { HoldoutValTypes.holdout_validation: { 'val_share': 0.33, @@ -67,15 +104,8 @@ class HoldoutValTypes(IntEnum): } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] -def get_cross_validators(*cross_val_types: CrossValTypes) -> Dict[str, CROSS_VAL_FN]: - cross_validators = {} # type: Dict[str, CROSS_VAL_FN] - for cross_val_type in cross_val_types: - cross_val_fn = globals()[cross_val_type.name] - cross_validators[cross_val_type.name] = cross_val_fn - return cross_validators - - def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]: + """TODO: deprecate soon""" holdout_validators = {} # type: Dict[str, HOLDOUT_FN] for holdout_val_type in holdout_val_types: holdout_val_fn = globals()[holdout_val_type.name] @@ -84,70 +114,93 @@ def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOL def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool: + """TODO: deprecate soon""" if isinstance(val_type, str): return val_type.lower().startswith("stratified") else: return val_type.name.lower().startswith("stratified") -def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False) - return train, val - - -def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \ - -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"]) - return train, val - - -def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = ShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Standard k fold cross validation. - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = KFold(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Returns train and validation indices respecting the temporal ordering of the data. - Dummy example: [0, 1, 2, 3] with 3 folds yields - [0] [1] - [0, 1] [2] - [0, 1, 2] [3] - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = TimeSeriesSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits +class HoldOutFuncs(): + @staticmethod + def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: + train, val = train_test_split(indices, test_size=val_share, shuffle=False) + return train, val + + @staticmethod + def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \ + -> Tuple[np.ndarray, np.ndarray]: + train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"]) + return train, val + + @classmethod + def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, SplitFunc]: + + holdout_validators = { + holdout_val_type.name: getattr(cls, holdout_val_type.name) + for holdout_val_type in holdout_val_types + } + return holdout_validators + + +class CrossValFuncs(): + @staticmethod + def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + cv = ShuffleSplit(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + cv = StratifiedShuffleSplit(n_splits=num_splits) + splits = list(cv.split(indices, kwargs["stratify"])) + return splits + + @staticmethod + def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + cv = StratifiedKFold(n_splits=num_splits) + splits = list(cv.split(indices, kwargs["stratify"])) + return splits + + @staticmethod + def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Standard k fold cross validation. + + :param indices: array of indices to be split + :param num_splits: number of cross validation splits + :return: list of tuples of training and validation indices + """ + cv = KFold(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Returns train and validation indices respecting the temporal ordering of the data. + Dummy example: [0, 1, 2, 3] with 3 folds yields + [0] [1] + [0, 1] [2] + [0, 1, 2] [3] + + :param indices: array of indices to be split + :param num_splits: number of cross validation splits + :return: list of tuples of training and validation indices + """ + cv = TimeSeriesSplit(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @classmethod + def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, SplitFunc]: + cross_validators = { + cross_val_type.name: getattr(cls, cross_val_type.name) + for cross_val_type in cross_val_types + } + return cross_validators From ffde17705964f0309e57780a5973545ab6817ce8 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Sat, 20 Feb 2021 03:24:04 +0900 Subject: [PATCH 2/5] Modified time_series_dataset.py to be compatible with resampling_strategy.py --- autoPyTorch/datasets/time_series_dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 7b0435d19..5f4e2edf1 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -8,8 +8,8 @@ from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, HoldoutValTypes, - get_cross_validators, - get_holdout_validators + CrossValFuncs, + HoldOutFuncs ) TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported @@ -60,8 +60,8 @@ def __init__(self, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation) + self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,13 +117,13 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.stratified_k_fold_cross_validation, CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation, CrossValTypes.stratified_shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation, HoldoutValTypes.stratified_holdout_validation ) @@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation ) From c4700999d0368394bbfb044946c5cc4878304854 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Tue, 23 Feb 2021 02:48:57 +0900 Subject: [PATCH 3/5] [fix]: back to the renamed version of CROSS_VAL_FN from temporal SplitFunc typing. --- autoPyTorch/datasets/base_dataset.py | 9 +++--- autoPyTorch/datasets/resampling_strategy.py | 32 ++++----------------- 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index cb8f0841a..966cd8df1 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, Callable +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -15,14 +15,15 @@ from autoPyTorch.datasets.resampling_strategy import ( CrossValFuncs, CrossValTypes, + CrossValFunc, DEFAULT_RESAMPLING_PARAMETERS, HoldoutValTypes, HoldOutFuncs, + HoldOutFunc ) from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset] -SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] def check_valid_data(data: Any) -> None: @@ -102,8 +103,8 @@ def __init__( if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors - self.cross_validators: Dict[str, SplitFunc] = {} - self.holdout_validators: Dict[str, SplitFunc] = {} + self.cross_validators: Dict[str, CrossValFunc] = {} + self.holdout_validators: Dict[str, HoldOutFunc] = {} self.rng = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index 7d5a29039..860adadaa 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -15,12 +15,8 @@ from typing_extensions import Protocol -SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] - - # Use callback protocol as workaround, since callable with function fields count 'self' as argument -class CROSS_VAL_FN(Protocol): - """TODO: deprecate soon""" +class CrossValFunc(Protocol): def __call__(self, num_splits: int, indices: np.ndarray, @@ -28,8 +24,7 @@ def __call__(self, ... -class HOLDOUT_FN(Protocol): - """TODO: deprecate soon""" +class HoldOutFunc(Protocol): def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] ) -> Tuple[np.ndarray, np.ndarray]: ... @@ -104,23 +99,6 @@ def is_stratified(self) -> bool: } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] -def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]: - """TODO: deprecate soon""" - holdout_validators = {} # type: Dict[str, HOLDOUT_FN] - for holdout_val_type in holdout_val_types: - holdout_val_fn = globals()[holdout_val_type.name] - holdout_validators[holdout_val_type.name] = holdout_val_fn - return holdout_validators - - -def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool: - """TODO: deprecate soon""" - if isinstance(val_type, str): - return val_type.lower().startswith("stratified") - else: - return val_type.name.lower().startswith("stratified") - - class HoldOutFuncs(): @staticmethod def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: @@ -134,7 +112,7 @@ def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwarg return train, val @classmethod - def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, SplitFunc]: + def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, HoldOutFunc]: holdout_validators = { holdout_val_type.name: getattr(cls, holdout_val_type.name) @@ -198,7 +176,7 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: return splits @classmethod - def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, SplitFunc]: + def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]: cross_validators = { cross_val_type.name: getattr(cls, cross_val_type.name) for cross_val_type in cross_val_types From d05696ba55bee7a1d80ac7c8f833be3b399b94e7 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Tue, 23 Feb 2021 03:51:35 +0900 Subject: [PATCH 4/5] added TODO --- autoPyTorch/datasets/resampling_strategy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index 860adadaa..1c8fea5fd 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -64,6 +64,7 @@ def is_stratified(self) -> bool: class HoldoutValTypes(IntEnum): + """TODO: change to enum using functools.partial""" """The type of hold out validation (refer to CrossValTypes' doc-string)""" holdout_validation = 6 stratified_holdout_validation = 7 From 0ab670c74963c7767511cd88a7fc3e573d213171 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Tue, 23 Feb 2021 05:00:57 +0900 Subject: [PATCH 5/5] Changed CrossValTypes to have val functions directly --- autoPyTorch/datasets/base_dataset.py | 10 +- autoPyTorch/datasets/resampling_strategy.py | 209 ++++++++++---------- autoPyTorch/datasets/time_series_dataset.py | 17 +- 3 files changed, 122 insertions(+), 114 deletions(-) diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 966cd8df1..7ebbcda2f 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -13,13 +13,11 @@ from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES from autoPyTorch.datasets.resampling_strategy import ( - CrossValFuncs, CrossValTypes, CrossValFunc, DEFAULT_RESAMPLING_PARAMETERS, HoldoutValTypes, - HoldOutFuncs, - HoldOutFunc + HoldOutValFunc ) from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix @@ -104,7 +102,7 @@ def __init__( type_check(train_tensors, val_tensors) self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors self.cross_validators: Dict[str, CrossValFunc] = {} - self.holdout_validators: Dict[str, HoldOutFunc] = {} + self.holdout_validators: Dict[str, HoldOutValFunc] = {} self.rng = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy @@ -125,8 +123,8 @@ def __init__( self.is_small_preprocess = True # Make sure cross validation splits are created once - self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) - self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes) + self.cross_validators = CrossValTypes.get_validators(*CrossValTypes) + self.holdout_validators = HoldoutValTypes.get_validators(*HoldoutValTypes) self.splits = self.get_splits_from_resampling_strategy() # We also need to be able to transform the data, be it for pre-processing diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index 1c8fea5fd..a2d595eae 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -1,4 +1,5 @@ -from enum import IntEnum +from enum import Enum +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -17,6 +18,7 @@ # Use callback protocol as workaround, since callable with function fields count 'self' as argument class CrossValFunc(Protocol): + """TODO: This class is not required anymore, because CrossValTypes class does not require get_validators()""" def __call__(self, num_splits: int, indices: np.ndarray, @@ -24,128 +26,51 @@ def __call__(self, ... -class HoldOutFunc(Protocol): +class HoldoutValFunc(Protocol): def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] ) -> Tuple[np.ndarray, np.ndarray]: ... -class CrossValTypes(IntEnum): - """The type of cross validation - - This class is used to specify the cross validation function - and is not supposed to be instantiated. - - Examples: This class is supposed to be used as follows - >>> cv_type = CrossValTypes.k_fold_cross_validation - >>> print(cv_type.name) - - k_fold_cross_validation - - >>> for cross_val_type in CrossValTypes: - print(cross_val_type.name, cross_val_type.value) - - stratified_k_fold_cross_validation 1 - k_fold_cross_validation 2 - stratified_shuffle_split_cross_validation 3 - shuffle_split_cross_validation 4 - time_series_cross_validation 5 - """ - stratified_k_fold_cross_validation = 1 - k_fold_cross_validation = 2 - stratified_shuffle_split_cross_validation = 3 - shuffle_split_cross_validation = 4 - time_series_cross_validation = 5 - - def is_stratified(self) -> bool: - stratified = [self.stratified_k_fold_cross_validation, - self.stratified_shuffle_split_cross_validation] - return getattr(self, self.name) in stratified - - -class HoldoutValTypes(IntEnum): - """TODO: change to enum using functools.partial""" - """The type of hold out validation (refer to CrossValTypes' doc-string)""" - holdout_validation = 6 - stratified_holdout_validation = 7 - - def is_stratified(self) -> bool: - stratified = [self.stratified_holdout_validation] - return getattr(self, self.name) in stratified - - -"""TODO: deprecate soon""" -RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] - -"""TODO: deprecate soon""" -DEFAULT_RESAMPLING_PARAMETERS = { - HoldoutValTypes.holdout_validation: { - 'val_share': 0.33, - }, - HoldoutValTypes.stratified_holdout_validation: { - 'val_share': 0.33, - }, - CrossValTypes.k_fold_cross_validation: { - 'num_splits': 3, - }, - CrossValTypes.stratified_k_fold_cross_validation: { - 'num_splits': 3, - }, - CrossValTypes.shuffle_split_cross_validation: { - 'num_splits': 3, - }, - CrossValTypes.time_series_cross_validation: { - 'num_splits': 3, - }, -} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] - - -class HoldOutFuncs(): +class HoldoutValFuncs(): @staticmethod - def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: + def holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \ + -> Tuple[np.ndarray, np.ndarray]: train, val = train_test_split(indices, test_size=val_share, shuffle=False) return train, val @staticmethod - def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \ + def stratified_holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \ -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"]) + train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=stratify) return train, val - @classmethod - def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, HoldOutFunc]: - - holdout_validators = { - holdout_val_type.name: getattr(cls, holdout_val_type.name) - for holdout_val_type in holdout_val_types - } - return holdout_validators - class CrossValFuncs(): @staticmethod - def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ -> List[Tuple[np.ndarray, np.ndarray]]: cv = ShuffleSplit(n_splits=num_splits) splits = list(cv.split(indices)) return splits @staticmethod - def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, + stratify: Optional[Any] = None) \ -> List[Tuple[np.ndarray, np.ndarray]]: cv = StratifiedShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) + splits = list(cv.split(indices, stratify)) return splits @staticmethod - def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ -> List[Tuple[np.ndarray, np.ndarray]]: cv = StratifiedKFold(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) + splits = list(cv.split(indices, stratify)) return splits @staticmethod - def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + def k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ -> List[Tuple[np.ndarray, np.ndarray]]: """ Standard k fold cross validation. @@ -159,7 +84,7 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) return splits @staticmethod - def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ + def time_series_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ -> List[Tuple[np.ndarray, np.ndarray]]: """ Returns train and validation indices respecting the temporal ordering of the data. @@ -176,10 +101,96 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: splits = list(cv.split(indices)) return splits - @classmethod - def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]: - cross_validators = { - cross_val_type.name: getattr(cls, cross_val_type.name) - for cross_val_type in cross_val_types - } - return cross_validators + +class CrossValTypes(Enum): + """The type of cross validation + + This class is used to specify the cross validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> cv_type = CrossValTypes.k_fold_cross_validation + >>> print(cv_type.name) + + k_fold_cross_validation + + >>> print(cv_type.value) + + functools.partial() + + >>> for cross_val_type in CrossValTypes: + print(cross_val_type.name) + + stratified_k_fold_cross_validation + k_fold_cross_validation + stratified_shuffle_split_cross_validation + shuffle_split_cross_validation + time_series_cross_validation + + Additionally, CrossValTypes. can be called directly. + """ + stratified_k_fold_cross_validation = partial(CrossValFuncs.stratified_k_fold_cross_validation) + k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation) + stratified_shuffle_split_cross_validation = partial(CrossValFuncs.stratified_shuffle_split_cross_validation) + shuffle_split_cross_validation = partial(CrossValFuncs.shuffle_split_cross_validation) + time_series_cross_validation = partial(CrossValFuncs.time_series_cross_validation) + + def is_stratified(self) -> bool: + stratified = [self.stratified_k_fold_cross_validation, + self.stratified_shuffle_split_cross_validation] + return getattr(self, self.name) in stratified + + def __call__(self, num_splits: int, indices: np.ndarray, stratify: Optional[Any] + ) -> Tuple[np.ndarray, np.ndarray]: + """TODO: doc-string and test files""" + self.value(num_splits=num_splits, indices=indices, stratify=stratify) + + @staticmethod + def get_validators(*choices: CrossValFunc): + """TODO: to be compatible, it is here now, but will be deprecated soon.""" + return {choice.name: choice.value for choice in choices} + + +class HoldoutValTypes(Enum): + """The type of hold out validation (refer to CrossValTypes' doc-string)""" + holdout_validation = partial(HoldoutValFuncs.holdout_validation) + stratified_holdout_validation = partial(HoldoutValFuncs.stratified_holdout_validation) + + def is_stratified(self) -> bool: + stratified = [self.stratified_holdout_validation] + return getattr(self, self.name) in stratified + + def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] + ) -> Tuple[np.ndarray, np.ndarray]: + self.value(val_share=val_share, indices=indices, stratify=stratify) + + @staticmethod + def get_validators(*choices: HoldoutValFunc): + """TODO: to be compatible, it is here now, but will be deprecated soon.""" + return {choice.name: choice.value for choice in choices} + + +"""TODO: deprecate soon (Will rename CrossValTypes -> CrossValFunc)""" +RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] + +"""TODO: deprecate soon""" +DEFAULT_RESAMPLING_PARAMETERS = { + HoldoutValTypes.holdout_validation: { + 'val_share': 0.33, + }, + HoldoutValTypes.stratified_holdout_validation: { + 'val_share': 0.33, + }, + CrossValTypes.k_fold_cross_validation: { + 'num_splits': 3, + }, + CrossValTypes.stratified_k_fold_cross_validation: { + 'num_splits': 3, + }, + CrossValTypes.shuffle_split_cross_validation: { + 'num_splits': 3, + }, + CrossValTypes.time_series_cross_validation: { + 'num_splits': 3, + }, +} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 5f4e2edf1..bb26af48f 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -7,9 +7,7 @@ from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, - HoldoutValTypes, - CrossValFuncs, - HoldOutFuncs + HoldoutValTypes ) TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported @@ -60,8 +58,9 @@ def __init__(self, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation) + """Comment: Do we really need those two? They are already defined in BaseDataset""" + self.cross_validators = CrossValTypes.get_cross_validators(CrossValTypes.time_series_cross_validation) + self.holdout_validators = HoldoutValTypes.get_holdout_validators(HoldoutValTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,13 +116,13 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = CrossValFuncs.get_cross_validators( + self.cross_validators = CrossValTypes.get_cross_validators( CrossValTypes.stratified_k_fold_cross_validation, CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation, CrossValTypes.stratified_shuffle_split_cross_validation ) - self.holdout_validators = HoldOutFuncs.get_holdout_validators( + self.holdout_validators = HoldoutValTypes.get_holdout_validators( HoldoutValTypes.holdout_validation, HoldoutValTypes.stratified_holdout_validation ) @@ -135,11 +134,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = CrossValFuncs.get_cross_validators( + self.cross_validators = CrossValTypes.get_cross_validators( CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation ) - self.holdout_validators = HoldOutFuncs.get_holdout_validators( + self.holdout_validators = HoldoutValTypes.get_holdout_validators( HoldoutValTypes.holdout_validation )