From 87534fd6f6a2e64f0e1bfe203412f1655a7011d6 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Sat, 20 Feb 2021 03:24:04 +0900 Subject: [PATCH] Modified time_series_dataset.py to be compatible with resampling_strategy.py --- autoPyTorch/datasets/time_series_dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 7b0435d19..5f4e2edf1 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -8,8 +8,8 @@ from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, HoldoutValTypes, - get_cross_validators, - get_holdout_validators + CrossValFuncs, + HoldOutFuncs ) TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported @@ -60,8 +60,8 @@ def __init__(self, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation) + self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,13 +117,13 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.stratified_k_fold_cross_validation, CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation, CrossValTypes.stratified_shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation, HoldoutValTypes.stratified_holdout_validation ) @@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation )