Skip to content

Commit

Permalink
Changed CrossValTypes to have val functions directly
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Feb 22, 2021
1 parent d05696b commit 0ab670c
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 114 deletions.
10 changes: 4 additions & 6 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CrossValFuncs,
CrossValTypes,
CrossValFunc,
DEFAULT_RESAMPLING_PARAMETERS,
HoldoutValTypes,
HoldOutFuncs,
HoldOutFunc
HoldOutValFunc
)
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix

Expand Down Expand Up @@ -104,7 +102,7 @@ def __init__(
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.holdout_validators: Dict[str, HoldOutValFunc] = {}
self.rng = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
Expand All @@ -125,8 +123,8 @@ def __init__(
self.is_small_preprocess = True

# Make sure cross validation splits are created once
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
self.cross_validators = CrossValTypes.get_validators(*CrossValTypes)
self.holdout_validators = HoldoutValTypes.get_validators(*HoldoutValTypes)
self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down
209 changes: 110 additions & 99 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import IntEnum
from enum import Enum
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -17,135 +18,59 @@

# Use callback protocol as workaround, since callable with function fields count 'self' as argument
class CrossValFunc(Protocol):
"""TODO: This class is not required anymore, because CrossValTypes class does not require get_validators()"""
def __call__(self,
num_splits: int,
indices: np.ndarray,
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
...


class HoldOutFunc(Protocol):
class HoldoutValFunc(Protocol):
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
...


class CrossValTypes(IntEnum):
"""The type of cross validation
This class is used to specify the cross validation function
and is not supposed to be instantiated.
Examples: This class is supposed to be used as follows
>>> cv_type = CrossValTypes.k_fold_cross_validation
>>> print(cv_type.name)
k_fold_cross_validation
>>> for cross_val_type in CrossValTypes:
print(cross_val_type.name, cross_val_type.value)
stratified_k_fold_cross_validation 1
k_fold_cross_validation 2
stratified_shuffle_split_cross_validation 3
shuffle_split_cross_validation 4
time_series_cross_validation 5
"""
stratified_k_fold_cross_validation = 1
k_fold_cross_validation = 2
stratified_shuffle_split_cross_validation = 3
shuffle_split_cross_validation = 4
time_series_cross_validation = 5

def is_stratified(self) -> bool:
stratified = [self.stratified_k_fold_cross_validation,
self.stratified_shuffle_split_cross_validation]
return getattr(self, self.name) in stratified


class HoldoutValTypes(IntEnum):
"""TODO: change to enum using functools.partial"""
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
holdout_validation = 6
stratified_holdout_validation = 7

def is_stratified(self) -> bool:
stratified = [self.stratified_holdout_validation]
return getattr(self, self.name) in stratified


"""TODO: deprecate soon"""
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]

"""TODO: deprecate soon"""
DEFAULT_RESAMPLING_PARAMETERS = {
HoldoutValTypes.holdout_validation: {
'val_share': 0.33,
},
HoldoutValTypes.stratified_holdout_validation: {
'val_share': 0.33,
},
CrossValTypes.k_fold_cross_validation: {
'num_splits': 3,
},
CrossValTypes.stratified_k_fold_cross_validation: {
'num_splits': 3,
},
CrossValTypes.shuffle_split_cross_validation: {
'num_splits': 3,
},
CrossValTypes.time_series_cross_validation: {
'num_splits': 3,
},
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]


class HoldOutFuncs():
class HoldoutValFuncs():
@staticmethod
def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
def holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \
-> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=False)
return train, val

@staticmethod
def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \
def stratified_holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \
-> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"])
train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=stratify)
return train, val

@classmethod
def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, HoldOutFunc]:

holdout_validators = {
holdout_val_type.name: getattr(cls, holdout_val_type.name)
for holdout_val_type in holdout_val_types
}
return holdout_validators


class CrossValFuncs():
@staticmethod
def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = ShuffleSplit(n_splits=num_splits)
splits = list(cv.split(indices))
return splits

@staticmethod
def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray,
stratify: Optional[Any] = None) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedShuffleSplit(n_splits=num_splits)
splits = list(cv.split(indices, kwargs["stratify"]))
splits = list(cv.split(indices, stratify))
return splits

@staticmethod
def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedKFold(n_splits=num_splits)
splits = list(cv.split(indices, kwargs["stratify"]))
splits = list(cv.split(indices, stratify))
return splits

@staticmethod
def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
def k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
-> List[Tuple[np.ndarray, np.ndarray]]:
"""
Standard k fold cross validation.
Expand All @@ -159,7 +84,7 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any)
return splits

@staticmethod
def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
def time_series_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
-> List[Tuple[np.ndarray, np.ndarray]]:
"""
Returns train and validation indices respecting the temporal ordering of the data.
Expand All @@ -176,10 +101,96 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
splits = list(cv.split(indices))
return splits

@classmethod
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]:
cross_validators = {
cross_val_type.name: getattr(cls, cross_val_type.name)
for cross_val_type in cross_val_types
}
return cross_validators

