Skip to content

Commit

Permalink
Add multi target classification (#441)
Browse files Browse the repository at this point in the history
* First pass at Multi-Target classifier. Core functionality works, but failing other tests

* Updated base model to support custom metrics on multi-target

* fix to init metrics param config in multi-target

* updates pytests to include multi-target classification

* preliminary fix for combine_prediction

* Documentation updates

* linter cleanup

* Bugfix for metrics in multi-target classification

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added new tutorial for multi-target classification

* Minor update to documentation for multi-target classification

---------

Co-authored-by: Yony Bresler <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Manu Joseph V <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Sep 17, 2024
1 parent fc6060e commit 25691f5
Show file tree
Hide file tree
Showing 22 changed files with 2,160 additions and 61 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ from pytorch_tabular.config import (
data_config = DataConfig(
target=[
"target"
], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
], # target should always be a list.
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/gs_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from pytorch_tabular.config import (
data_config = DataConfig(
target=[
"target"
], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
], # target should always be a list.
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@
"data_config = DataConfig(\n",
" target=[\n",
" target_col\n",
" ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
" ], # target should always be a list\n",
" continuous_cols=num_col_names,\n",
" categorical_cols=cat_col_names,\n",
")\n",
Expand Down
2,008 changes: 2,008 additions & 0 deletions docs/tutorials/15-Multi Target Classification.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions examples/__only_for_dev__/adhoc_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def print_metrics(y_true, y_pred, tag):
from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig # noqa: E402

data_config = DataConfig(
# target should always be a list. Multi-targets are only supported for regression.
# Multi-Task Classification is not implemented
# target should always be a list.
target=["target"],
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
Expand Down
3 changes: 1 addition & 2 deletions examples/__only_for_dev__/to_test_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def print_metrics(y_true, y_pred, tag):
lr = 1e-3

data_config = DataConfig(
# target should always be a list. Multi-targets are only supported for regression.
# Multi-Task Classification is not implemented
# target should always be a list.
target=[target_name],
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
Expand Down
6 changes: 6 additions & 0 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class InferredConfig:
output_dim (Optional[int]): The number of output targets
output_cardinality (Optional[int]): The number of unique values in classification output
categorical_cardinality (Optional[List[int]]): The number of unique values in categorical features
embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
Expand All @@ -216,6 +218,10 @@ class InferredConfig:
default=None,
metadata={"help": "The number of output targets"},
)
output_cardinality: Optional[List[int]] = field(
default=None,
metadata={"help": "The number of unique values in classification output"},
)
categorical_cardinality: Optional[List[int]] = field(
default=None,
metadata={"help": "The number of unique values in categorical features"},
Expand Down
91 changes: 72 additions & 19 deletions src/pytorch_tabular/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,43 @@ def __init__(
config.metrics_params.append(vars(metric))
if config.task == "classification":
config.metrics_prob_input = self.custom_metrics_prob_inputs
for i, mp in enumerate(config.metrics_params):
mp.sub_params_list = []
for j, num_classes in enumerate(inferred_config.output_cardinality):
config.metrics_params[i].sub_params_list.append(
OmegaConf.create(
{
"task": mp.get("task", "multiclass"),
"num_classes": mp.get("num_classes", num_classes),
}
)
)

# Updating default metrics in config
elif config.task == "classification":
# Adding metric_params to config for classification task
for i, mp in enumerate(config.metrics_params):
# For classification task, output_dim == number of classses
config.metrics_params[i]["task"] = mp.get("task", "multiclass")
config.metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim)
if config.metrics[i] in (
"accuracy",
"precision",
"recall",
"precision_recall",
"specificity",
"f1_score",
"fbeta_score",
):
config.metrics_params[i]["top_k"] = mp.get("top_k", 1)
mp.sub_params_list = []
for j, num_classes in enumerate(inferred_config.output_cardinality):
# config.metrics_params[i][j]["task"] = mp.get("task", "multiclass")
# config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes)

config.metrics_params[i].sub_params_list.append(
OmegaConf.create(
{"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)}
)
)

if config.metrics[i] in (
"accuracy",
"precision",
"recall",
"precision_recall",
"specificity",
"f1_score",
"fbeta_score",
):
config.metrics_params[i].sub_params_list[j]["top_k"] = mp.get("top_k", 1)

if self.custom_optimizer is not None:
config.optimizer = str(self.custom_optimizer.__class__.__name__)
Expand Down Expand Up @@ -267,7 +287,22 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
)
else:
# TODO loss fails with batch size of 1?
computed_loss = self.loss(y_hat.squeeze(), y.squeeze()) + reg_loss
computed_loss = reg_loss
start_index = 0
for i in range(len(self.hparams.output_cardinality)):
end_index = start_index + self.hparams.output_cardinality[i]
_loss = self.loss(y_hat[:, start_index:end_index], y[:, i])
computed_loss += _loss
if self.hparams.output_dim > 1:
self.log(
f"{tag}_loss_{i}",
_loss,
on_epoch=True,
on_step=False,
logger=True,
prog_bar=False,
)
start_index = end_index
self.log(
f"{tag}_loss",
computed_loss,
Expand Down Expand Up @@ -325,11 +360,29 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
_metrics.append(_metric)
avg_metric = torch.stack(_metrics, dim=0).sum()
else:
y_hat = nn.Softmax(dim=-1)(y_hat.squeeze())
if prob_inp:
avg_metric = metric(y_hat, y.squeeze(), **metric_params)
else:
avg_metric = metric(torch.argmax(y_hat, dim=-1), y.squeeze(), **metric_params)
_metrics = []
start_index = 0
for i, cardinality in enumerate(self.hparams.output_cardinality):
end_index = start_index + cardinality
y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze())
if prob_inp:
_metric = metric(y_hat_i, y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i])
else:
_metric = metric(
torch.argmax(y_hat_i, dim=-1), y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i]
)
if len(self.hparams.output_cardinality) > 1:
self.log(
f"{tag}_{metric_str}_{i}",
_metric,
on_epoch=True,
on_step=False,
logger=True,
prog_bar=False,
)
_metrics.append(_metric)
start_index = end_index
avg_metric = torch.stack(_metrics, dim=0).sum()
metrics.append(avg_metric)
self.log(
f"{tag}_{metric_str}",
Expand Down
27 changes: 20 additions & 7 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,21 @@ def _update_config(self, config) -> InferredConfig:
if config.task == "regression":
# self._output_dim_reg = len(config.target) if config.target else None if self.train is not None:
output_dim = len(config.target) if config.target else None
output_cardinality = None
elif config.task == "classification":
# self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None
if self.train is not None:
output_dim = len(np.unique(self.train[config.target[0]])) if config.target else None
output_cardinality = (
self.train[config.target].fillna("NA").nunique().tolist() if config.target else None
)
output_dim = sum(output_cardinality)
else:
output_dim = len(np.unique(self.train_dataset.y)) if config.target else None
output_cardinality = (
self.train_dataset.data[config.target].fillna("NA").nunique().tolist() if config.target else None
)
output_dim = sum(output_cardinality)
elif config.task == "ssl":
output_cardinality = None
output_dim = None
else:
raise ValueError(f"{config.task} is an unsupported task.")
Expand All @@ -308,6 +316,7 @@ def _update_config(self, config) -> InferredConfig:
categorical_dim=categorical_dim,
continuous_dim=continuous_dim,
output_dim=output_dim,
output_cardinality=output_cardinality,
categorical_cardinality=categorical_cardinality,
embedding_dims=embedding_dims,
)
Expand Down Expand Up @@ -381,11 +390,14 @@ def _label_encode_target(self, data: DataFrame, stage: str) -> DataFrame:
if self.config.task != "classification":
return data
if stage == "fit" or self.label_encoder is None:
self.label_encoder = LabelEncoder()
data[self.config.target[0]] = self.label_encoder.fit_transform(data[self.config.target[0]])
self.label_encoder = [None] * len(self.config.target)
for i in range(len(self.config.target)):
self.label_encoder[i] = LabelEncoder()
data[self.config.target[i]] = self.label_encoder[i].fit_transform(data[self.config.target[i]])
else:
if self.config.target[0] in data.columns:
data[self.config.target[0]] = self.label_encoder.transform(data[self.config.target[0]])
for i in range(len(self.config.target)):
if self.config.target[i] in data.columns:
data[self.config.target[i]] = self.label_encoder[i].transform(data[self.config.target[i]])
return data

def _target_transform(self, data: DataFrame, stage: str) -> DataFrame:
Expand Down Expand Up @@ -818,7 +830,8 @@ def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
# TODO Is the target encoding necessary?
if len(set(self.target) - set(df.columns)) > 0:
if self.config.task == "classification":
df.loc[:, self.target] = np.array([self.label_encoder.classes_[0]] * len(df)).reshape(-1, 1)
for i in range(len(self.target)):
df.loc[:, self.target[i]] = np.array([self.label_encoder[i].classes_[0]] * len(df)).reshape(-1, 1)
else:
df.loc[:, self.target] = np.zeros((len(df), len(self.target)))
df, _ = self.preprocess_data(df, stage="inference")
Expand Down
31 changes: 15 additions & 16 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,6 @@ def num_params(self):

def _run_validation(self):
"""Validates the Config params and throws errors if something is wrong."""
if self.config.task == "classification":
if len(self.config.target) > 1:
raise NotImplementedError("Multi-Target Classification is not implemented.")
if self.config.task == "regression":
if self.config.target_range is not None:
if (
Expand Down Expand Up @@ -1291,12 +1288,16 @@ def _format_predicitons(
pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1)

elif self.config.task == "classification":
point_predictions = nn.Softmax(dim=-1)(point_predictions).numpy()
for i, class_ in enumerate(self.datamodule.label_encoder.classes_):
pred_df[f"{class_}_probability"] = point_predictions[:, i]
pred_df["prediction"] = self.datamodule.label_encoder.inverse_transform(
np.argmax(point_predictions, axis=1)
)
start_index = 0
for i, target_col in enumerate(self.config.target):
end_index = start_index + self.datamodule._inferred_config.output_cardinality[i]
prob_prediction = nn.Softmax(dim=-1)(point_predictions[:, start_index:end_index]).numpy()
start_index = end_index
for j, class_ in enumerate(self.datamodule.label_encoder[i].classes_):
pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[:, j]
pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[i].inverse_transform(
np.argmax(prob_prediction, axis=1)
)
warnings.warn(
"Classification prediction column will be renamed to"
" `{target_col}_prediction` in the next release to maintain"
Expand Down Expand Up @@ -2046,23 +2047,21 @@ def _combine_predictions(
elif callable(aggregate):
bagged_pred = aggregate(pred_prob_l)
if self.config.task == "classification":
classes = self.datamodule.label_encoder.classes_
# FIXME need to iterate .label_encoder[x]
classes = self.datamodule.label_encoder[0].classes_
if aggregate == "hard_voting":
pred_df = pd.DataFrame(
np.concatenate(pred_prob_l, axis=1),
columns=[
f"{c}_probability_fold_{i}"
for i in range(len(pred_prob_l))
for c in self.datamodule.label_encoder.classes_
],
columns=[f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) for c in classes],
index=pred_idx,
)
pred_df["prediction"] = classes[final_pred]
else:
final_pred = classes[np.argmax(bagged_pred, axis=1)]
pred_df = pd.DataFrame(
bagged_pred,
columns=[f"{c}_probability" for c in self.datamodule.label_encoder.classes_],
# FIXME
columns=[f"{c}_probability" for c in self.datamodule.label_encoder[0].classes_],
index=pred_idx,
)
pred_df["prediction"] = final_pred
Expand Down
4 changes: 3 additions & 1 deletion tests/test_autoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]


@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
Expand All @@ -91,6 +92,7 @@ def test_regression(
@pytest.mark.parametrize("batch_norm_continuous_input", [True, False])
def test_classification(
classification_data,
multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
Expand All @@ -100,7 +102,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
target=target,
target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_categorical_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]


@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
Expand All @@ -136,6 +137,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
Expand All @@ -146,7 +148,7 @@ def test_classification(
return

data_config = DataConfig(
target=target,
target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def test_cross_validate_regression(
[
"accuracy",
None,
lambda y_true, y_pred: accuracy_score(y_true, y_pred["prediction"].values),
lambda y_true, y_pred: accuracy_score(y_true, y_pred["target_prediction"].values),
],
)
@pytest.mark.parametrize("return_oof", [True])
Expand Down
4 changes: 3 additions & 1 deletion tests/test_danet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]


@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
Expand All @@ -91,14 +92,15 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
normalize_continuous_features,
):
(train, test, target) = classification_data
data_config = DataConfig(
target=target,
target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
Expand Down
Loading

0 comments on commit 25691f5

Please sign in to comment.