Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Optimise on custom metric #486

Open
wants to merge 5 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ jobs:
run: |
git submodule update --init --recursive

# flake8 does not restrict the version for importlib_metadata, however, it is incompatible.
# See https://github.com/python/importlib_metadata/issues/406
- name: Install pre-commit
run: |
pip install "importlib_metadata<5.0.0"
pip install pre-commit
pre-commit install

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
]
files: autoPyTorch/.*
exclude: autoPyTorch/ensemble/
- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/PyCQA/flake8
rev: 3.8.3
hooks:
- id: flake8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ def _prepare_model(self,
X_train: np.ndarray,
y_train: np.ndarray
) -> None:
if not self.is_classification:
if hasattr(AutoPyTorchToCatboostMetrics, self.metric.name):
self.config['eval_metric'] = AutoPyTorchToCatboostMetrics[self.metric.name].value
if not self.is_classification:
# CatBoost Cannot handle a random state object, just the seed
self.model = CatBoostRegressor(**self.config, random_state=self.random_state.get_state()[1][0])
else:
self.config['eval_metric'] = AutoPyTorchToCatboostMetrics[self.metric.name].value
# CatBoost Cannot handle a random state object, just the seed
self.model = CatBoostClassifier(**self.config, random_state=self.random_state.get_state()[1][0])

Expand Down
9 changes: 7 additions & 2 deletions autoPyTorch/pipeline/components/training/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import List, Union
from typing import Dict, List, Union

import numpy as np

Expand All @@ -13,7 +13,12 @@

from smac.utils.constants import MAXINT

from autoPyTorch.pipeline.components.training.metrics.base import make_metric
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric, make_metric


CLASSIFICATION_METRICS: Dict[str, autoPyTorchMetric]
FORECASTING_METRICS: Dict[str, autoPyTorchMetric]
REGRESSION_METRICS: Dict[str, autoPyTorchMetric]

# Standard regression scores
mean_absolute_error = make_metric('mean_absolute_error',
Expand Down
58 changes: 40 additions & 18 deletions autoPyTorch/pipeline/components/training/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@
)


def add_metric(
metric: autoPyTorchMetric,
task_type: str,
) -> None:
"""
Adds an `autoPyTorchMetric` such that it can be used for
the search as well as for model evaluation.

Args:
metric (autoPyTorchMetric):
Metric to be added.
task_type (str):
The current task type, in string.
"""

if STRING_TO_TASK_TYPES[task_type] in REGRESSION_TASKS:
metrics = REGRESSION_METRICS
elif STRING_TO_TASK_TYPES[task_type] in CLASSIFICATION_TASKS:
metrics = CLASSIFICATION_METRICS
elif STRING_TO_TASK_TYPES[task_type] in FORECASTING_TASKS:
metrics = FORECASTING_METRICS

metrics[metric.name] = metric


