diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index cb8f0841a..966cd8df1 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, Callable +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -15,14 +15,15 @@ from autoPyTorch.datasets.resampling_strategy import ( CrossValFuncs, CrossValTypes, + CrossValFunc, DEFAULT_RESAMPLING_PARAMETERS, HoldoutValTypes, HoldOutFuncs, + HoldOutFunc ) from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset] -SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] def check_valid_data(data: Any) -> None: @@ -102,8 +103,8 @@ def __init__( if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors - self.cross_validators: Dict[str, SplitFunc] = {} - self.holdout_validators: Dict[str, SplitFunc] = {} + self.cross_validators: Dict[str, CrossValFunc] = {} + self.holdout_validators: Dict[str, HoldOutFunc] = {} self.rng = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index 7d5a29039..860adadaa 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -15,12 +15,8 @@ from typing_extensions import Protocol -SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] - - # Use callback protocol as workaround, since callable with function fields count 'self' as argument -class CROSS_VAL_FN(Protocol): - """TODO: deprecate soon""" +class CrossValFunc(Protocol): def __call__(self, num_splits: int, indices: np.ndarray, @@ -28,8 +24,7 @@ def __call__(self, ... -class HOLDOUT_FN(Protocol): - """TODO: deprecate soon""" +class HoldOutFunc(Protocol): def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] ) -> Tuple[np.ndarray, np.ndarray]: ... @@ -104,23 +99,6 @@ def is_stratified(self) -> bool: } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] -def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]: - """TODO: deprecate soon""" - holdout_validators = {} # type: Dict[str, HOLDOUT_FN] - for holdout_val_type in holdout_val_types: - holdout_val_fn = globals()[holdout_val_type.name] - holdout_validators[holdout_val_type.name] = holdout_val_fn - return holdout_validators - - -def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool: - """TODO: deprecate soon""" - if isinstance(val_type, str): - return val_type.lower().startswith("stratified") - else: - return val_type.name.lower().startswith("stratified") - - class HoldOutFuncs(): @staticmethod def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: @@ -134,7 +112,7 @@ def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwarg return train, val @classmethod - def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, SplitFunc]: + def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, HoldOutFunc]: holdout_validators = { holdout_val_type.name: getattr(cls, holdout_val_type.name) @@ -198,7 +176,7 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: return splits @classmethod - def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, SplitFunc]: + def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]: cross_validators = { cross_val_type.name: getattr(cls, cross_val_type.name) for cross_val_type in cross_val_types