Skip to content

Commit

Permalink
Modified time_series_dataset.py to be compatible with resampling_stra…
Browse files Browse the repository at this point in the history
…tegy.py
  • Loading branch information
nabenabe0928 committed Feb 19, 2021
1 parent c6d046b commit 87534fd
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions autoPyTorch/datasets/time_series_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
get_cross_validators,
get_holdout_validators
CrossValFuncs,
HoldOutFuncs
)

TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported
Expand Down Expand Up @@ -60,8 +60,8 @@ def __init__(self,
train_transforms=train_transforms,
val_transforms=val_transforms,
)
self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation)
self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation)
self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation)


def _check_time_series_forecasting_inputs(target_variables: Tuple[int],
Expand Down Expand Up @@ -117,13 +117,13 @@ def __init__(self,
val=val,
task_type="time_series_classification")
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
self.cross_validators = get_cross_validators(
self.cross_validators = CrossValFuncs.get_cross_validators(
CrossValTypes.stratified_k_fold_cross_validation,
CrossValTypes.k_fold_cross_validation,
CrossValTypes.shuffle_split_cross_validation,
CrossValTypes.stratified_shuffle_split_cross_validation
)
self.holdout_validators = get_holdout_validators(
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
HoldoutValTypes.holdout_validation,
HoldoutValTypes.stratified_holdout_validation
)
Expand All @@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.
val=val,
task_type="time_series_regression")
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
self.cross_validators = get_cross_validators(
self.cross_validators = CrossValFuncs.get_cross_validators(
CrossValTypes.k_fold_cross_validation,
CrossValTypes.shuffle_split_cross_validation
)
self.holdout_validators = get_holdout_validators(
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
HoldoutValTypes.holdout_validation
)

Expand Down

0 comments on commit 87534fd

Please sign in to comment.