def sanitize_array(array: np.ndarray) -> np.ndarray:
"""
Replace NaN and Inf (there should not be any!)
Expand Down Expand Up @@ -66,14 +91,14 @@ def get_metrics(dataset_properties: Dict[str, Any],

Args:
dataset_properties: Dict[str, Any]
contains information about the dataset and task type
contains information about the dataset and task type
names: Optional[Iterable[str]]
names of metrics to return
names of metrics to return
all_supported_metrics: bool
if true, returns all metrics that are relevant to task_type
if true, returns all metrics that are relevant to task_type

Returns:

List[autoPyTorchMetric]
"""
assert 'task_type' in dataset_properties, \
"Expected dataset_properties to have task_type got {}".format(dataset_properties.keys())
Expand Down Expand Up @@ -189,27 +214,24 @@ def calculate_loss(
"""
Returns a loss (a magnitude that allows casting the
optimization problem, as a minimization one) for the
given Auto-Sklearn Scorer object
Parameters
----------
solution: np.ndarray
given autoPyTorchMetric object

Args:
target (np.ndarray):
The ground truth of the targets
prediction: np.ndarray
prediction (np.ndarray):
The best estimate from the model, of the given targets
task_type: int
task_type (int):
To understand if the problem task is classification
or regression
metric: Scorer
Object that host a function to calculate how good the
prediction is according to the solution.
scoring_functions: List[Scorer]
metrics (Iterable[autoPyTorchMetric]):
A list of metrics to calculate multiple losses
score_kwargs: Dict
additional arguments for computing scores
Returns
-------
float or Dict[str, float]
A loss function for each of the provided scorer objects

Returns:
Dict[str, float]:
A loss score for each of the provided metrics
"""
score = calculate_score(
target=target,
Expand Down
96 changes: 96 additions & 0 deletions examples/40_advanced/example_add_custom_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
======================
Tabular Classification
======================

The following example shows how to optimize
AutoPyTorch on a custom metric
"""
import os
import tempfile as tmp
import warnings

os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

import sklearn.datasets
import sklearn.model_selection

from autoPyTorch.api.tabular_classification import TabularClassificationTask
from autoPyTorch.metrics import CLASSIFICATION_METRICS
from autoPyTorch.pipeline.components.training.metrics.base import make_metric
from autoPyTorch.pipeline.components.training.metrics.utils import add_metric


############################################################################
# Data Loading
# ============
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
X,
y,
random_state=1,
)


###############################################################################
# Define and add custom score function
# ====================================
def score_function(y_test, y_pred):
Copy link
Contributor

@dengdifan dengdifan Dec 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Time Series Tasks, an additional kwargs must be attached:
https://github.com/sktime/sktime/blob/main/sktime/performance_metrics/forecasting/_functions.py#L179

Should we create a new example for this, or do we add another docstring here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add it to the same example. We can divide this example into two parts- for tabular tasks and for time series tasks

return sum(y_pred==y_test) / y_pred.shape[0]

print("#"*80)
print(f"Current metrics available for classification: {list(CLASSIFICATION_METRICS.keys())}")
custom_metric = make_metric(name="custom_metric", score_func=score_function, worst_possible_result=0, greater_is_better=True)

add_metric(metric=custom_metric, task_type="tabular_classification")
print("#"*80)
print(f"Metrics available for classification after adding custom metric: {list(CLASSIFICATION_METRICS.keys())}")


############################################################################
# Build and fit a classifier
# ==========================
api = TabularClassificationTask(
# To maintain logs of the run, you can uncomment the
# Following lines
# temporary_directory='./tmp/autoPyTorch_example_tmp_01',
# output_directory='./tmp/autoPyTorch_example_out_01',
# delete_tmp_folder_after_terminate=False,
# delete_output_folder_after_terminate=False,
seed=42,
)

######################################################################################
# Search for an ensemble of machine learning algorithms optimised on the custom metric
# ====================================================================================
api.search(
X_train=X_train,
y_train=y_train,
X_test=X_test.copy(),
y_test=y_test.copy(),
dataset_name='Australian',
optimize_metric='custom_metric',
total_walltime_limit=300,
func_eval_time_limit_secs=50,
memory_limit=None,
)

############################################################################
# Print the final ensemble performance
# ====================================

y_pred = api.predict(X_test)
score = api.score(y_pred, y_test)
print(score)

# Print statistics from search
print(api.sprint_statistics())

# Print the final ensemble built by AutoPyTorch
print(api.show_models())
48 changes: 47 additions & 1 deletion test/test_pipeline/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
TIMESERIES_FORECASTING
)
from autoPyTorch.metrics import (
CLASSIFICATION_METRICS,
FORECASTING_METRICS,
REGRESSION_METRICS,
accuracy,
balanced_accuracy,
compute_mase_coefficient,
Expand All @@ -30,7 +33,12 @@
autoPyTorchMetric,
make_metric
)
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_loss, calculate_score, get_metrics
from autoPyTorch.pipeline.components.training.metrics.utils import (
add_metric,
calculate_loss,
calculate_score,
get_metrics
)


@pytest.mark.parametrize('output_type', ['multiclass',
Expand Down Expand Up @@ -354,3 +362,41 @@ def test_compute_mase_coefficient():
past_target = np.zeros(12)
assert compute_mase_coefficient(past_target, 15) == 1.
assert compute_mase_coefficient(past_target, 5) == 1.


@pytest.mark.parametrize('task_type,task_metrics_dict,output_type', [
('tabular_classification', CLASSIFICATION_METRICS, 'multiclass'),
('tabular_classification', CLASSIFICATION_METRICS, 'multiclass-multioutput'),
('tabular_classification', CLASSIFICATION_METRICS, 'binary'),
('tabular_regression', REGRESSION_METRICS, 'continuous'),
('tabular_regression', REGRESSION_METRICS, 'continuous-multioutput'),
('time_series_forecasting', FORECASTING_METRICS, 'continuous'),
('time_series_forecasting', FORECASTING_METRICS, 'continuous-multioutput'),
])
def test_add_metric(task_type, task_metrics_dict, output_type):
"""
Tests if the custom score function is added to the metrics for the given task.
Also checks if the added metric is returned by the `get_metrics` function.

Args:
task_type (str)
task_metrics_dict (Dict[str, autoPyTorchMetric])
output_type (str)
"""
def score_function(y_test, y_pred):
return sum(y_pred == y_test) / y_pred.shape[0]

custom_metric = make_metric(
name="custom_metric",
score_func=score_function,
worst_possible_result=0,
greater_is_better=True
)

add_metric(metric=custom_metric, task_type=task_type)
assert 'custom_metric' in list(task_metrics_dict.keys())
metrics = get_metrics(
dataset_properties={'task_type': task_type, 'output_type': output_type},
names=['custom_metric']
)
assert custom_metric in metrics