class CrossValTypes(Enum):
"""The type of cross validation
This class is used to specify the cross validation function
and is not supposed to be instantiated.
Examples: This class is supposed to be used as follows
>>> cv_type = CrossValTypes.k_fold_cross_validation
>>> print(cv_type.name)
k_fold_cross_validation
>>> print(cv_type.value)
functools.partial(<function CrossValTypes.k_fold_cross_validation at ...>)
>>> for cross_val_type in CrossValTypes:
print(cross_val_type.name)
stratified_k_fold_cross_validation
k_fold_cross_validation
stratified_shuffle_split_cross_validation
shuffle_split_cross_validation
time_series_cross_validation
Additionally, CrossValTypes.<function> can be called directly.
"""
stratified_k_fold_cross_validation = partial(CrossValFuncs.stratified_k_fold_cross_validation)
k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation)
stratified_shuffle_split_cross_validation = partial(CrossValFuncs.stratified_shuffle_split_cross_validation)
shuffle_split_cross_validation = partial(CrossValFuncs.shuffle_split_cross_validation)
time_series_cross_validation = partial(CrossValFuncs.time_series_cross_validation)

def is_stratified(self) -> bool:
stratified = [self.stratified_k_fold_cross_validation,
self.stratified_shuffle_split_cross_validation]
return getattr(self, self.name) in stratified

def __call__(self, num_splits: int, indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
"""TODO: doc-string and test files"""
self.value(num_splits=num_splits, indices=indices, stratify=stratify)

@staticmethod
def get_validators(*choices: CrossValFunc):
"""TODO: to be compatible, it is here now, but will be deprecated soon."""
return {choice.name: choice.value for choice in choices}


class HoldoutValTypes(Enum):
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
holdout_validation = partial(HoldoutValFuncs.holdout_validation)
stratified_holdout_validation = partial(HoldoutValFuncs.stratified_holdout_validation)

def is_stratified(self) -> bool:
stratified = [self.stratified_holdout_validation]
return getattr(self, self.name) in stratified

def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
self.value(val_share=val_share, indices=indices, stratify=stratify)

@staticmethod
def get_validators(*choices: HoldoutValFunc):
"""TODO: to be compatible, it is here now, but will be deprecated soon."""
return {choice.name: choice.value for choice in choices}


"""TODO: deprecate soon (Will rename CrossValTypes -> CrossValFunc)"""
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]

"""TODO: deprecate soon"""
DEFAULT_RESAMPLING_PARAMETERS = {
HoldoutValTypes.holdout_validation: {
'val_share': 0.33,
},
HoldoutValTypes.stratified_holdout_validation: {
'val_share': 0.33,
},
CrossValTypes.k_fold_cross_validation: {
'num_splits': 3,
},
CrossValTypes.stratified_k_fold_cross_validation: {
'num_splits': 3,
},
CrossValTypes.shuffle_split_cross_validation: {
'num_splits': 3,
},
CrossValTypes.time_series_cross_validation: {
'num_splits': 3,
},
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
17 changes: 8 additions & 9 deletions autoPyTorch/datasets/time_series_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
CrossValFuncs,
HoldOutFuncs
HoldoutValTypes
)

TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported
Expand Down Expand Up @@ -60,8 +58,9 @@ def __init__(self,
train_transforms=train_transforms,
val_transforms=val_transforms,
)
self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation)
"""Comment: Do we really need those two? They are already defined in BaseDataset"""
self.cross_validators = CrossValTypes.get_cross_validators(CrossValTypes.time_series_cross_validation)
self.holdout_validators = HoldoutValTypes.get_holdout_validators(HoldoutValTypes.holdout_validation)


def _check_time_series_forecasting_inputs(target_variables: Tuple[int],
Expand Down Expand Up @@ -117,13 +116,13 @@ def __init__(self,
val=val,
task_type="time_series_classification")
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
self.cross_validators = CrossValFuncs.get_cross_validators(
self.cross_validators = CrossValTypes.get_cross_validators(
CrossValTypes.stratified_k_fold_cross_validation,
CrossValTypes.k_fold_cross_validation,
CrossValTypes.shuffle_split_cross_validation,
CrossValTypes.stratified_shuffle_split_cross_validation
)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
self.holdout_validators = HoldoutValTypes.get_holdout_validators(
HoldoutValTypes.holdout_validation,
HoldoutValTypes.stratified_holdout_validation
)
Expand All @@ -135,11 +134,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.
val=val,
task_type="time_series_regression")
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
self.cross_validators = CrossValFuncs.get_cross_validators(
self.cross_validators = CrossValTypes.get_cross_validators(
CrossValTypes.k_fold_cross_validation,
CrossValTypes.shuffle_split_cross_validation
)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
self.holdout_validators = HoldoutValTypes.get_holdout_validators(
HoldoutValTypes.holdout_validation
)

Expand Down

0 comments on commit 0ab670c

Please sign in to comment.