diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 7b0435d19..5f4e2edf1 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -8,8 +8,8 @@ from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, HoldoutValTypes, - get_cross_validators, - get_holdout_validators + CrossValFuncs, + HoldOutFuncs ) TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported @@ -60,8 +60,8 @@ def __init__(self, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation) + self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,13 +117,13 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.stratified_k_fold_cross_validation, CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation, CrossValTypes.stratified_shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation, HoldoutValTypes.stratified_holdout_validation ) @@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( + self.holdout_validators = HoldOutFuncs.get_holdout_validators( HoldoutValTypes.holdout_validation )