Skip to content

Commit

Permalink
[fix]: back to the renamed version of CROSS_VAL_FN from temporal Spli…
Browse files Browse the repository at this point in the history
…tFunc typing.
  • Loading branch information
nabenabe0928 committed Feb 22, 2021
1 parent ffde177 commit c470099
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 31 deletions.
9 changes: 5 additions & 4 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, Callable
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import numpy as np

Expand All @@ -15,14 +15,15 @@
from autoPyTorch.datasets.resampling_strategy import (
CrossValFuncs,
CrossValTypes,
CrossValFunc,
DEFAULT_RESAMPLING_PARAMETERS,
HoldoutValTypes,
HoldOutFuncs,
HoldOutFunc
)
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix

BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]


def check_valid_data(data: Any) -> None:
Expand Down Expand Up @@ -102,8 +103,8 @@ def __init__(
if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, SplitFunc] = {}
self.holdout_validators: Dict[str, SplitFunc] = {}
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.rng = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
Expand Down
32 changes: 5 additions & 27 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -15,21 +15,16 @@
from typing_extensions import Protocol


SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]


# Use callback protocol as workaround, since callable with function fields count 'self' as argument
class CROSS_VAL_FN(Protocol):
"""TODO: deprecate soon"""
class CrossValFunc(Protocol):
def __call__(self,
num_splits: int,
indices: np.ndarray,
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
...


class HOLDOUT_FN(Protocol):
"""TODO: deprecate soon"""
class HoldOutFunc(Protocol):
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
...
Expand Down Expand Up @@ -104,23 +99,6 @@ def is_stratified(self) -> bool:
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]


def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]:
"""TODO: deprecate soon"""
holdout_validators = {} # type: Dict[str, HOLDOUT_FN]
for holdout_val_type in holdout_val_types:
holdout_val_fn = globals()[holdout_val_type.name]
holdout_validators[holdout_val_type.name] = holdout_val_fn
return holdout_validators


def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool:
"""TODO: deprecate soon"""
if isinstance(val_type, str):
return val_type.lower().startswith("stratified")
else:
return val_type.name.lower().startswith("stratified")


class HoldOutFuncs():
@staticmethod
def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -134,7 +112,7 @@ def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwarg
return train, val

@classmethod
def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, SplitFunc]:
def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, HoldOutFunc]:

holdout_validators = {
holdout_val_type.name: getattr(cls, holdout_val_type.name)
Expand Down Expand Up @@ -198,7 +176,7 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
return splits

@classmethod
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, SplitFunc]:
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]:
cross_validators = {
cross_val_type.name: getattr(cls, cross_val_type.name)
for cross_val_type in cross_val_types
Expand Down

0 comments on commit c470099

Please sign in to comment.