diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index c56141904..0233b69a4 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -167,42 +167,82 @@ class FixedPipelineParams(NamedTuple): search_space_updates (Optional[HyperparameterSearchSpaceUpdates]): An object used to fine tune the hyperparameter search space of the pipeline """ - def __init__(self, backend: Backend, - queue: Queue, - metric: autoPyTorchMetric, - budget: float, - configuration: Union[int, str, Configuration], - budget_type: str = None, - pipeline_config: Optional[Dict[str, Any]] = None, - seed: int = 1, - output_y_hat_optimization: bool = True, - num_run: Optional[int] = None, - include: Optional[Dict[str, Any]] = None, - exclude: Optional[Dict[str, Any]] = None, - disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, - init_params: Optional[Dict[str, Any]] = None, - logger_port: Optional[int] = None, - all_supported_metrics: bool = True, - search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None - ) -> None: - - self.starttime = time.time() - - self.configuration = configuration - self.backend: Backend = backend - self.queue = queue - - self.include = include - self.exclude = exclude - self.search_space_updates = search_space_updates - - self.metric = metric - - - self._init_datamanager_info() - - # Flag to save target for ensemble - self.output_y_hat_optimization = output_y_hat_optimization + backend: Backend + seed: int + metric: autoPyTorchMetric + budget_type: str # Literal['epochs', 'runtime'] + pipeline_config: Dict[str, Any] + save_y_opt: bool = True + include: Optional[Dict[str, Any]] = None + exclude: Optional[Dict[str, Any]] = None + disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None + logger_port: Optional[int] = None + all_supported_metrics: bool = True + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + + @classmethod + def with_default_pipeline_config( + cls, + pipeline_config: Optional[Dict[str, Any]] = None, + choice: str = 'default', + **kwargs: Any + ) -> 'FixedPipelineParams': + + if 'budget_type' in kwargs: + raise TypeError( + f'{cls.__name__}.with_default_pipeline_config() got multiple values for argument `budget_type`' + ) + + budget_type_choices = ('epochs', 'runtime') + if pipeline_config is None: + pipeline_config = get_default_pipeline_config(choice=choice) + if 'budget_type' not in pipeline_config: + raise ValueError('pipeline_config must have `budget_type`') + + budget_type = pipeline_config['budget_type'] + if pipeline_config['budget_type'] not in budget_type_choices: + raise ValueError(f"budget_type must be in {budget_type_choices}, but got {budget_type}") + + kwargs.update(pipeline_config=pipeline_config, budget_type=budget_type) + return cls(**kwargs) + + +class EvaluatorParams(NamedTuple): + """ + Attributes: + configuration (Union[int, str, Configuration]): + Determines the pipeline to be constructed. A dummy estimator is created for + integer configurations, a traditional machine learning pipeline is created + for string based configuration, and NAS is performed when a configuration + object is passed. + num_run (Optional[int]): + An identifier of the current configuration being fit. This number is unique per + configuration. + init_params (Optional[Dict[str, Any]]): + Optional argument that is passed to each pipeline step. It is the equivalent of + kwargs for the pipeline steps. + """ + budget: float + configuration: Union[int, str, Configuration] + num_run: Optional[int] = None + init_params: Optional[Dict[str, Any]] = None + + @classmethod + def with_default_budget( + cls, + budget: float = 0, + choice: str = 'default', + **kwargs: Any + ) -> 'EvaluatorParams': + budget = get_default_budget(choice=choice) if budget == 0 else budget + kwargs.update(budget=budget) + return cls(**kwargs) + + +class AbstractEvaluator(object): + """ + This method defines the interface that pipeline evaluators should follow, when + interacting with SMAC through TargetAlgorithmQuery. An evaluator is an object that: + constructs a pipeline (i.e. a classification or regression estimator) for a given @@ -244,21 +284,33 @@ def _init_miscellaneous(self) -> None: DisableFileOutputParameters.check_compatibility(disable_file_output) self.disable_file_output = disable_file_output else: - if isinstance(self.configuration, int): - self.pipeline_class = DummyClassificationPipeline - elif isinstance(self.configuration, str): - if self.task_type in TABULAR_TASKS: - self.pipeline_class = MyTraditionalTabularClassificationPipeline - else: - raise ValueError("Only tabular tasks are currently supported with traditional methods") - elif isinstance(self.configuration, Configuration): - if self.task_type in TABULAR_TASKS: - self.pipeline_class = autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline - elif self.task_type in IMAGE_TASKS: - self.pipeline_class = autoPyTorch.pipeline.image_classification.ImageClassificationPipeline - else: - raise ValueError('task {} not available'.format(self.task_type)) - self.predict_function = self._predict_proba + self.disable_file_output = [] + + if self.num_folds == 1: # not save cv model when we perform holdout + self.disable_file_output.append('cv_model') + + def _init_dataset_properties(self) -> None: + datamanager: BaseDataset = self.fixed_pipeline_params.backend.load_datamanager() + if datamanager.task_type is None: + raise ValueError(f"Expected dataset {datamanager.__class__.__name__} to have task_type got None") + if datamanager.splits is None: + raise ValueError(f"cannot fit pipeline {self.__class__.__name__} with datamanager.splits None") + + self.splits = datamanager.splits + self.num_folds: int = len(self.splits) + # Since cv might not finish in time, we take self.pipelines as None by default + self.pipelines: List[Optional[BaseEstimator]] = [None] * self.num_folds + self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type] + self.num_classes = getattr(datamanager, 'num_classes', 1) + self.output_type = datamanager.output_type + + search_space_updates = self.fixed_pipeline_params.search_space_updates + self.dataset_properties = datamanager.get_dataset_properties( + get_dataset_requirements(info=datamanager.get_required_dataset_info(), + include=self.fixed_pipeline_params.include, + exclude=self.fixed_pipeline_params.exclude, + search_space_updates=search_space_updates + )) self.X_train, self.y_train = datamanager.train_tensors self.unique_train_labels = [ @@ -271,6 +323,8 @@ def _init_miscellaneous(self) -> None: if datamanager.test_tensors is not None: self.X_test, self.y_test = datamanager.test_tensors + del datamanager # Delete datamanager to release the memory + def _init_additional_metrics(self) -> None: all_supported_metrics = self.fixed_pipeline_params.all_supported_metrics metric = self.fixed_pipeline_params.metric @@ -282,59 +336,7 @@ def _init_additional_metrics(self) -> None: all_supported_metrics=all_supported_metrics) self.metrics_dict = {'additional_metrics': [m.name for m in [metric] + self.additional_metrics]} - def _init_datamanager_info( - self, - ) -> None: - """ - Initialises instance attributes that come from the datamanager. - For example, - X_train, y_train, etc. - """ - - datamanager: BaseDataset = self.backend.load_datamanager() - - assert datamanager.task_type is not None, \ - "Expected dataset {} to have task_type got None".format(datamanager.__class__.__name__) - self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type] - self.output_type = STRING_TO_OUTPUT_TYPES[datamanager.output_type] - self.issparse = datamanager.issparse - - self.X_train, self.y_train = datamanager.train_tensors - - if datamanager.val_tensors is not None: - self.X_valid, self.y_valid = datamanager.val_tensors - else: - self.X_valid, self.y_valid = None, None - - if datamanager.test_tensors is not None: - self.X_test, self.y_test = datamanager.test_tensors - else: - self.X_test, self.y_test = None, None - - self.resampling_strategy = datamanager.resampling_strategy - - self.num_classes: Optional[int] = getattr(datamanager, "num_classes", None) - - self.dataset_properties = datamanager.get_dataset_properties( - get_dataset_requirements(info=datamanager.get_required_dataset_info(), - include=self.include, - exclude=self.exclude, - search_space_updates=self.search_space_updates - )) - self.splits = datamanager.splits - if self.splits is None: - raise AttributeError(f"create_splits on {datamanager.__class__.__name__} must be called " - f"before the instantiation of {self.__class__.__name__}") - - # delete datamanager from memory - del datamanager - - def _init_fit_dictionary( - self, - logger_port: int, - pipeline_config: Dict[str, Any], - metrics_dict: Optional[Dict[str, List[str]]] = None, - ) -> None: + def _init_fit_dictionary(self) -> None: """ Initialises the fit dictionary @@ -617,36 +619,4 @@ def _is_output_possible( if y is not None and not np.all(np.isfinite(y)): return False # Model predictions contains NaNs - Args: - prediction (np.ndarray): - The un-formatted predictions of a pipeline - Y_train (np.ndarray): - The labels from the dataset to give an intuition of the expected - predictions dimensionality - Returns: - (np.ndarray): - The formatted prediction - """ - assert self.num_classes is not None, "Called function on wrong task" - - if self.output_type == MULTICLASS and \ - prediction.shape[1] < self.num_classes: - if Y_train is None: - raise ValueError('Y_train must not be None!') - classes = list(np.unique(Y_train)) - - mapping = dict() - for class_number in range(self.num_classes): - if class_number in classes: - index = classes.index(class_number) - mapping[index] = class_number - new_predictions = np.zeros((prediction.shape[0], self.num_classes), - dtype=np.float32) - - for index in mapping: - class_index = mapping[index] - new_predictions[:, class_index] = prediction[:, index] - - return new_predictions - - return prediction + return True diff --git a/autoPyTorch/evaluation/evaluator.py b/autoPyTorch/evaluation/evaluator.py index a6eef9b2b..887e1548b 100644 --- a/autoPyTorch/evaluation/evaluator.py +++ b/autoPyTorch/evaluation/evaluator.py @@ -111,14 +111,15 @@ def __init__(self, queue: Queue, fixed_pipeline_params: FixedPipelineParams, eva self.train = not isinstance(resampling_strategy, NoResamplingStrategyTypes) self.cross_validation = isinstance(resampling_strategy, CrossValTypes) - if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)): - raise ValueError( - f'resampling_strategy for TrainEvaluator must be in ' - f'(CrossValTypes, HoldoutValTypes), but got {self.resampling_strategy}' - ) + if not self.train and fixed_pipeline_params.save_y_opt: + # TODO: Add the test to cover here + # No resampling can not be used for building ensembles. save_y_opt=False ensures it + fixed_pipeline_params = fixed_pipeline_params._replace(save_y_opt=False) + + super().__init__(queue=queue, fixed_pipeline_params=fixed_pipeline_params, evaluator_params=evaluator_params) - self.num_folds: int = len(self.splits) - self.logger.debug("Search space updates :{}".format(self.search_space_updates)) + if self.train: + self.logger.debug("Search space updates :{}".format(self.fixed_pipeline_params.search_space_updates)) def _evaluate_on_split(self, split_id: int) -> EvaluationResults: """ diff --git a/autoPyTorch/evaluation/test_evaluator.py b/autoPyTorch/evaluation/test_evaluator.py deleted file mode 100644 index 4d5b0ae91..000000000 --- a/autoPyTorch/evaluation/test_evaluator.py +++ /dev/null @@ -1,236 +0,0 @@ -from multiprocessing.queues import Queue -from typing import Any, Dict, List, Optional, Tuple, Union - -from ConfigSpace.configuration_space import Configuration - -import numpy as np - -from smac.tae import StatusType - -from autoPyTorch.automl_common.common.utils.backend import Backend -from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes -from autoPyTorch.evaluation.abstract_evaluator import ( - AbstractEvaluator, - fit_and_suppress_warnings -) -from autoPyTorch.evaluation.utils import DisableFileOutputParameters -from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric -from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates - - -__all__ = [ - 'eval_test_function', - 'TestEvaluator' -] - - -class TestEvaluator(AbstractEvaluator): - """ - This class builds a pipeline using the provided configuration. - A pipeline implementing the provided configuration is fitted - using the datamanager object retrieved from disc, via the backend. - After the pipeline is fitted, it is save to disc and the performance estimate - is communicated to the main process via a Queue. It is only compatible - with `NoResamplingStrategyTypes`, i.e, when the training data - is not split and the test set is used for SMBO optimisation. It can not - be used for building ensembles which is ensured by having - `output_y_hat_optimisation`=False - - Attributes: - backend (Backend): - An object to interface with the disk storage. In particular, allows to - access the train and test datasets - queue (Queue): - Each worker available will instantiate an evaluator, and after completion, - it will return the evaluation result via a multiprocessing queue - metric (autoPyTorchMetric): - A scorer object that is able to evaluate how good a pipeline was fit. It - is a wrapper on top of the actual score method (a wrapper on top of scikit - lean accuracy for example) that formats the predictions accordingly. - budget: (float): - The amount of epochs/time a configuration is allowed to run. - budget_type (str): - The budget type, which can be epochs or time - pipeline_config (Optional[Dict[str, Any]]): - Defines the content of the pipeline being evaluated. For example, it - contains pipeline specific settings like logging name, or whether or not - to use tensorboard. - configuration (Union[int, str, Configuration]): - Determines the pipeline to be constructed. A dummy estimator is created for - integer configurations, a traditional machine learning pipeline is created - for string based configuration, and NAS is performed when a configuration - object is passed. - seed (int): - A integer that allows for reproducibility of results - output_y_hat_optimization (bool): - Whether this worker should output the target predictions, so that they are - stored on disk. Fundamentally, the resampling strategy might shuffle the - Y_train targets, so we store the split in order to re-use them for ensemble - selection. - num_run (Optional[int]): - An identifier of the current configuration being fit. This number is unique per - configuration. - include (Optional[Dict[str, Any]]): - An optional dictionary to include components of the pipeline steps. - exclude (Optional[Dict[str, Any]]): - An optional dictionary to exclude components of the pipeline steps. - disable_file_output (Optional[List[Union[str, DisableFileOutputParameters]]]): - Used as a list to pass more fine-grained - information on what to save. Must be a member of `DisableFileOutputParameters`. - Allowed elements in the list are: - - + `y_optimization`: - do not save the predictions for the optimization set, - which would later on be used to build an ensemble. Note that SMAC - optimizes a metric evaluated on the optimization set. - + `pipeline`: - do not save any individual pipeline files - + `pipelines`: - In case of cross validation, disables saving the joint model of the - pipelines fit on each fold. - + `y_test`: - do not save the predictions for the test set. - + `all`: - do not save any of the above. - For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`. - init_params (Optional[Dict[str, Any]]): - Optional argument that is passed to each pipeline step. It is the equivalent of - kwargs for the pipeline steps. - logger_port (Optional[int]): - Logging is performed using a socket-server scheme to be robust against many - parallel entities that want to write to the same file. This integer states the - socket port for the communication channel. If None is provided, a traditional - logger is used. - all_supported_metrics (bool): - Whether all supported metric should be calculated for every configuration. - search_space_updates (Optional[HyperparameterSearchSpaceUpdates]): - An object used to fine tune the hyperparameter search space of the pipeline - """ - def __init__( - self, - backend: Backend, queue: Queue, - metric: autoPyTorchMetric, - budget: float, - configuration: Union[int, str, Configuration], - budget_type: str = None, - pipeline_config: Optional[Dict[str, Any]] = None, - seed: int = 1, - output_y_hat_optimization: bool = False, - num_run: Optional[int] = None, - include: Optional[Dict[str, Any]] = None, - exclude: Optional[Dict[str, Any]] = None, - disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, - init_params: Optional[Dict[str, Any]] = None, - logger_port: Optional[int] = None, - all_supported_metrics: bool = True, - search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None - ) -> None: - super().__init__( - backend=backend, - queue=queue, - configuration=configuration, - metric=metric, - seed=seed, - output_y_hat_optimization=output_y_hat_optimization, - num_run=num_run, - include=include, - exclude=exclude, - disable_file_output=disable_file_output, - init_params=init_params, - budget=budget, - budget_type=budget_type, - logger_port=logger_port, - all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, - search_space_updates=search_space_updates - ) - - if not isinstance(self.resampling_strategy, (NoResamplingStrategyTypes)): - raise ValueError( - f'resampling_strategy for TestEvaluator must be in ' - f'NoResamplingStrategyTypes, but got {self.resampling_strategy}' - ) - - def fit_predict_and_loss(self) -> None: - - split_id = 0 - train_indices, test_indices = self.splits[split_id] - - self.pipeline = self._get_pipeline() - X = {'train_indices': train_indices, - 'val_indices': test_indices, - 'split_id': split_id, - 'num_run': self.num_run, - **self.fit_dictionary} # fit dictionary - y = None - fit_and_suppress_warnings(self.logger, self.pipeline, X, y) - train_loss, _ = self.predict_and_loss(train=True) - test_loss, test_pred = self.predict_and_loss() - self.Y_optimization = self.y_test - self.finish_up( - loss=test_loss, - train_loss=train_loss, - opt_pred=test_pred, - valid_pred=None, - test_pred=test_pred, - file_output=True, - additional_run_info=None, - status=StatusType.SUCCESS, - ) - - def predict_and_loss( - self, train: bool = False - ) -> Tuple[Dict[str, float], np.ndarray]: - labels = self.y_train if train else self.y_test - feats = self.X_train if train else self.X_test - preds = self.predict_function( - X=feats, - pipeline=self.pipeline, - Y_train=self.y_train # Need this as we need to know all the classes in train splits - ) - loss_dict = self._loss(labels, preds) - - return loss_dict, preds - - -# create closure for evaluating an algorithm -def eval_test_function( - backend: Backend, - queue: Queue, - metric: autoPyTorchMetric, - budget: float, - config: Optional[Configuration], - seed: int, - output_y_hat_optimization: bool, - num_run: int, - include: Optional[Dict[str, Any]], - exclude: Optional[Dict[str, Any]], - disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, - pipeline_config: Optional[Dict[str, Any]] = None, - budget_type: str = None, - init_params: Optional[Dict[str, Any]] = None, - logger_port: Optional[int] = None, - all_supported_metrics: bool = True, - search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None, - instance: str = None, -) -> None: - evaluator = TestEvaluator( - backend=backend, - queue=queue, - metric=metric, - configuration=config, - seed=seed, - num_run=num_run, - output_y_hat_optimization=output_y_hat_optimization, - include=include, - exclude=exclude, - disable_file_output=disable_file_output, - init_params=init_params, - budget=budget, - budget_type=budget_type, - logger_port=logger_port, - all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, - search_space_updates=search_space_updates) - - evaluator.fit_predict_and_loss() diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 69ec9f321..4ace9ba0d 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -716,3 +716,29 @@ def test_pipeline_fit_error( assert 'TIMEOUT' in str(run_value.status) assert pipeline is None + + +@pytest.mark.parametrize("ans,task_class", ( + ("continuous", TabularRegressionTask), + ("multiclass", TabularClassificationTask)) +) +def test_task_inference(ans, task_class, backend): + # Get the data and check that contents of data-manager make sense + X = np.random.random((6, 1)) + y = np.array([-10 ** 12, 0, 1, 2, 3, 4], dtype=np.int64) + 10 ** 12 + + estimator = task_class( + backend=backend, + resampling_strategy=HoldoutValTypes.holdout_validation, + resampling_strategy_args=None, + seed=42, + ) + dataset = estimator.get_dataset(X, y) + assert dataset.output_type == ans + + y += 10 ** 12 + 10 # Check if the function catches overflow possibilities + if ans == 'continuous': + with pytest.raises(ValueError): # ValueError due to `Too large value` + estimator.get_dataset(X, y) + else: + estimator.get_dataset(X, y) diff --git a/test/test_datasets/test_resampling_strategies.py b/test/test_datasets/test_resampling_strategies.py index 7f14275a3..473f17182 100644 --- a/test/test_datasets/test_resampling_strategies.py +++ b/test/test_datasets/test_resampling_strategies.py @@ -1,6 +1,15 @@ import numpy as np -from autoPyTorch.datasets.resampling_strategy import CrossValFuncs, HoldOutFuncs +import pytest + +from autoPyTorch.datasets.resampling_strategy import ( + CrossValFuncs, + CrossValTypes, + HoldOutFuncs, + HoldoutValTypes, + NoResamplingStrategyTypes, + check_resampling_strategy +) def test_holdoutfuncs(): @@ -40,3 +49,12 @@ def test_crossvalfuncs(): splits = split.stratified_k_fold_cross_validation(0, 10, X, stratify=y) assert len(splits) == 10 assert all([0 in y[s[1]] for s in splits]) + + +def test_check_resampling_strategy(): + for rs in (CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes): + for rs_func in rs: + check_resampling_strategy(rs_func) + + with pytest.raises(ValueError): + check_resampling_strategy(None) diff --git a/test/test_evaluation/test_evaluators.py b/test/test_evaluation/test_evaluators.py index 449bf8d4a..2371522d8 100644 --- a/test/test_evaluation/test_evaluators.py +++ b/test/test_evaluation/test_evaluators.py @@ -143,26 +143,7 @@ def tearDown(self): if os.path.exists(self.ev_path): shutil.rmtree(self.ev_path) - def test_evaluate_loss(self): - D = get_binary_classification_datamanager() - backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') - backend_api.load_datamanager = lambda: D - fixed_params_dict = self.fixed_params._asdict() - fixed_params_dict.update(backend=backend_api) - evaluator = Evaluator( - queue=multiprocessing.Queue(), - fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict), - evaluator_params=self.eval_params - ) - evaluator.splits = None - with pytest.raises(ValueError): - evaluator.evaluate_loss() - - @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') - def test_holdout(self, pipeline_mock): - pipeline_mock.fit_dictionary = {'budget_type': 'epochs', 'epochs': 50} - # Binary iris, contains 69 train samples, 31 test samples - D = get_binary_classification_datamanager() + def _get_evaluator(self, pipeline_mock, data): pipeline_mock.predict_proba.side_effect = \ lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1)) pipeline_mock.side_effect = lambda **kwargs: pipeline_mock @@ -170,7 +151,7 @@ def test_holdout(self, pipeline_mock): _queue = multiprocessing.Queue() backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') - backend_api.load_datamanager = lambda: D + backend_api.load_datamanager = lambda: data fixed_params_dict = self.fixed_params._asdict() fixed_params_dict.update(backend=backend_api) @@ -184,56 +165,72 @@ def test_holdout(self, pipeline_mock): evaluator.evaluate_loss() + return evaluator + + def _check_results(self, evaluator, ans): rval = read_queue(evaluator.queue) self.assertEqual(len(rval), 1) result = rval[0]['loss'] self.assertEqual(len(rval[0]), 3) self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1) - + self.assertEqual(result, ans) self.assertEqual(evaluator._save_to_backend.call_count, 1) - self.assertEqual(result, 0.5652173913043479) - self.assertEqual(pipeline_mock.fit.call_count, 1) - # 3 calls because of train, holdout and test set - self.assertEqual(pipeline_mock.predict_proba.call_count, 3) - call_args = evaluator._save_to_backend.call_args - self.assertEqual(call_args[0][0].shape[0], len(D.splits[0][1])) - self.assertIsNone(call_args[0][1]) - self.assertEqual(call_args[0][2].shape[0], D.test_tensors[1].shape[0]) - self.assertEqual(evaluator.pipelines[0].fit.call_count, 1) - @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') - def test_cv(self, pipeline_mock): - D = get_binary_classification_datamanager(resampling_strategy=CrossValTypes.k_fold_cross_validation) + def _check_whether_save_y_opt_is_correct(self, resampling_strategy, ans): + backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') + D = get_binary_classification_datamanager(resampling_strategy) + backend_api.load_datamanager = lambda: D + fixed_params_dict = self.fixed_params._asdict() + fixed_params_dict.update(backend=backend_api, save_y_opt=True) + evaluator = Evaluator( + queue=multiprocessing.Queue(), + fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict), + evaluator_params=self.eval_params + ) + assert evaluator.fixed_pipeline_params.save_y_opt == ans - pipeline_mock.predict_proba.side_effect = \ - lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1)) - pipeline_mock.side_effect = lambda **kwargs: pipeline_mock - pipeline_mock.get_additional_run_info.return_value = None + def test_whether_save_y_opt_is_correct_for_no_resampling(self): + self._check_whether_save_y_opt_is_correct(NoResamplingStrategyTypes.no_resampling, False) - _queue = multiprocessing.Queue() + def test_whether_save_y_opt_is_correct_for_resampling(self): + self._check_whether_save_y_opt_is_correct(CrossValTypes.k_fold_cross_validation, True) + + def test_evaluate_loss(self): + D = get_binary_classification_datamanager() backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') backend_api.load_datamanager = lambda: D - fixed_params_dict = self.fixed_params._asdict() fixed_params_dict.update(backend=backend_api) evaluator = Evaluator( - queue=_queue, + queue=multiprocessing.Queue(), fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict), evaluator_params=self.eval_params ) - evaluator._save_to_backend = unittest.mock.Mock(spec=evaluator._save_to_backend) - evaluator._save_to_backend.return_value = True + evaluator.splits = None + with pytest.raises(ValueError): + evaluator.evaluate_loss() - evaluator.evaluate_loss() + @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') + def test_holdout(self, pipeline_mock): + D = get_binary_classification_datamanager() + evaluator = self._get_evaluator(pipeline_mock, D) + self._check_results(evaluator, ans=0.5652173913043479) - rval = read_queue(evaluator.queue) - self.assertEqual(len(rval), 1) - result = rval[0]['loss'] - self.assertEqual(len(rval[0]), 3) - self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1) + self.assertEqual(pipeline_mock.fit.call_count, 1) + # 3 calls because of train, holdout and test set + self.assertEqual(pipeline_mock.predict_proba.call_count, 3) + call_args = evaluator._save_to_backend.call_args + self.assertEqual(call_args[0][0].shape[0], len(D.splits[0][1])) + self.assertIsNone(call_args[0][1]) + self.assertEqual(call_args[0][2].shape[0], D.test_tensors[1].shape[0]) + self.assertEqual(evaluator.pipelines[0].fit.call_count, 1) + + @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') + def test_cv(self, pipeline_mock): + D = get_binary_classification_datamanager(resampling_strategy=CrossValTypes.k_fold_cross_validation) + evaluator = self._get_evaluator(pipeline_mock, D) + self._check_results(evaluator, ans=0.463768115942029) - self.assertEqual(evaluator._save_to_backend.call_count, 1) - self.assertEqual(result, 0.463768115942029) self.assertEqual(pipeline_mock.fit.call_count, 5) # 15 calls because of the training, holdout and # test set (3 sets x 5 folds = 15) @@ -251,38 +248,10 @@ def test_cv(self, pipeline_mock): @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') def test_no_resampling(self, pipeline_mock): - pipeline_mock.fit_dictionary = {'budget_type': 'epochs', 'epochs': 10} - # Binary iris, contains 69 train samples, 31 test samples D = get_binary_classification_datamanager(NoResamplingStrategyTypes.no_resampling) - pipeline_mock.predict_proba.side_effect = \ - lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1)) - pipeline_mock.side_effect = lambda **kwargs: pipeline_mock - pipeline_mock.get_additional_run_info.return_value = None - - _queue = multiprocessing.Queue() - backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') - backend_api.load_datamanager = lambda: D - - fixed_params_dict = self.fixed_params._asdict() - fixed_params_dict.update(backend=backend_api) - evaluator = Evaluator( - queue=_queue, - fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict), - evaluator_params=self.eval_params - ) - evaluator._save_to_backend = unittest.mock.Mock(spec=evaluator._save_to_backend) - evaluator._save_to_backend.return_value = True + evaluator = self._get_evaluator(pipeline_mock, D) + self._check_results(evaluator, ans=0.5806451612903225) - evaluator.evaluate_loss() - - rval = read_queue(evaluator.queue) - self.assertEqual(len(rval), 1) - result = rval[0]['loss'] - self.assertEqual(len(rval[0]), 3) - self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1) - - self.assertEqual(evaluator._save_to_backend.call_count, 1) - self.assertEqual(result, 0.5806451612903225) self.assertEqual(pipeline_mock.fit.call_count, 1) # 2 calls because of train and test set self.assertEqual(pipeline_mock.predict_proba.call_count, 2) diff --git a/test/test_evaluation/test_tae.py b/test/test_evaluation/test_tae.py index 0a187f6c2..eaf505ad7 100644 --- a/test/test_evaluation/test_tae.py +++ b/test/test_evaluation/test_tae.py @@ -90,6 +90,7 @@ def _create_taq(): backend=unittest.mock.Mock(), seed=1, metric=accuracy, + multi_objectives=["cost"], cost_for_crash=accuracy._cost_of_crash, abort_on_first_run_crash=True, pynisher_context=unittest.mock.Mock() @@ -104,6 +105,15 @@ def test_check_run_info(self): with pytest.raises(ValueError): taq.run_wrapper(run_info) + def test_check_and_get_default_budget(self): + taq = _create_taq() + budget = taq._check_and_get_default_budget() + assert isinstance(budget, float) + + taq.fixed_pipeline_params = taq.fixed_pipeline_params._replace(budget_type='test') + with pytest.raises(ValueError): + taq._check_and_get_default_budget() + def test_cutoff_update_in_run_wrapper(self): taq = _create_taq() run_info = RunInfo( diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index ce1579caa..213671bb8 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -524,8 +524,8 @@ def test_train_pipeline_with_runtime_max_reached(fit_dictionary_tabular_dummy): def test_get_pipeline_representation(): pipeline = TabularClassificationPipeline( dataset_properties={ - 'numerical_columns': None, - 'categorical_columns': None, + 'numerical_columns': [], + 'categorical_columns': [], 'task_type': 'tabular_classification' } ) diff --git a/test/test_pipeline/test_tabular_regression.py b/test/test_pipeline/test_tabular_regression.py index e21eb961f..8ef8d26bd 100644 --- a/test/test_pipeline/test_tabular_regression.py +++ b/test/test_pipeline/test_tabular_regression.py @@ -322,8 +322,8 @@ def test_pipeline_score(fit_dictionary_tabular_dummy): def test_get_pipeline_representation(): pipeline = TabularRegressionPipeline( dataset_properties={ - 'numerical_columns': None, - 'categorical_columns': None, + 'numerical_columns': [], + 'categorical_columns': [], 'task_type': 'tabular_classification' } )