From 552d3fce3984cac8c8f373f6b668966760cfd8ac Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 12 Dec 2022 15:26:42 +0100 Subject: [PATCH 1/5] add_metric and fix precommit --- .pre-commit-config.yaml | 2 +- .../traditional_learner/learners.py | 4 +- .../components/training/metrics/metrics.py | 9 ++- .../components/training/metrics/utils.py | 58 +++++++++++++------ 4 files changed, 50 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cac475cb1..7b2d8d697 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: ] files: autoPyTorch/.* exclude: autoPyTorch/ensemble/ - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/PyCQA/flake8 rev: 3.8.3 hooks: - id: flake8 diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py b/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py index fca02aa32..2a1d98cd9 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py @@ -119,12 +119,12 @@ def _prepare_model(self, X_train: np.ndarray, y_train: np.ndarray ) -> None: - if not self.is_classification: + if hasattr(AutoPyTorchToCatboostMetrics, self.metric.name): self.config['eval_metric'] = AutoPyTorchToCatboostMetrics[self.metric.name].value + if not self.is_classification: # CatBoost Cannot handle a random state object, just the seed self.model = CatBoostRegressor(**self.config, random_state=self.random_state.get_state()[1][0]) else: - self.config['eval_metric'] = AutoPyTorchToCatboostMetrics[self.metric.name].value # CatBoost Cannot handle a random state object, just the seed self.model = CatBoostClassifier(**self.config, random_state=self.random_state.get_state()[1][0]) diff --git a/autoPyTorch/pipeline/components/training/metrics/metrics.py b/autoPyTorch/pipeline/components/training/metrics/metrics.py index 5fa60a24d..58bb840c9 100644 --- a/autoPyTorch/pipeline/components/training/metrics/metrics.py +++ b/autoPyTorch/pipeline/components/training/metrics/metrics.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List, Union +from typing import Dict, List, Union import numpy as np @@ -13,7 +13,12 @@ from smac.utils.constants import MAXINT -from autoPyTorch.pipeline.components.training.metrics.base import make_metric +from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric, make_metric + + +CLASSIFICATION_METRICS: Dict[str, autoPyTorchMetric] +FORECASTING_METRICS: Dict[str, autoPyTorchMetric] +REGRESSION_METRICS: Dict[str, autoPyTorchMetric] # Standard regression scores mean_absolute_error = make_metric('mean_absolute_error', diff --git a/autoPyTorch/pipeline/components/training/metrics/utils.py b/autoPyTorch/pipeline/components/training/metrics/utils.py index e72c1afce..a84fa441c 100644 --- a/autoPyTorch/pipeline/components/training/metrics/utils.py +++ b/autoPyTorch/pipeline/components/training/metrics/utils.py @@ -20,6 +20,31 @@ ) +def add_metric( + metric: autoPyTorchMetric, + task_type: str, +) -> None: + """ + Adds an `autoPyTorchMetric` such that it can be used for + the search as well as for model evaluation. + + Args: + metric (autoPyTorchMetric): + Metric to be added. + task_type (str): + The current task type, in string. + """ + + if STRING_TO_TASK_TYPES[task_type] in REGRESSION_TASKS: + metrics = REGRESSION_METRICS + elif STRING_TO_TASK_TYPES[task_type] in CLASSIFICATION_TASKS: + metrics = CLASSIFICATION_METRICS + elif STRING_TO_TASK_TYPES[task_type] in FORECASTING_TASKS: + metrics = FORECASTING_METRICS + + metrics[metric.name] = metric + + def sanitize_array(array: np.ndarray) -> np.ndarray: """ Replace NaN and Inf (there should not be any!) @@ -66,14 +91,14 @@ def get_metrics(dataset_properties: Dict[str, Any], Args: dataset_properties: Dict[str, Any] - contains information about the dataset and task type + contains information about the dataset and task type names: Optional[Iterable[str]] - names of metrics to return + names of metrics to return all_supported_metrics: bool - if true, returns all metrics that are relevant to task_type + if true, returns all metrics that are relevant to task_type Returns: - + List[autoPyTorchMetric] """ assert 'task_type' in dataset_properties, \ "Expected dataset_properties to have task_type got {}".format(dataset_properties.keys()) @@ -189,27 +214,24 @@ def calculate_loss( """ Returns a loss (a magnitude that allows casting the optimization problem, as a minimization one) for the - given Auto-Sklearn Scorer object - Parameters - ---------- - solution: np.ndarray + given autoPyTorchMetric object + + Args: + target (np.ndarray): The ground truth of the targets - prediction: np.ndarray + prediction (np.ndarray): The best estimate from the model, of the given targets - task_type: int + task_type (int): To understand if the problem task is classification or regression - metric: Scorer - Object that host a function to calculate how good the - prediction is according to the solution. - scoring_functions: List[Scorer] + metrics (Iterable[autoPyTorchMetric]): A list of metrics to calculate multiple losses score_kwargs: Dict additional arguments for computing scores - Returns - ------- - float or Dict[str, float] - A loss function for each of the provided scorer objects + + Returns: + Dict[str, float]: + A loss score for each of the provided metrics """ score = calculate_score( target=target, From 989cac432f0e6cbb6889ca7b74c57f0cc11089b4 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 12 Dec 2022 15:27:04 +0100 Subject: [PATCH 2/5] add test case for add metric --- test/test_pipeline/test_metrics.py | 48 +++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/test/test_pipeline/test_metrics.py b/test/test_pipeline/test_metrics.py index 0a40d84bb..7e837cef8 100644 --- a/test/test_pipeline/test_metrics.py +++ b/test/test_pipeline/test_metrics.py @@ -17,6 +17,9 @@ TIMESERIES_FORECASTING ) from autoPyTorch.metrics import ( + CLASSIFICATION_METRICS, + FORECASTING_METRICS, + REGRESSION_METRICS, accuracy, balanced_accuracy, compute_mase_coefficient, @@ -30,7 +33,12 @@ autoPyTorchMetric, make_metric ) -from autoPyTorch.pipeline.components.training.metrics.utils import calculate_loss, calculate_score, get_metrics +from autoPyTorch.pipeline.components.training.metrics.utils import ( + add_metric, + calculate_loss, + calculate_score, + get_metrics +) @pytest.mark.parametrize('output_type', ['multiclass', @@ -354,3 +362,41 @@ def test_compute_mase_coefficient(): past_target = np.zeros(12) assert compute_mase_coefficient(past_target, 15) == 1. assert compute_mase_coefficient(past_target, 5) == 1. + + +@pytest.mark.parametrize('task_type,task_metrics_dict,output_type', [ + ('tabular_classification', CLASSIFICATION_METRICS, 'multiclass'), + ('tabular_classification', CLASSIFICATION_METRICS, 'multiclass-multioutput'), + ('tabular_classification', CLASSIFICATION_METRICS, 'binary'), + ('tabular_regression', REGRESSION_METRICS, 'continuous'), + ('tabular_regression', REGRESSION_METRICS, 'continuous-multioutput'), + ('time_series_forecasting', FORECASTING_METRICS, 'continuous'), + ('time_series_forecasting', FORECASTING_METRICS, 'continuous-multioutput'), +]) +def test_add_metric(task_type, task_metrics_dict, output_type): + """ + Tests if the custom score function is added to the metrics for the given task. + Also checks if the added metric is returned by the `get_metrics` function. + + Args: + task_type (str) + task_metrics_dict (Dict[str, autoPyTorchMetric]) + output_type (str) + """ + def score_function(y_test, y_pred): + return sum(y_pred == y_test) / y_pred.shape[0] + + custom_metric = make_metric( + name="custom_metric", + score_func=score_function, + worst_possible_result=0, + greater_is_better=True + ) + + add_metric(metric=custom_metric, task_type=task_type) + assert 'custom_metric' in list(task_metrics_dict.keys()) + metrics = get_metrics( + dataset_properties={'task_type': task_type, 'output_type': output_type}, + names=['custom_metric'] + ) + assert custom_metric in metrics From 8ac05640c90bc3aefbef855b45cee3581a740edb Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 12 Dec 2022 15:27:23 +0100 Subject: [PATCH 3/5] add example for add metric --- .../40_advanced/example_add_custom_metric.py | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 examples/40_advanced/example_add_custom_metric.py diff --git a/examples/40_advanced/example_add_custom_metric.py b/examples/40_advanced/example_add_custom_metric.py new file mode 100644 index 000000000..7222d4228 --- /dev/null +++ b/examples/40_advanced/example_add_custom_metric.py @@ -0,0 +1,96 @@ +""" +====================== +Tabular Classification +====================== + +The following example shows how to optimize +AutoPyTorch on a custom metric +""" +import os +import tempfile as tmp +import warnings + +os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['OPENBLAS_NUM_THREADS'] = '1' +os.environ['MKL_NUM_THREADS'] = '1' + +warnings.simplefilter(action='ignore', category=UserWarning) +warnings.simplefilter(action='ignore', category=FutureWarning) + +import sklearn.datasets +import sklearn.model_selection + +from autoPyTorch.api.tabular_classification import TabularClassificationTask +from autoPyTorch.metrics import CLASSIFICATION_METRICS +from autoPyTorch.pipeline.components.training.metrics.base import make_metric +from autoPyTorch.pipeline.components.training.metrics.utils import add_metric + + +############################################################################ +# Data Loading +# ============ +X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, + y, + random_state=1, +) + + +############################################################################### +# Define and add custom score function +# ==================================== +def score_function(y_test, y_pred): + return sum(y_pred==y_test) / y_pred.shape[0] + +print("#"*80) +print(f"Current metrics available for classification: {list(CLASSIFICATION_METRICS.keys())}") +custom_metric = make_metric(name="custom_metric", score_func=score_function, worst_possible_result=0, greater_is_better=True) + +add_metric(metric=custom_metric, task_type="tabular_classification") +print("#"*80) +print(f"Metrics available for classification after adding custom metric: {list(CLASSIFICATION_METRICS.keys())}") + + +############################################################################ +# Build and fit a classifier +# ========================== +api = TabularClassificationTask( + # To maintain logs of the run, you can uncomment the + # Following lines + # temporary_directory='./tmp/autoPyTorch_example_tmp_01', + # output_directory='./tmp/autoPyTorch_example_out_01', + # delete_tmp_folder_after_terminate=False, + # delete_output_folder_after_terminate=False, + seed=42, +) + +###################################################################################### +# Search for an ensemble of machine learning algorithms optimised on the custom metric +# ==================================================================================== +api.search( + X_train=X_train, + y_train=y_train, + X_test=X_test.copy(), + y_test=y_test.copy(), + dataset_name='Australian', + optimize_metric='custom_metric', + total_walltime_limit=300, + func_eval_time_limit_secs=50, + memory_limit=None, +) + +############################################################################ +# Print the final ensemble performance +# ==================================== + +y_pred = api.predict(X_test) +score = api.score(y_pred, y_test) +print(score) + +# Print statistics from search +print(api.sprint_statistics()) + +# Print the final ensemble built by AutoPyTorch +print(api.show_models()) From 3d0afa71be40451ff636d5996bb7c2867ec67d53 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 19 Dec 2022 14:21:11 +0100 Subject: [PATCH 4/5] fix bug with importlib_metadata package version --- .github/workflows/pre-commit.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index d9fd438c5..d725cbf54 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -34,8 +34,10 @@ jobs: run: | git submodule update --init --recursive + # flake8 does not restrict the version for importlib_metadata, however, it is incompatible. - name: Install pre-commit run: | + pip install importlib_metadata < 5.0.0 pip install pre-commit pre-commit install From 1b1f07b0799251b9e9074345062006cf5c33d5d2 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 19 Dec 2022 14:23:17 +0100 Subject: [PATCH 5/5] silly bug in passing version --- .github/workflows/pre-commit.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index d725cbf54..024df93f8 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -35,9 +35,10 @@ jobs: git submodule update --init --recursive # flake8 does not restrict the version for importlib_metadata, however, it is incompatible. + # See https://github.com/python/importlib_metadata/issues/406 - name: Install pre-commit run: | - pip install importlib_metadata < 5.0.0 + pip install "importlib_metadata<5.0.0" pip install pre-commit pre-commit install