From c973856db0d27e9ee2d8e2b04183af7171957731 Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Fri, 10 Dec 2021 21:41:18 +0100 Subject: [PATCH 1/5] working_example --- examples/15_AUC_multiclass.ipynb | 461 +++++++++++++++++++++++ pytorch_widedeep/datasets/__init__.py | 4 +- pytorch_widedeep/datasets/_base.py | 87 ++++- pytorch_widedeep/datasets/data/ecoli.csv | 337 +++++++++++++++++ 4 files changed, 886 insertions(+), 3 deletions(-) create mode 100644 examples/15_AUC_multiclass.ipynb create mode 100644 pytorch_widedeep/datasets/data/ecoli.csv diff --git a/examples/15_AUC_multiclass.ipynb b/examples/15_AUC_multiclass.ipynb new file mode 100644 index 00000000..878f4ee9 --- /dev/null +++ b/examples/15_AUC_multiclass.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "731975e2", + "metadata": {}, + "source": [ + "# Hyperparameter tuning with Raytune and visulization using Tensorboard and Weights & Biases" + ] + }, + { + "cell_type": "markdown", + "id": "ee745c58", + "metadata": {}, + "source": [ + "## Initial imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fdab94eb", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-12-10 21:26:26.784170: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2021-12-10 21:26:26.784221: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from torch.optim import SGD, lr_scheduler\n", + "\n", + "from pytorch_widedeep import Trainer\n", + "from pytorch_widedeep.preprocessing import TabPreprocessor\n", + "from pytorch_widedeep.models import TabMlp, WideDeep\n", + "from torchmetrics import AUC\n", + "from pytorch_widedeep.initializers import XavierNormal\n", + "from pytorch_widedeep.datasets import load_ecoli\n", + "from pytorch_widedeep.utils import LabelEncoder\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# increase displayed columns in jupyter notebook\n", + "pd.set_option(\"display.max_columns\", 200)\n", + "pd.set_option(\"display.max_rows\", 300)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "07c75f0c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SequenceNamemcggvhlipchgaacalm1alm2class
0AAT_ECOLI0.490.290.480.50.560.240.35cp
1ACEA_ECOLI0.070.400.480.50.540.350.44cp
2ACEK_ECOLI0.560.400.480.50.490.370.46cp
3ACKA_ECOLI0.590.490.480.50.520.450.36cp
4ADI_ECOLI0.230.320.480.50.550.250.35cp
\n", + "
" + ], + "text/plain": [ + " SequenceName mcg gvh lip chg aac alm1 alm2 class\n", + "0 AAT_ECOLI 0.49 0.29 0.48 0.5 0.56 0.24 0.35 cp\n", + "1 ACEA_ECOLI 0.07 0.40 0.48 0.5 0.54 0.35 0.44 cp\n", + "2 ACEK_ECOLI 0.56 0.40 0.48 0.5 0.49 0.37 0.46 cp\n", + "3 ACKA_ECOLI 0.59 0.49 0.48 0.5 0.52 0.45 0.36 cp\n", + "4 ADI_ECOLI 0.23 0.32 0.48 0.5 0.55 0.25 0.35 cp" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = load_ecoli(as_frame=True)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "1e3f8efc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cp 143\n", + "im 77\n", + "pp 52\n", + "imU 35\n", + "om 20\n", + "omL 5\n", + "imS 2\n", + "imL 2\n", + "Name: class, dtype: int64" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# imbalance of the classes\n", + "df[\"class\"].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "e4db0d6d", + "metadata": {}, + "outputs": [], + "source": [ + "df = df.loc[~df[\"class\"].isin([\"omL\", \"imS\", \"imL\"])]\n", + "df.reset_index(inplace=True, drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "005531a3", + "metadata": {}, + "outputs": [], + "source": [ + "encoder = LabelEncoder([\"class\"])\n", + "df_enc = encoder.fit_transform(df)\n", + "df_enc[\"class\"] = df_enc[\"class\"]-1" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "214b3071", + "metadata": {}, + "outputs": [], + "source": [ + "# drop columns we won't need in this example\n", + "df_enc = df_enc.drop(columns=[\"SequenceName\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "168c81f1", + "metadata": {}, + "outputs": [], + "source": [ + "df_train, df_valid = train_test_split(df_enc, test_size=0.2, stratify=df_enc[\"class\"], random_state=1)\n", + "df_valid, df_test = train_test_split(df_valid, test_size=0.5, stratify=df_valid[\"class\"], random_state=1)" + ] + }, + { + "cell_type": "markdown", + "id": "87e7b8f0", + "metadata": {}, + "source": [ + "## Preparing the data" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "3a7b246b", + "metadata": {}, + "outputs": [], + "source": [ + "continuous_cols = df_enc.drop(columns=[\"class\"]).columns.values.tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "7a2dac24", + "metadata": {}, + "outputs": [], + "source": [ + "# deeptabular\n", + "tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)\n", + "X_tab_train = tab_preprocessor.fit_transform(df_train)\n", + "X_tab_valid = tab_preprocessor.transform(df_valid)\n", + "X_tab_test = tab_preprocessor.transform(df_test)\n", + "\n", + "# target\n", + "y_train = df_train[\"class\"].values\n", + "y_valid = df_valid[\"class\"].values\n", + "y_test = df_test[\"class\"].values\n", + "\n", + "X_train = {\"X_tab\": X_tab_train, \"target\": y_train}\n", + "X_val = {\"X_tab\": X_tab_valid, \"target\": y_valid}" + ] + }, + { + "cell_type": "markdown", + "id": "7b9f63e2", + "metadata": {}, + "source": [ + "## Define the model" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "511198d4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "WideDeep(\n", + " (deeptabular): Sequential(\n", + " (0): TabMlp(\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (cont_norm): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (tab_mlp): MLP(\n", + " (mlp): Sequential(\n", + " (dense_layer_0): Sequential(\n", + " (0): Dropout(p=0.1, inplace=False)\n", + " (1): Linear(in_features=7, out_features=200, bias=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (dense_layer_1): Sequential(\n", + " (0): Dropout(p=0.1, inplace=False)\n", + " (1): Linear(in_features=200, out_features=100, bias=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (1): Linear(in_features=100, out_features=5, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "deeptabular = TabMlp(\n", + " column_idx=tab_preprocessor.column_idx,\n", + " continuous_cols=tab_preprocessor.continuous_cols,\n", + ")\n", + "model = WideDeep(deeptabular=deeptabular, pred_dim=df_enc[\"class\"].nunique())\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "a5359b0f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n", + " warnings.warn(*args, **kwargs)\n" + ] + } + ], + "source": [ + "auc = AUC()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "34a18ac0", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'AUC' object has no attribute 'num_classes'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_1193/1577559620.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_val\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X_wide, X_tab, X_text, X_img, X_train, X_val, val_split, target, n_epochs, validation_freq, batch_size, custom_dataloader, finetune, finetune_epochs, finetune_max_lr, finetune_deeptabular_gradual, finetune_deeptabular_max_lr, finetune_deeptabular_layers, finetune_deeptext_gradual, finetune_deeptext_max_lr, finetune_deeptext_layers, finetune_deepimage_gradual, finetune_deepimage_max_lr, finetune_deepimage_layers, finetune_routine, stop_after_finetuning, **kwargs)\u001b[0m\n\u001b[1;32m 615\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_description\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"epoch %i\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m \u001b[0mtrain_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 618\u001b[0m \u001b[0mprint_loss_and_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 619\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_train_step\u001b[0;34m(self, data, target, batch_idx)\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1169\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1170\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1171\u001b[0m \u001b[0;31m# TODO raise exception if the loss is exploding with non scaled target values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1172\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_get_score\u001b[0;34m(self, y_pred, y)\u001b[0m\n\u001b[1;32m 1212\u001b[0m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1213\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"multiclass\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1214\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1215\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/metrics.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprefix\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTorchMetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[operator]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 1131\u001b[0m type(self).__name__, name))\n\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'AUC' object has no attribute 'num_classes'" + ] + } + ], + "source": [ + "# Optimizers\n", + "deep_opt = SGD(model.deeptabular.parameters(), lr=0.1)\n", + "# LR Scheduler\n", + "deep_sch = lr_scheduler.StepLR(deep_opt, step_size=3)\n", + "# Hyperparameters\n", + "trainer = Trainer(\n", + " model,\n", + " objective=\"multiclass_focal_loss\",\n", + " lr_schedulers={\"deeptabular\": deep_sch},\n", + " initializers={\"deeptabular\": XavierNormal},\n", + " optimizers={\"deeptabular\": deep_opt},\n", + " metrics=[auc],\n", + " verbose=0,\n", + ")\n", + "\n", + "trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=10)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d" + }, + "kernelspec": { + "display_name": "Python 3.8.5 64-bit ('base': conda)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pytorch_widedeep/datasets/__init__.py b/pytorch_widedeep/datasets/__init__.py index b4ba41cb..00eca1e5 100644 --- a/pytorch_widedeep/datasets/__init__.py +++ b/pytorch_widedeep/datasets/__init__.py @@ -1,3 +1,3 @@ -from ._base import load_adult, load_bio_kdd04 +from ._base import load_adult, load_bio_kdd04, load_ecoli -__all__ = ["load_bio_kdd04", "load_adult"] +__all__ = ["load_bio_kdd04", "load_adult", "load_ecoli"] diff --git a/pytorch_widedeep/datasets/_base.py b/pytorch_widedeep/datasets/_base.py index 45e3e76a..23187ed7 100644 --- a/pytorch_widedeep/datasets/_base.py +++ b/pytorch_widedeep/datasets/_base.py @@ -5,7 +5,7 @@ def load_bio_kdd04(as_frame: bool = False): """Load and return the higly imbalanced Protein Homology - Dataset from [KDD cup 2004](https://www.kdd.org/kdd-cup/view/kdd-cup-2004/Data. + Dataset from [KDD cup 2004](https://www.kdd.org/kdd-cup/view/kdd-cup-2004/Data). This datasets include only bio_train.dat part of the dataset @@ -46,3 +46,88 @@ def load_adult(as_frame: bool = False): return df else: return df.to_numpy() + + +def load_ecoli(as_frame: bool = False): + """Load and return the higly imbalanced multiclass classification e.coli dataset + Dataset from [UCI Machine learning Repository](https://archive.ics.uci.edu/ml/datasets/ecoli). + + + 1. Title: Protein Localization Sites + + 2. Creator and Maintainer: + Kenta Nakai + Institue of Molecular and Cellular Biology + Osaka, University + 1-3 Yamada-oka, Suita 565 Japan + nakai@imcb.osaka-u.ac.jp + http://www.imcb.osaka-u.ac.jp/nakai/psort.html + Donor: Paul Horton (paulh@cs.berkeley.edu) + Date: September, 1996 + See also: yeast database + + 3. Past Usage. + Reference: "A Probablistic Classification System for Predicting the Cellular + Localization Sites of Proteins", Paul Horton & Kenta Nakai, + Intelligent Systems in Molecular Biology, 109-115. + St. Louis, USA 1996. + Results: 81% for E.coli with an ad hoc structured + probability model. Also similar accuracy for Binary Decision Tree and + Bayesian Classifier methods applied by the same authors in + unpublished results. + + Predicted Attribute: Localization site of protein. ( non-numeric ). + + 4. The references below describe a predecessor to this dataset and its + development. They also give results (not cross-validated) for classification + by a rule-based expert system with that version of the dataset. + + Reference: "Expert Sytem for Predicting Protein Localization Sites in + Gram-Negative Bacteria", Kenta Nakai & Minoru Kanehisa, + PROTEINS: Structure, Function, and Genetics 11:95-110, 1991. + + Reference: "A Knowledge Base for Predicting Protein Localization Sites in + Eukaryotic Cells", Kenta Nakai & Minoru Kanehisa, + Genomics 14:897-911, 1992. + + 5. Number of Instances: 336 for the E.coli dataset and + + 6. Number of Attributes. + for E.coli dataset: 8 ( 7 predictive, 1 name ) + + 7. Attribute Information. + + 1. Sequence Name: Accession number for the SWISS-PROT database + 2. mcg: McGeoch's method for signal sequence recognition. + 3. gvh: von Heijne's method for signal sequence recognition. + 4. lip: von Heijne's Signal Peptidase II consensus sequence score. + Binary attribute. + 5. chg: Presence of charge on N-terminus of predicted lipoproteins. + Binary attribute. + 6. aac: score of discriminant analysis of the amino acid content of + outer membrane and periplasmic proteins. + 7. alm1: score of the ALOM membrane spanning region prediction program. + 8. alm2: score of ALOM program after excluding putative cleavable signal + regions from the sequence. + + 8. Missing Attribute Values: None. + + 9. Class Distribution. The class is the localization site. Please see Nakai & Kanehisa referenced above for more details. + + cp (cytoplasm) 143 + im (inner membrane without signal sequence) 77 + pp (perisplasm) 52 + imU (inner membrane, uncleavable signal sequence) 35 + om (outer membrane) 20 + omL (outer membrane lipoprotein) 5 + imL (inner membrane lipoprotein) 2 + imS (inner membrane, cleavable signal sequence) 2 + """ + + with resources.path("pytorch_widedeep.datasets.data", "ecoli.csv") as fpath: + df = pd.read_csv(fpath, sep=",") + + if as_frame: + return df + else: + return df.to_numpy() \ No newline at end of file diff --git a/pytorch_widedeep/datasets/data/ecoli.csv b/pytorch_widedeep/datasets/data/ecoli.csv new file mode 100644 index 00000000..4b91813b --- /dev/null +++ b/pytorch_widedeep/datasets/data/ecoli.csv @@ -0,0 +1,337 @@ +SequenceName,mcg,gvh,lip,chg,aac,alm1,alm2,class +AAT_ECOLI,0.49,0.29,0.48,0.50,0.56,0.24,0.35,cp +ACEA_ECOLI,0.07,0.40,0.48,0.50,0.54,0.35,0.44,cp +ACEK_ECOLI,0.56,0.40,0.48,0.50,0.49,0.37,0.46,cp +ACKA_ECOLI,0.59,0.49,0.48,0.50,0.52,0.45,0.36,cp +ADI_ECOLI,0.23,0.32,0.48,0.50,0.55,0.25,0.35,cp +ALKH_ECOLI,0.67,0.39,0.48,0.50,0.36,0.38,0.46,cp +AMPD_ECOLI,0.29,0.28,0.48,0.50,0.44,0.23,0.34,cp +AMY2_ECOLI,0.21,0.34,0.48,0.50,0.51,0.28,0.39,cp +APT_ECOLI,0.20,0.44,0.48,0.50,0.46,0.51,0.57,cp +ARAC_ECOLI,0.42,0.40,0.48,0.50,0.56,0.18,0.30,cp +ASG1_ECOLI,0.42,0.24,0.48,0.50,0.57,0.27,0.37,cp +BTUR_ECOLI,0.25,0.48,0.48,0.50,0.44,0.17,0.29,cp +CAFA_ECOLI,0.39,0.32,0.48,0.50,0.46,0.24,0.35,cp +CAIB_ECOLI,0.51,0.50,0.48,0.50,0.46,0.32,0.35,cp +CFA_ECOLI,0.22,0.43,0.48,0.50,0.48,0.16,0.28,cp +CHEA_ECOLI,0.25,0.40,0.48,0.50,0.46,0.44,0.52,cp +CHEB_ECOLI,0.34,0.45,0.48,0.50,0.38,0.24,0.35,cp +CHEW_ECOLI,0.44,0.27,0.48,0.50,0.55,0.52,0.58,cp +CHEY_ECOLI,0.23,0.40,0.48,0.50,0.39,0.28,0.38,cp +CHEZ_ECOLI,0.41,0.57,0.48,0.50,0.39,0.21,0.32,cp +CRL_ECOLI,0.40,0.45,0.48,0.50,0.38,0.22,0.00,cp +CSPA_ECOLI,0.31,0.23,0.48,0.50,0.73,0.05,0.14,cp +CYNR_ECOLI,0.51,0.54,0.48,0.50,0.41,0.34,0.43,cp +CYPB_ECOLI,0.30,0.16,0.48,0.50,0.56,0.11,0.23,cp +CYPC_ECOLI,0.36,0.39,0.48,0.50,0.48,0.22,0.23,cp +CYSB_ECOLI,0.29,0.37,0.48,0.50,0.48,0.44,0.52,cp +CYSE_ECOLI,0.25,0.40,0.48,0.50,0.47,0.33,0.42,cp +DAPD_ECOLI,0.21,0.51,0.48,0.50,0.50,0.32,0.41,cp +DCP_ECOLI,0.43,0.37,0.48,0.50,0.53,0.35,0.44,cp +DDLA_ECOLI,0.43,0.39,0.48,0.50,0.47,0.31,0.41,cp +DDLB_ECOLI,0.53,0.38,0.48,0.50,0.44,0.26,0.36,cp +DEOC_ECOLI,0.34,0.33,0.48,0.50,0.38,0.35,0.44,cp +DLDH_ECOLI,0.56,0.51,0.48,0.50,0.34,0.37,0.46,cp +EFG_ECOLI,0.40,0.29,0.48,0.50,0.42,0.35,0.44,cp +EFTS_ECOLI,0.24,0.35,0.48,0.50,0.31,0.19,0.31,cp +EFTU_ECOLI,0.36,0.54,0.48,0.50,0.41,0.38,0.46,cp +ENO_ECOLI,0.29,0.52,0.48,0.50,0.42,0.29,0.39,cp +FABB_ECOLI,0.65,0.47,0.48,0.50,0.59,0.30,0.40,cp +FES_ECOLI,0.32,0.42,0.48,0.50,0.35,0.28,0.38,cp +G3P1_ECOLI,0.38,0.46,0.48,0.50,0.48,0.22,0.29,cp +G3P2_ECOLI,0.33,0.45,0.48,0.50,0.52,0.32,0.41,cp +G6PI_ECOLI,0.30,0.37,0.48,0.50,0.59,0.41,0.49,cp +GCVA_ECOLI,0.40,0.50,0.48,0.50,0.45,0.39,0.47,cp +GLNA_ECOLI,0.28,0.38,0.48,0.50,0.50,0.33,0.42,cp +GLPD_ECOLI,0.61,0.45,0.48,0.50,0.48,0.35,0.41,cp +GLYA_ECOLI,0.17,0.38,0.48,0.50,0.45,0.42,0.50,cp +GSHR_ECOLI,0.44,0.35,0.48,0.50,0.55,0.55,0.61,cp +GT_ECOLI,0.43,0.40,0.48,0.50,0.39,0.28,0.39,cp +HEM6_ECOLI,0.42,0.35,0.48,0.50,0.58,0.15,0.27,cp +HEMN_ECOLI,0.23,0.33,0.48,0.50,0.43,0.33,0.43,cp +HPRT_ECOLI,0.37,0.52,0.48,0.50,0.42,0.42,0.36,cp +IF1_ECOLI,0.29,0.30,0.48,0.50,0.45,0.03,0.17,cp +IF2_ECOLI,0.22,0.36,0.48,0.50,0.35,0.39,0.47,cp +ILVY_ECOLI,0.23,0.58,0.48,0.50,0.37,0.53,0.59,cp +IPYR_ECOLI,0.47,0.47,0.48,0.50,0.22,0.16,0.26,cp +KAD_ECOLI,0.54,0.47,0.48,0.50,0.28,0.33,0.42,cp +KDSA_ECOLI,0.51,0.37,0.48,0.50,0.35,0.36,0.45,cp +LEU3_ECOLI,0.40,0.35,0.48,0.50,0.45,0.33,0.42,cp +LON_ECOLI,0.44,0.34,0.48,0.50,0.30,0.33,0.43,cp +LPLA_ECOLI,0.42,0.38,0.48,0.50,0.54,0.34,0.43,cp +LYSR_ECOLI,0.44,0.56,0.48,0.50,0.50,0.46,0.54,cp +MALQ_ECOLI,0.52,0.36,0.48,0.50,0.41,0.28,0.38,cp +MALZ_ECOLI,0.36,0.41,0.48,0.50,0.48,0.47,0.54,cp +MASY_ECOLI,0.18,0.30,0.48,0.50,0.46,0.24,0.35,cp +METB_ECOLI,0.47,0.29,0.48,0.50,0.51,0.33,0.43,cp +METC_ECOLI,0.24,0.43,0.48,0.50,0.54,0.52,0.59,cp +METK_ECOLI,0.25,0.37,0.48,0.50,0.41,0.33,0.42,cp +METR_ECOLI,0.52,0.57,0.48,0.50,0.42,0.47,0.54,cp +METX_ECOLI,0.25,0.37,0.48,0.50,0.43,0.26,0.36,cp +MURF_ECOLI,0.35,0.48,0.48,0.50,0.56,0.40,0.48,cp +NADA_ECOLI,0.26,0.26,0.48,0.50,0.34,0.25,0.35,cp +NFRC_ECOLI,0.44,0.51,0.48,0.50,0.47,0.26,0.36,cp +NHAR_ECOLI,0.37,0.50,0.48,0.50,0.42,0.36,0.45,cp +NIRD_ECOLI,0.44,0.42,0.48,0.50,0.42,0.25,0.20,cp +OMPR_ECOLI,0.24,0.43,0.48,0.50,0.37,0.28,0.38,cp +OTC1_ECOLI,0.42,0.30,0.48,0.50,0.48,0.26,0.36,cp +OTC2_ECOLI,0.48,0.42,0.48,0.50,0.45,0.25,0.35,cp +PEPE_ECOLI,0.41,0.48,0.48,0.50,0.51,0.44,0.51,cp +PFLA_ECOLI,0.44,0.28,0.48,0.50,0.43,0.27,0.37,cp +PFLB_ECOLI,0.29,0.41,0.48,0.50,0.48,0.38,0.46,cp +PGK_ECOLI,0.34,0.28,0.48,0.50,0.41,0.35,0.44,cp +PHOB_ECOLI,0.41,0.43,0.48,0.50,0.45,0.31,0.41,cp +PHOH_ECOLI,0.29,0.47,0.48,0.50,0.41,0.23,0.34,cp +PMBA_ECOLI,0.34,0.55,0.48,0.50,0.58,0.31,0.41,cp +PNP_ECOLI,0.36,0.56,0.48,0.50,0.43,0.45,0.53,cp +PROB_ECOLI,0.40,0.46,0.48,0.50,0.52,0.49,0.56,cp +PT1A_ECOLI,0.50,0.49,0.48,0.50,0.49,0.46,0.53,cp +PT1_ECOLI,0.52,0.44,0.48,0.50,0.37,0.36,0.42,cp +PTCA_ECOLI,0.50,0.51,0.48,0.50,0.27,0.23,0.34,cp +PTCB_ECOLI,0.53,0.42,0.48,0.50,0.16,0.29,0.39,cp +PTFA_ECOLI,0.34,0.46,0.48,0.50,0.52,0.35,0.44,cp +PTGA_ECOLI,0.40,0.42,0.48,0.50,0.37,0.27,0.27,cp +PTHA_ECOLI,0.41,0.43,0.48,0.50,0.50,0.24,0.25,cp +PTHP_ECOLI,0.30,0.45,0.48,0.50,0.36,0.21,0.32,cp +PTKA_ECOLI,0.31,0.47,0.48,0.50,0.29,0.28,0.39,cp +PTKB_ECOLI,0.64,0.76,0.48,0.50,0.45,0.35,0.38,cp +PTNA_ECOLI,0.35,0.37,0.48,0.50,0.30,0.34,0.43,cp +PTWB_ECOLI,0.57,0.54,0.48,0.50,0.37,0.28,0.33,cp +PTWX_ECOLI,0.65,0.55,0.48,0.50,0.34,0.37,0.28,cp +RHAR_ECOLI,0.51,0.46,0.48,0.50,0.58,0.31,0.41,cp +RHAS_ECOLI,0.38,0.40,0.48,0.50,0.63,0.25,0.35,cp +RIMI_ECOLI,0.24,0.57,0.48,0.50,0.63,0.34,0.43,cp +RIMJ_ECOLI,0.38,0.26,0.48,0.50,0.54,0.16,0.28,cp +RIML_ECOLI,0.33,0.47,0.48,0.50,0.53,0.18,0.29,cp +RNB_ECOLI,0.24,0.34,0.48,0.50,0.38,0.30,0.40,cp +RNC_ECOLI,0.26,0.50,0.48,0.50,0.44,0.32,0.41,cp +RND_ECOLI,0.44,0.49,0.48,0.50,0.39,0.38,0.40,cp +RNE_ECOLI,0.43,0.32,0.48,0.50,0.33,0.45,0.52,cp +SERC_ECOLI,0.49,0.43,0.48,0.50,0.49,0.30,0.40,cp +SLYD_ECOLI,0.47,0.28,0.48,0.50,0.56,0.20,0.25,cp +SOXS_ECOLI,0.32,0.33,0.48,0.50,0.60,0.06,0.20,cp +SYA_ECOLI,0.34,0.35,0.48,0.50,0.51,0.49,0.56,cp +SYC_ECOLI,0.35,0.34,0.48,0.50,0.46,0.30,0.27,cp +SYD_ECOLI,0.38,0.30,0.48,0.50,0.43,0.29,0.39,cp +SYE_ECOLI,0.38,0.44,0.48,0.50,0.43,0.20,0.31,cp +SYFA_ECOLI,0.41,0.51,0.48,0.50,0.58,0.20,0.31,cp +SYFB_ECOLI,0.34,0.42,0.48,0.50,0.41,0.34,0.43,cp +SYGA_ECOLI,0.51,0.49,0.48,0.50,0.53,0.14,0.26,cp +SYGB_ECOLI,0.25,0.51,0.48,0.50,0.37,0.42,0.50,cp +SYH_ECOLI,0.29,0.28,0.48,0.50,0.50,0.42,0.50,cp +SYI_ECOLI,0.25,0.26,0.48,0.50,0.39,0.32,0.42,cp +SYK1_ECOLI,0.24,0.41,0.48,0.50,0.49,0.23,0.34,cp +SYK2_ECOLI,0.17,0.39,0.48,0.50,0.53,0.30,0.39,cp +SYL_ECOLI,0.04,0.31,0.48,0.50,0.41,0.29,0.39,cp +SYM_ECOLI,0.61,0.36,0.48,0.50,0.49,0.35,0.44,cp +SYP_ECOLI,0.34,0.51,0.48,0.50,0.44,0.37,0.46,cp +SYQ_ECOLI,0.28,0.33,0.48,0.50,0.45,0.22,0.33,cp +SYR_ECOLI,0.40,0.46,0.48,0.50,0.42,0.35,0.44,cp +SYS_ECOLI,0.23,0.34,0.48,0.50,0.43,0.26,0.37,cp +SYT_ECOLI,0.37,0.44,0.48,0.50,0.42,0.39,0.47,cp +SYV_ECOLI,0.00,0.38,0.48,0.50,0.42,0.48,0.55,cp +SYW_ECOLI,0.39,0.31,0.48,0.50,0.38,0.34,0.43,cp +SYY_ECOLI,0.30,0.44,0.48,0.50,0.49,0.22,0.33,cp +THGA_ECOLI,0.27,0.30,0.48,0.50,0.71,0.28,0.39,cp +THIK_ECOLI,0.17,0.52,0.48,0.50,0.49,0.37,0.46,cp +TYRB_ECOLI,0.36,0.42,0.48,0.50,0.53,0.32,0.41,cp +UBIC_ECOLI,0.30,0.37,0.48,0.50,0.43,0.18,0.30,cp +UGPQ_ECOLI,0.26,0.40,0.48,0.50,0.36,0.26,0.37,cp +USPA_ECOLI,0.40,0.41,0.48,0.50,0.55,0.22,0.33,cp +UVRB_ECOLI,0.22,0.34,0.48,0.50,0.42,0.29,0.39,cp +UVRC_ECOLI,0.44,0.35,0.48,0.50,0.44,0.52,0.59,cp +XGPT_ECOLI,0.27,0.42,0.48,0.50,0.37,0.38,0.43,cp +XYLA_ECOLI,0.16,0.43,0.48,0.50,0.54,0.27,0.37,cp +EMRA_ECOLI,0.06,0.61,0.48,0.50,0.49,0.92,0.37,im +AAS_ECOLI,0.44,0.52,0.48,0.50,0.43,0.47,0.54,im +AMPE_ECOLI,0.63,0.47,0.48,0.50,0.51,0.82,0.84,im +ARAE_ECOLI,0.23,0.48,0.48,0.50,0.59,0.88,0.89,im +ARAH_ECOLI,0.34,0.49,0.48,0.50,0.58,0.85,0.80,im +AROP_ECOLI,0.43,0.40,0.48,0.50,0.58,0.75,0.78,im +ATKB_ECOLI,0.46,0.61,0.48,0.50,0.48,0.86,0.87,im +ATP6_ECOLI,0.27,0.35,0.48,0.50,0.51,0.77,0.79,im +BETT_ECOLI,0.52,0.39,0.48,0.50,0.65,0.71,0.73,im +CODB_ECOLI,0.29,0.47,0.48,0.50,0.71,0.65,0.69,im +CYDA_ECOLI,0.55,0.47,0.48,0.50,0.57,0.78,0.80,im +CYOC_ECOLI,0.12,0.67,0.48,0.50,0.74,0.58,0.63,im +CYOD_ECOLI,0.40,0.50,0.48,0.50,0.65,0.82,0.84,im +DCTA_ECOLI,0.73,0.36,0.48,0.50,0.53,0.91,0.92,im +DHG_ECOLI,0.84,0.44,0.48,0.50,0.48,0.71,0.74,im +DHSC_ECOLI,0.48,0.45,0.48,0.50,0.60,0.78,0.80,im +DHSD_ECOLI,0.54,0.49,0.48,0.50,0.40,0.87,0.88,im +DPPC_ECOLI,0.48,0.41,0.48,0.50,0.51,0.90,0.88,im +DSBB_ECOLI,0.50,0.66,0.48,0.50,0.31,0.92,0.92,im +ENVZ_ECOLI,0.72,0.46,0.48,0.50,0.51,0.66,0.70,im +EXBB_ECOLI,0.47,0.55,0.48,0.50,0.58,0.71,0.75,im +FRDC_ECOLI,0.33,0.56,0.48,0.50,0.33,0.78,0.80,im +FRDD_ECOLI,0.64,0.58,0.48,0.50,0.48,0.78,0.73,im +FTSW_ECOLI,0.54,0.57,0.48,0.50,0.56,0.81,0.83,im +GABP_ECOLI,0.47,0.59,0.48,0.50,0.52,0.76,0.79,im +GALP_ECOLI,0.63,0.50,0.48,0.50,0.59,0.85,0.86,im +GLNP_ECOLI,0.49,0.42,0.48,0.50,0.53,0.79,0.81,im +GLPT_ECOLI,0.31,0.50,0.48,0.50,0.57,0.84,0.85,im +GLTP_ECOLI,0.74,0.44,0.48,0.50,0.55,0.88,0.89,im +KDGL_ECOLI,0.33,0.45,0.48,0.50,0.45,0.88,0.89,im +KGTP_ECOLI,0.45,0.40,0.48,0.50,0.61,0.74,0.77,im +LACY_ECOLI,0.71,0.40,0.48,0.50,0.71,0.70,0.74,im +LGT_ECOLI,0.50,0.37,0.48,0.50,0.66,0.64,0.69,im +LLDP_ECOLI,0.66,0.53,0.48,0.50,0.59,0.66,0.66,im +LNT_ECOLI,0.60,0.61,0.48,0.50,0.54,0.67,0.71,im +LSPA_ECOLI,0.83,0.37,0.48,0.50,0.61,0.71,0.74,im +LYSP_ECOLI,0.34,0.51,0.48,0.50,0.67,0.90,0.90,im +MALF_ECOLI,0.63,0.54,0.48,0.50,0.65,0.79,0.81,im +MALG_ECOLI,0.70,0.40,0.48,0.50,0.56,0.86,0.83,im +MCP3_ECOLI,0.60,0.50,1.00,0.50,0.54,0.77,0.80,im +MSBB_ECOLI,0.16,0.51,0.48,0.50,0.33,0.39,0.48,im +MTR_ECOLI,0.74,0.70,0.48,0.50,0.66,0.65,0.69,im +NANT_ECOLI,0.20,0.46,0.48,0.50,0.57,0.78,0.81,im +NHAA_ECOLI,0.89,0.55,0.48,0.50,0.51,0.72,0.76,im +NHAB_ECOLI,0.70,0.46,0.48,0.50,0.56,0.78,0.73,im +PHEP_ECOLI,0.12,0.43,0.48,0.50,0.63,0.70,0.74,im +PHOR_ECOLI,0.61,0.52,0.48,0.50,0.54,0.67,0.52,im +PNTA_ECOLI,0.33,0.37,0.48,0.50,0.46,0.65,0.69,im +POTE_ECOLI,0.63,0.65,0.48,0.50,0.66,0.67,0.71,im +PROP_ECOLI,0.41,0.51,0.48,0.50,0.53,0.75,0.78,im +PSTA_ECOLI,0.34,0.67,0.48,0.50,0.52,0.76,0.79,im +PSTC_ECOLI,0.58,0.34,0.48,0.50,0.56,0.87,0.81,im +PTAA_ECOLI,0.59,0.56,0.48,0.50,0.55,0.80,0.82,im +PTBA_ECOLI,0.51,0.40,0.48,0.50,0.57,0.62,0.67,im +PTCC_ECOLI,0.50,0.57,0.48,0.50,0.71,0.61,0.66,im +PTDA_ECOLI,0.60,0.46,0.48,0.50,0.45,0.81,0.83,im +PTFB_ECOLI,0.37,0.47,0.48,0.50,0.39,0.76,0.79,im +PTGB_ECOLI,0.58,0.55,0.48,0.50,0.57,0.70,0.74,im +PTHB_ECOLI,0.36,0.47,0.48,0.50,0.51,0.69,0.72,im +PTMA_ECOLI,0.39,0.41,0.48,0.50,0.52,0.72,0.75,im +PTOA_ECOLI,0.35,0.51,0.48,0.50,0.61,0.71,0.74,im +PTTB_ECOLI,0.31,0.44,0.48,0.50,0.50,0.79,0.82,im +RODA_ECOLI,0.61,0.66,0.48,0.50,0.46,0.87,0.88,im +SECE_ECOLI,0.48,0.49,0.48,0.50,0.52,0.77,0.71,im +SECF_ECOLI,0.11,0.50,0.48,0.50,0.58,0.72,0.68,im +SECY_ECOLI,0.31,0.36,0.48,0.50,0.58,0.94,0.94,im +TNAB_ECOLI,0.68,0.51,0.48,0.50,0.71,0.75,0.78,im +XYLE_ECOLI,0.69,0.39,0.48,0.50,0.57,0.76,0.79,im +YCEE_ECOLI,0.52,0.54,0.48,0.50,0.62,0.76,0.79,im +EXBD_ECOLI,0.46,0.59,0.48,0.50,0.36,0.76,0.23,im +FTSL_ECOLI,0.36,0.45,0.48,0.50,0.38,0.79,0.17,im +FTSN_ECOLI,0.00,0.51,0.48,0.50,0.35,0.67,0.44,im +FTSQ_ECOLI,0.10,0.49,0.48,0.50,0.41,0.67,0.21,im +MOTB_ECOLI,0.30,0.51,0.48,0.50,0.42,0.61,0.34,im +TOLA_ECOLI,0.61,0.47,0.48,0.50,0.00,0.80,0.32,im +TOLQ_ECOLI,0.63,0.75,0.48,0.50,0.64,0.73,0.66,im +EMRB_ECOLI,0.71,0.52,0.48,0.50,0.64,1.00,0.99,im +ATKC_ECOLI,0.85,0.53,0.48,0.50,0.53,0.52,0.35,imS +NFRB_ECOLI,0.63,0.49,0.48,0.50,0.54,0.76,0.79,imS +NLPA_ECOLI,0.75,0.55,1.00,1.00,0.40,0.47,0.30,imL +CYOA_ECOLI,0.70,0.39,1.00,0.50,0.51,0.82,0.84,imL +ATKA_ECOLI,0.72,0.42,0.48,0.50,0.65,0.77,0.79,imU +BCR_ECOLI,0.79,0.41,0.48,0.50,0.66,0.81,0.83,imU +CADB_ECOLI,0.83,0.48,0.48,0.50,0.65,0.76,0.79,imU +CAIT_ECOLI,0.69,0.43,0.48,0.50,0.59,0.74,0.77,imU +CPXA_ECOLI,0.79,0.36,0.48,0.50,0.46,0.82,0.70,imU +CRED_ECOLI,0.78,0.33,0.48,0.50,0.57,0.77,0.79,imU +CYDB_ECOLI,0.75,0.37,0.48,0.50,0.64,0.70,0.74,imU +CYOB_ECOLI,0.59,0.29,0.48,0.50,0.64,0.75,0.77,imU +CYOE_ECOLI,0.67,0.37,0.48,0.50,0.54,0.64,0.68,imU +DMSC_ECOLI,0.66,0.48,0.48,0.50,0.54,0.70,0.74,imU +DPPB_ECOLI,0.64,0.46,0.48,0.50,0.48,0.73,0.76,imU +DSBD_ECOLI,0.76,0.71,0.48,0.50,0.50,0.71,0.75,imU +FEPD_ECOLI,0.84,0.49,0.48,0.50,0.55,0.78,0.74,imU +FEPG_ECOLI,0.77,0.55,0.48,0.50,0.51,0.78,0.74,imU +FTSH_ECOLI,0.81,0.44,0.48,0.50,0.42,0.67,0.68,imU +GLTS_ECOLI,0.58,0.60,0.48,0.50,0.59,0.73,0.76,imU +KEFC_ECOLI,0.63,0.42,0.48,0.50,0.48,0.77,0.80,imU +KUP_ECOLI,0.62,0.42,0.48,0.50,0.58,0.79,0.81,imU +MCP1_ECOLI,0.86,0.39,0.48,0.50,0.59,0.89,0.90,imU +MCP2_ECOLI,0.81,0.53,0.48,0.50,0.57,0.87,0.88,imU +MCP4_ECOLI,0.87,0.49,0.48,0.50,0.61,0.76,0.79,imU +MELB_ECOLI,0.47,0.46,0.48,0.50,0.62,0.74,0.77,imU +MOTA_ECOLI,0.76,0.41,0.48,0.50,0.50,0.59,0.62,imU +NUPC_ECOLI,0.70,0.53,0.48,0.50,0.70,0.86,0.87,imU +NUPG_ECOLI,0.64,0.45,0.48,0.50,0.67,0.61,0.66,imU +PNTB_ECOLI,0.81,0.52,0.48,0.50,0.57,0.78,0.80,imU +PTKC_ECOLI,0.73,0.26,0.48,0.50,0.57,0.75,0.78,imU +RHAT_ECOLI,0.49,0.61,1.00,0.50,0.56,0.71,0.74,imU +SECD_ECOLI,0.88,0.42,0.48,0.50,0.52,0.73,0.75,imU +SECG_ECOLI,0.84,0.54,0.48,0.50,0.75,0.92,0.70,imU +TEHA_ECOLI,0.63,0.51,0.48,0.50,0.64,0.72,0.76,imU +TYRP_ECOLI,0.86,0.55,0.48,0.50,0.63,0.81,0.83,imU +UHPB_ECOLI,0.79,0.54,0.48,0.50,0.50,0.66,0.68,imU +TONB_ECOLI,0.57,0.38,0.48,0.50,0.06,0.49,0.33,imU +LEP_ECOLI,0.78,0.44,0.48,0.50,0.45,0.73,0.68,imU +FADL_ECOLI,0.78,0.68,0.48,0.50,0.83,0.40,0.29,om +FHUA_ECOLI,0.63,0.69,0.48,0.50,0.65,0.41,0.28,om +LAMB_ECOLI,0.67,0.88,0.48,0.50,0.73,0.50,0.25,om +NFRA_ECOLI,0.61,0.75,0.48,0.50,0.51,0.33,0.33,om +NMPC_ECOLI,0.67,0.84,0.48,0.50,0.74,0.54,0.37,om +OMPA_ECOLI,0.74,0.90,0.48,0.50,0.57,0.53,0.29,om +OMPC_ECOLI,0.73,0.84,0.48,0.50,0.86,0.58,0.29,om +OMPF_ECOLI,0.75,0.76,0.48,0.50,0.83,0.57,0.30,om +OMPX_ECOLI,0.77,0.57,0.48,0.50,0.88,0.53,0.20,om +PHOE_ECOLI,0.74,0.78,0.48,0.50,0.75,0.54,0.15,om +TSX_ECOLI,0.68,0.76,0.48,0.50,0.84,0.45,0.27,om +BTUB_ECOLI,0.56,0.68,0.48,0.50,0.77,0.36,0.45,om +CIRA_ECOLI,0.65,0.51,0.48,0.50,0.66,0.54,0.33,om +FECA_ECOLI,0.52,0.81,0.48,0.50,0.72,0.38,0.38,om +FEPA_ECOLI,0.64,0.57,0.48,0.50,0.70,0.33,0.26,om +FHUE_ECOLI,0.60,0.76,1.00,0.50,0.77,0.59,0.52,om +OMPP_ECOLI,0.69,0.59,0.48,0.50,0.77,0.39,0.21,om +OMPT_ECOLI,0.63,0.49,0.48,0.50,0.79,0.45,0.28,om +TOLC_ECOLI,0.71,0.71,0.48,0.50,0.68,0.43,0.36,om +PA1_ECOLI,0.68,0.63,0.48,0.50,0.73,0.40,0.30,om +MULI_ECOLI,0.77,0.57,1.00,0.50,0.37,0.54,0.01,omL +NLPB_ECOLI,0.66,0.49,1.00,0.50,0.54,0.56,0.36,omL +NLPE_ECOLI,0.71,0.46,1.00,0.50,0.52,0.59,0.30,omL +PAL_ECOLI,0.67,0.55,1.00,0.50,0.66,0.58,0.16,omL +SLP_ECOLI,0.68,0.49,1.00,0.50,0.62,0.55,0.28,omL +AGP_ECOLI,0.74,0.49,0.48,0.50,0.42,0.54,0.36,pp +AMY1_ECOLI,0.70,0.61,0.48,0.50,0.56,0.52,0.43,pp +ARAF_ECOLI,0.66,0.86,0.48,0.50,0.34,0.41,0.36,pp +ASG2_ECOLI,0.73,0.78,0.48,0.50,0.58,0.51,0.31,pp +BGLX_ECOLI,0.65,0.57,0.48,0.50,0.47,0.47,0.51,pp +C562_ECOLI,0.72,0.86,0.48,0.50,0.17,0.55,0.21,pp +CN16_ECOLI,0.67,0.70,0.48,0.50,0.46,0.45,0.33,pp +CYPH_ECOLI,0.67,0.81,0.48,0.50,0.54,0.49,0.23,pp +CYSP_ECOLI,0.67,0.61,0.48,0.50,0.51,0.37,0.38,pp +DGAL_ECOLI,0.63,1.00,0.48,0.50,0.35,0.51,0.49,pp +DPPA_ECOLI,0.57,0.59,0.48,0.50,0.39,0.47,0.33,pp +DSBA_ECOLI,0.71,0.71,0.48,0.50,0.40,0.54,0.39,pp +DSBC_ECOLI,0.66,0.74,0.48,0.50,0.31,0.38,0.43,pp +ECOT_ECOLI,0.67,0.81,0.48,0.50,0.25,0.42,0.25,pp +ECPD_ECOLI,0.64,0.72,0.48,0.50,0.49,0.42,0.19,pp +FECB_ECOLI,0.68,0.82,0.48,0.50,0.38,0.65,0.56,pp +FECR_ECOLI,0.32,0.39,0.48,0.50,0.53,0.28,0.38,pp +FEPB_ECOLI,0.70,0.64,0.48,0.50,0.47,0.51,0.47,pp +FIMC_ECOLI,0.63,0.57,0.48,0.50,0.49,0.70,0.20,pp +GGT_ECOLI,0.74,0.82,0.48,0.50,0.49,0.49,0.41,pp +GLNH_ECOLI,0.63,0.86,0.48,0.50,0.39,0.47,0.34,pp +GLPQ_ECOLI,0.63,0.83,0.48,0.50,0.40,0.39,0.19,pp +HTRA_ECOLI,0.63,0.71,0.48,0.50,0.60,0.40,0.39,pp +LIVJ_ECOLI,0.71,0.86,0.48,0.50,0.40,0.54,0.32,pp +LIVK_ECOLI,0.68,0.78,0.48,0.50,0.43,0.44,0.42,pp +MALE_ECOLI,0.64,0.84,0.48,0.50,0.37,0.45,0.40,pp +MALM_ECOLI,0.74,0.47,0.48,0.50,0.50,0.57,0.42,pp +MEPA_ECOLI,0.75,0.84,0.48,0.50,0.35,0.52,0.33,pp +MODA_ECOLI,0.63,0.65,0.48,0.50,0.39,0.44,0.35,pp +NRFA_ECOLI,0.69,0.67,0.48,0.50,0.30,0.39,0.24,pp +NRFF_ECOLI,0.70,0.71,0.48,0.50,0.42,0.84,0.85,pp +OPPA_ECOLI,0.69,0.80,0.48,0.50,0.46,0.57,0.26,pp +OSMY_ECOLI,0.64,0.66,0.48,0.50,0.41,0.39,0.20,pp +POTD_ECOLI,0.63,0.80,0.48,0.50,0.46,0.31,0.29,pp +POTF_ECOLI,0.66,0.71,0.48,0.50,0.41,0.50,0.35,pp +PPA_ECOLI,0.69,0.59,0.48,0.50,0.46,0.44,0.52,pp +PPB_ECOLI,0.68,0.67,0.48,0.50,0.49,0.40,0.34,pp +PROX_ECOLI,0.64,0.78,0.48,0.50,0.50,0.36,0.38,pp +PSTS_ECOLI,0.62,0.78,0.48,0.50,0.47,0.49,0.54,pp +PTR_ECOLI,0.76,0.73,0.48,0.50,0.44,0.39,0.39,pp +RBSB_ECOLI,0.64,0.81,0.48,0.50,0.37,0.39,0.44,pp +SPEA_ECOLI,0.29,0.39,0.48,0.50,0.52,0.40,0.48,pp +SUBI_ECOLI,0.62,0.83,0.48,0.50,0.46,0.36,0.40,pp +TBPA_ECOLI,0.56,0.54,0.48,0.50,0.43,0.37,0.30,pp +TESA_ECOLI,0.69,0.66,0.48,0.50,0.41,0.50,0.25,pp +TOLB_ECOLI,0.69,0.65,0.48,0.50,0.63,0.48,0.41,pp +TORA_ECOLI,0.43,0.59,0.48,0.50,0.52,0.49,0.56,pp +TREA_ECOLI,0.74,0.56,0.48,0.50,0.47,0.68,0.30,pp +UGPB_ECOLI,0.71,0.57,0.48,0.50,0.48,0.35,0.32,pp +USHA_ECOLI,0.61,0.60,0.48,0.50,0.44,0.39,0.38,pp +XYLF_ECOLI,0.59,0.61,0.48,0.50,0.42,0.42,0.37,pp +YTFQ_ECOLI,0.74,0.74,0.48,0.50,0.31,0.53,0.52,pp From ef9ea277b84926c06804f2e5e180f8136c7ecd71 Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Fri, 10 Dec 2021 21:42:45 +0100 Subject: [PATCH 2/5] name fix --- examples/15_AUC_multiclass.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/15_AUC_multiclass.ipynb b/examples/15_AUC_multiclass.ipynb index 878f4ee9..02cb8ea7 100644 --- a/examples/15_AUC_multiclass.ipynb +++ b/examples/15_AUC_multiclass.ipynb @@ -5,7 +5,7 @@ "id": "731975e2", "metadata": {}, "source": [ - "# Hyperparameter tuning with Raytune and visulization using Tensorboard and Weights & Biases" + "# AUC multiclass computation" ] }, { From d13e86f99edff88ef4c2fb8e229477f0a1e32626 Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Sat, 11 Dec 2021 16:37:13 +0100 Subject: [PATCH 3/5] fixed multiclass torchmetrics --- examples/15_AUC_multiclass.ipynb | 84 +++++++++++-------------- pytorch_widedeep/metrics.py | 77 ++++++++++++++++++++++- pytorch_widedeep/training/trainer.py | 10 ++- tests/test_metrics/test_torchmetrics.py | 9 ++- 4 files changed, 127 insertions(+), 53 deletions(-) diff --git a/examples/15_AUC_multiclass.ipynb b/examples/15_AUC_multiclass.ipynb index 02cb8ea7..f0af0e1a 100644 --- a/examples/15_AUC_multiclass.ipynb +++ b/examples/15_AUC_multiclass.ipynb @@ -26,8 +26,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-12-10 21:26:26.784170: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", - "2021-12-10 21:26:26.784221: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + "2021-12-11 16:34:55.734357: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2021-12-11 16:34:55.734404: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], @@ -39,7 +39,7 @@ "from pytorch_widedeep import Trainer\n", "from pytorch_widedeep.preprocessing import TabPreprocessor\n", "from pytorch_widedeep.models import TabMlp, WideDeep\n", - "from torchmetrics import AUC\n", + "from torchmetrics import AUC, AUROC\n", "from pytorch_widedeep.initializers import XavierNormal\n", "from pytorch_widedeep.datasets import load_ecoli\n", "from pytorch_widedeep.utils import LabelEncoder\n", @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 2, "id": "07c75f0c", "metadata": {}, "outputs": [ @@ -163,7 +163,7 @@ "4 ADI_ECOLI 0.23 0.32 0.48 0.5 0.55 0.25 0.35 cp" ] }, - "execution_count": 26, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -175,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 3, "id": "1e3f8efc", "metadata": {}, "outputs": [ @@ -193,7 +193,7 @@ "Name: class, dtype: int64" ] }, - "execution_count": 27, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -205,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 4, "id": "e4db0d6d", "metadata": {}, "outputs": [], @@ -216,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 5, "id": "005531a3", "metadata": {}, "outputs": [], @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 6, "id": "214b3071", "metadata": {}, "outputs": [], @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 7, "id": "168c81f1", "metadata": {}, "outputs": [], @@ -258,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 8, "id": "3a7b246b", "metadata": {}, "outputs": [], @@ -268,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 9, "id": "7a2dac24", "metadata": {}, "outputs": [], @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 10, "id": "511198d4", "metadata": {}, "outputs": [ @@ -331,7 +331,7 @@ ")" ] }, - "execution_count": 39, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 11, "id": "a5359b0f", "metadata": {}, "outputs": [ @@ -356,49 +356,40 @@ "output_type": "stream", "text": [ "/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n", + " warnings.warn(*args, **kwargs)\n", + "/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n", " warnings.warn(*args, **kwargs)\n" ] } ], "source": [ - "auc = AUC()" + "auc = AUC(reorder=True)\n", + "auc.num_classes = df_enc[\"class\"].nunique()\n", + "auroc = AUROC(num_classes=df_enc[\"class\"].nunique())" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 12, "id": "34a18ac0", "metadata": { "scrolled": false }, "outputs": [ { - "ename": "AttributeError", - "evalue": "'AUC' object has no attribute 'num_classes'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_1193/1577559620.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_val\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/utils/general_utils.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m ] = alias\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X_wide, X_tab, X_text, X_img, X_train, X_val, val_split, target, n_epochs, validation_freq, batch_size, custom_dataloader, finetune, finetune_epochs, finetune_max_lr, finetune_deeptabular_gradual, finetune_deeptabular_max_lr, finetune_deeptabular_layers, finetune_deeptext_gradual, finetune_deeptext_max_lr, finetune_deeptext_layers, finetune_deepimage_gradual, finetune_deepimage_max_lr, finetune_deepimage_layers, finetune_routine, stop_after_finetuning, **kwargs)\u001b[0m\n\u001b[1;32m 615\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_description\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"epoch %i\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 617\u001b[0;31m \u001b[0mtrain_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargett\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 618\u001b[0m \u001b[0mprint_loss_and_metric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 619\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_train_step\u001b[0;34m(self, data, target, batch_idx)\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1169\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1170\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1171\u001b[0m \u001b[0;31m# TODO raise exception if the loss is exploding with non scaled target values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1172\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/training/trainer.py\u001b[0m in \u001b[0;36m_get_score\u001b[0;34m(self, y_pred, y)\u001b[0m\n\u001b[1;32m 1212\u001b[0m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1213\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"multiclass\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1214\u001b[0;31m \u001b[0mscore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1215\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/pytorch_widedeep/metrics.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprefix\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTorchMetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mround\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_classes\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[operator]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 1131\u001b[0m type(self).__name__, name))\n\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'AUC' object has no attribute 'num_classes'" + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 6/6 [00:00<00:00, 79.20it/s, loss=0.1, metrics={'AUC': 8.5, 'AUROC': 0.427}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 6.06it/s, loss=0.0961, metrics={'AUC': 6.5, 'AUROC': 0.419}]\n", + "epoch 2: 100%|██████████| 6/6 [00:00<00:00, 82.52it/s, loss=0.095, metrics={'AUC': 4.5, 'AUROC': 0.4418}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.69it/s, loss=0.0917, metrics={'AUC': 6.5, 'AUROC': 0.4351}]\n", + "epoch 3: 100%|██████████| 6/6 [00:00<00:00, 103.30it/s, loss=0.0908, metrics={'AUC': 5.5, 'AUROC': 0.4715}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.35it/s, loss=0.0875, metrics={'AUC': 6.5, 'AUROC': 0.4633}]\n", + "epoch 4: 100%|██████████| 6/6 [00:00<00:00, 90.88it/s, loss=0.0872, metrics={'AUC': 7.0, 'AUROC': 0.4767}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.44it/s, loss=0.0874, metrics={'AUC': 6.5, 'AUROC': 0.4652}]\n", + "epoch 5: 100%|██████████| 6/6 [00:00<00:00, 88.87it/s, loss=0.0866, metrics={'AUC': 6.0, 'AUROC': 0.4775}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.37it/s, loss=0.087, metrics={'AUC': 6.5, 'AUROC': 0.4524}]\n" ] } ], @@ -414,11 +405,10 @@ " lr_schedulers={\"deeptabular\": deep_sch},\n", " initializers={\"deeptabular\": XavierNormal},\n", " optimizers={\"deeptabular\": deep_opt},\n", - " metrics=[auc],\n", - " verbose=0,\n", + " metrics=[auc, auroc],\n", ")\n", "\n", - "trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=10)" + "trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=50)" ] } ], diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index 2486eb13..e939ed80 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -1,6 +1,7 @@ import numpy as np import torch from torchmetrics import Metric as TorchMetric +from torchmetrics import AUC from .wdtypes import * # noqa: F403 @@ -38,10 +39,23 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> Dict: if isinstance(metric, Metric): logs[self.prefix + metric._name] = metric(y_pred, y_true) if isinstance(metric, TorchMetric): + if not hasattr(metric, "num_classes"): + raise ValueError( + """TorchMetric does not have num_classes attribute. + Use metric in this library or extend the metric by num_classes attribute, + see `examples ` + """ + ) if metric.num_classes == 2: - metric.update(torch.round(y_pred).int(), y_true.int()) + if isinstance(metric, AUC): + metric.update(torch.round(y_pred).int(), y_true.int()) + else: + metric.update(y_pred, y_true.int()) if metric.num_classes > 2: # type: ignore[operator] - metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined] + if isinstance(metric, AUC): + metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined] + else: + metric.update(y_pred, y_true.int()) # type: ignore[attr-defined] logs[self.prefix + type(metric).__name__] = ( metric.compute().detach().cpu().numpy() ) @@ -396,3 +410,62 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: y_true_avg = self.y_true_sum / self.num_examples self.denominator += ((y_true - y_true_avg) ** 2).sum().item() return np.array((1 - (self.numerator / self.denominator))) + + +class Accuracy(Metric): + r"""Class to calculate the accuracy for both binary and categorical problems + + Parameters + ---------- + top_k: int, default = 1 + Accuracy will be computed using the top k most likely classes in + multiclass problems + + Examples + -------- + >>> import torch + >>> + >>> from pytorch_widedeep.metrics import Accuracy + >>> + >>> acc = Accuracy() + >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) + >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) + >>> acc(y_pred, y_true) + array(0.5) + >>> + >>> acc = Accuracy(top_k=2) + >>> y_true = torch.tensor([0, 1, 2]) + >>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) + >>> acc(y_pred, y_true) + array(0.66666667) + """ + + def __init__(self, top_k: int = 1): + super(Accuracy, self).__init__() + + self.top_k = top_k + self.correct_count = 0 + self.total_count = 0 + self._name = "acc" + + def reset(self): + """ + resets counters to 0 + """ + self.correct_count = 0 + self.total_count = 0 + + def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: + num_classes = y_pred.size(1) + + if num_classes == 1: + y_pred = y_pred.round() + y_true = y_true + elif num_classes > 1: + y_pred = y_pred.topk(self.top_k, 1)[1] + y_true = y_true.view(-1, 1).expand_as(y_pred) + + self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore[assignment] + self.total_count += len(y_pred) + accuracy = float(self.correct_count) / float(self.total_count) + return np.array(accuracy) diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 4d32c829..db9004a0 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -147,10 +147,14 @@ class Trainer: `__ folder in the repo - List of objects of type :obj:`torchmetrics.Metric`. This can be any - metric from torchmetrics library `Examples + metric from torchmetrics library that has attribute num_classes `Examples `_. This can also be a custom metric as - long as it is an object of type :obj:`Metric`. See `the instructions + classification-metrics>`_. + Objects of type :obj:`torchmetrics.Metric` can be extended with num_classes + attribute to be used with the Trainer object, see `examples + `. + This can also be a custom metric as long as it is an object of + type :obj:`Metric`. See `the instructions `_. class_weight: float, List or Tuple. optional. default=None - float indicating the weight of the minority class in binary classification diff --git a/tests/test_metrics/test_torchmetrics.py b/tests/test_metrics/test_torchmetrics.py index a5bd69ae..c28495f4 100644 --- a/tests/test_metrics/test_torchmetrics.py +++ b/tests/test_metrics/test_torchmetrics.py @@ -1,13 +1,14 @@ import numpy as np import torch import pytest -from torchmetrics import F1, FBeta, Recall, Accuracy, Precision +from torchmetrics import F1, FBeta, Recall, Accuracy, Precision, AUC from sklearn.metrics import ( f1_score, fbeta_score, recall_score, accuracy_score, precision_score, + auc_score, ) from pytorch_widedeep.metrics import MultipleMetrics @@ -35,9 +36,12 @@ def f2_score_bin(y_true, y_pred): ("Recall", recall_score, Recall(num_classes=2, average="none")), ("F1", f1_score, F1(num_classes=2, average="none")), ("FBeta", f2_score_bin, FBeta(beta=2, num_classes=2, average="none")), + ("AUC", auc_score, AUC()), ], ) def test_binary_metrics(metric_name, sklearn_metric, torch_metric): + if metric_name == "AUC": + torch_metric.num_classes=2 sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round()) wd_metric = MultipleMetrics(metrics=[torch_metric]) wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt) @@ -82,11 +86,14 @@ def f2_score_multi(y_true, y_pred, average): ("Recall", recall_score, Recall(num_classes=3, average="macro")), ("F1", f1_score, F1(num_classes=3, average="macro")), ("FBeta", f2_score_multi, FBeta(beta=3, num_classes=3, average="macro")), + ("AUC", auc_score, AUC()), ], ) def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric): if metric_name == "Accuracy": sk_res = sklearn_metric(y_true_multi_np, y_pred_muli_np.argmax(axis=1)) + elif metric_name == "AUC": + torch_metric.num_classes=3 else: sk_res = sklearn_metric( y_true_multi_np, y_pred_muli_np.argmax(axis=1), average="macro" From b324222105a36b8da2ba2f164d1f11536debfc24 Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Sat, 11 Dec 2021 21:53:58 +0100 Subject: [PATCH 4/5] adjusted torchmetrics handling --- VERSION | 2 +- examples/15_AUC_multiclass.ipynb | 54 ++++++++++++------------- pytorch_widedeep/datasets/_base.py | 20 ++++----- pytorch_widedeep/metrics.py | 18 +-------- pytorch_widedeep/training/trainer.py | 10 ++--- pytorch_widedeep/version.py | 2 +- tests/test_datasets/test_datasets.py | 17 +++++++- tests/test_metrics/test_torchmetrics.py | 23 ++++------- 8 files changed, 65 insertions(+), 81 deletions(-) diff --git a/VERSION b/VERSION index 9256e288..97bceaaf 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.13 \ No newline at end of file +1.0.14 \ No newline at end of file diff --git a/examples/15_AUC_multiclass.ipynb b/examples/15_AUC_multiclass.ipynb index f0af0e1a..87a8d590 100644 --- a/examples/15_AUC_multiclass.ipynb +++ b/examples/15_AUC_multiclass.ipynb @@ -26,8 +26,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-12-11 16:34:55.734357: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", - "2021-12-11 16:34:55.734404: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" + "2021-12-11 21:51:29.255591: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2021-12-11 21:51:29.255638: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], @@ -175,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "1e3f8efc", "metadata": {}, "outputs": [ @@ -193,7 +193,7 @@ "Name: class, dtype: int64" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -205,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "e4db0d6d", "metadata": {}, "outputs": [], @@ -216,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "005531a3", "metadata": {}, "outputs": [], @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "214b3071", "metadata": {}, "outputs": [], @@ -239,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "168c81f1", "metadata": {}, "outputs": [], @@ -258,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "3a7b246b", "metadata": {}, "outputs": [], @@ -268,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "7a2dac24", "metadata": {}, "outputs": [], @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "511198d4", "metadata": {}, "outputs": [ @@ -331,7 +331,7 @@ ")" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "a5359b0f", "metadata": {}, "outputs": [ @@ -355,22 +355,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n", - " warnings.warn(*args, **kwargs)\n", "/home/palo/miniconda3/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n", " warnings.warn(*args, **kwargs)\n" ] } ], "source": [ - "auc = AUC(reorder=True)\n", - "auc.num_classes = df_enc[\"class\"].nunique()\n", "auroc = AUROC(num_classes=df_enc[\"class\"].nunique())" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "34a18ac0", "metadata": { "scrolled": false @@ -380,16 +376,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████| 6/6 [00:00<00:00, 79.20it/s, loss=0.1, metrics={'AUC': 8.5, 'AUROC': 0.427}]\n", - "valid: 100%|██████████| 1/1 [00:00<00:00, 6.06it/s, loss=0.0961, metrics={'AUC': 6.5, 'AUROC': 0.419}]\n", - "epoch 2: 100%|██████████| 6/6 [00:00<00:00, 82.52it/s, loss=0.095, metrics={'AUC': 4.5, 'AUROC': 0.4418}]\n", - "valid: 100%|██████████| 1/1 [00:00<00:00, 5.69it/s, loss=0.0917, metrics={'AUC': 6.5, 'AUROC': 0.4351}]\n", - "epoch 3: 100%|██████████| 6/6 [00:00<00:00, 103.30it/s, loss=0.0908, metrics={'AUC': 5.5, 'AUROC': 0.4715}]\n", - "valid: 100%|██████████| 1/1 [00:00<00:00, 5.35it/s, loss=0.0875, metrics={'AUC': 6.5, 'AUROC': 0.4633}]\n", - "epoch 4: 100%|██████████| 6/6 [00:00<00:00, 90.88it/s, loss=0.0872, metrics={'AUC': 7.0, 'AUROC': 0.4767}]\n", - "valid: 100%|██████████| 1/1 [00:00<00:00, 5.44it/s, loss=0.0874, metrics={'AUC': 6.5, 'AUROC': 0.4652}]\n", - "epoch 5: 100%|██████████| 6/6 [00:00<00:00, 88.87it/s, loss=0.0866, metrics={'AUC': 6.0, 'AUROC': 0.4775}]\n", - "valid: 100%|██████████| 1/1 [00:00<00:00, 5.37it/s, loss=0.087, metrics={'AUC': 6.5, 'AUROC': 0.4524}]\n" + "epoch 1: 100%|██████████| 6/6 [00:00<00:00, 84.27it/s, loss=0.111, metrics={'AUROC': 0.285}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.57it/s, loss=0.106, metrics={'AUROC': 0.3309}]\n", + "epoch 2: 100%|██████████| 6/6 [00:00<00:00, 111.92it/s, loss=0.106, metrics={'AUROC': 0.3124}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 4.99it/s, loss=0.102, metrics={'AUROC': 0.375}]\n", + "epoch 3: 100%|██████████| 6/6 [00:00<00:00, 109.51it/s, loss=0.102, metrics={'AUROC': 0.3459}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 6.70it/s, loss=0.0967, metrics={'AUROC': 0.4444}]\n", + "epoch 4: 100%|██████████| 6/6 [00:00<00:00, 106.40it/s, loss=0.0984, metrics={'AUROC': 0.3717}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.93it/s, loss=0.0963, metrics={'AUROC': 0.4516}]\n", + "epoch 5: 100%|██████████| 6/6 [00:00<00:00, 93.06it/s, loss=0.0975, metrics={'AUROC': 0.3877}]\n", + "valid: 100%|██████████| 1/1 [00:00<00:00, 5.98it/s, loss=0.0961, metrics={'AUROC': 0.4404}]\n" ] } ], @@ -405,7 +401,7 @@ " lr_schedulers={\"deeptabular\": deep_sch},\n", " initializers={\"deeptabular\": XavierNormal},\n", " optimizers={\"deeptabular\": deep_opt},\n", - " metrics=[auc, auroc],\n", + " metrics=[auroc],\n", ")\n", "\n", "trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=50)" diff --git a/pytorch_widedeep/datasets/_base.py b/pytorch_widedeep/datasets/_base.py index 23187ed7..e6ea8cc7 100644 --- a/pytorch_widedeep/datasets/_base.py +++ b/pytorch_widedeep/datasets/_base.py @@ -67,7 +67,7 @@ def load_ecoli(as_frame: bool = False): See also: yeast database 3. Past Usage. - Reference: "A Probablistic Classification System for Predicting the Cellular + Reference: "A Probablistic Classification System for Predicting the Cellular Localization Sites of Proteins", Paul Horton & Kenta Nakai, Intelligent Systems in Molecular Biology, 109-115. St. Louis, USA 1996. @@ -78,23 +78,23 @@ def load_ecoli(as_frame: bool = False): Predicted Attribute: Localization site of protein. ( non-numeric ). - 4. The references below describe a predecessor to this dataset and its - development. They also give results (not cross-validated) for classification + 4. The references below describe a predecessor to this dataset and its + development. They also give results (not cross-validated) for classification by a rule-based expert system with that version of the dataset. - Reference: "Expert Sytem for Predicting Protein Localization Sites in - Gram-Negative Bacteria", Kenta Nakai & Minoru Kanehisa, + Reference: "Expert Sytem for Predicting Protein Localization Sites in + Gram-Negative Bacteria", Kenta Nakai & Minoru Kanehisa, PROTEINS: Structure, Function, and Genetics 11:95-110, 1991. Reference: "A Knowledge Base for Predicting Protein Localization Sites in - Eukaryotic Cells", Kenta Nakai & Minoru Kanehisa, + Eukaryotic Cells", Kenta Nakai & Minoru Kanehisa, Genomics 14:897-911, 1992. - 5. Number of Instances: 336 for the E.coli dataset and + 5. Number of Instances: 336 for the E.coli dataset and 6. Number of Attributes. for E.coli dataset: 8 ( 7 predictive, 1 name ) - + 7. Attribute Information. 1. Sequence Name: Accession number for the SWISS-PROT database @@ -115,7 +115,7 @@ def load_ecoli(as_frame: bool = False): 9. Class Distribution. The class is the localization site. Please see Nakai & Kanehisa referenced above for more details. cp (cytoplasm) 143 - im (inner membrane without signal sequence) 77 + im (inner membrane without signal sequence) 77 pp (perisplasm) 52 imU (inner membrane, uncleavable signal sequence) 35 om (outer membrane) 20 @@ -130,4 +130,4 @@ def load_ecoli(as_frame: bool = False): if as_frame: return df else: - return df.to_numpy() \ No newline at end of file + return df.to_numpy() diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index e939ed80..6b1cd162 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -39,23 +39,7 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> Dict: if isinstance(metric, Metric): logs[self.prefix + metric._name] = metric(y_pred, y_true) if isinstance(metric, TorchMetric): - if not hasattr(metric, "num_classes"): - raise ValueError( - """TorchMetric does not have num_classes attribute. - Use metric in this library or extend the metric by num_classes attribute, - see `examples ` - """ - ) - if metric.num_classes == 2: - if isinstance(metric, AUC): - metric.update(torch.round(y_pred).int(), y_true.int()) - else: - metric.update(y_pred, y_true.int()) - if metric.num_classes > 2: # type: ignore[operator] - if isinstance(metric, AUC): - metric.update(torch.max(y_pred, dim=1).indices, y_true.int()) # type: ignore[attr-defined] - else: - metric.update(y_pred, y_true.int()) # type: ignore[attr-defined] + metric.update(y_pred, y_true.int()) # type: ignore[attr-defined] logs[self.prefix + type(metric).__name__] = ( metric.compute().detach().cpu().numpy() ) diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index db9004a0..2a3bd564 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -147,14 +147,10 @@ class Trainer: `__ folder in the repo - List of objects of type :obj:`torchmetrics.Metric`. This can be any - metric from torchmetrics library that has attribute num_classes `Examples + metric from torchmetrics library `Examples `_. - Objects of type :obj:`torchmetrics.Metric` can be extended with num_classes - attribute to be used with the Trainer object, see `examples - `. - This can also be a custom metric as long as it is an object of - type :obj:`Metric`. See `the instructions + classification-metrics>`_. This can also be a custom metric as long as + it is an object of type :obj:`Metric`. See `the instructions `_. class_weight: float, List or Tuple. optional. default=None - float indicating the weight of the minority class in binary classification diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py index 66c607f6..b19b12ea 100644 --- a/pytorch_widedeep/version.py +++ b/pytorch_widedeep/version.py @@ -1 +1 @@ -__version__ = "1.0.13" +__version__ = "1.0.14" diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py index acfca3c2..7c0a8bb8 100644 --- a/tests/test_datasets/test_datasets.py +++ b/tests/test_datasets/test_datasets.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from pytorch_widedeep.datasets import load_adult, load_bio_kdd04 +from pytorch_widedeep.datasets import load_adult, load_bio_kdd04, load_ecoli @pytest.mark.parametrize( @@ -33,3 +33,18 @@ def test_load_adult(as_frame): assert (df.shape, type(df)) == ((48842, 15), pd.DataFrame) else: assert (df.shape, type(df)) == ((48842, 15), np.ndarray) + + +@pytest.mark.parametrize( + "as_frame", + [ + (True), + (False), + ], +) +def test_load_ecoli(as_frame): + df = load_ecoli(as_frame=as_frame) + if as_frame: + assert (df.shape, type(df)) == ((336, 9), pd.DataFrame) + else: + assert (df.shape, type(df)) == ((336, 9), np.ndarray) diff --git a/tests/test_metrics/test_torchmetrics.py b/tests/test_metrics/test_torchmetrics.py index c28495f4..2fcaa5c0 100644 --- a/tests/test_metrics/test_torchmetrics.py +++ b/tests/test_metrics/test_torchmetrics.py @@ -1,14 +1,13 @@ import numpy as np import torch import pytest -from torchmetrics import F1, FBeta, Recall, Accuracy, Precision, AUC +from torchmetrics import F1, FBeta, Recall, Accuracy, Precision from sklearn.metrics import ( f1_score, fbeta_score, recall_score, accuracy_score, precision_score, - auc_score, ) from pytorch_widedeep.metrics import MultipleMetrics @@ -31,17 +30,14 @@ def f2_score_bin(y_true, y_pred): @pytest.mark.parametrize( "metric_name, sklearn_metric, torch_metric", [ - ("Accuracy", accuracy_score, Accuracy(num_classes=2)), - ("Precision", precision_score, Precision(num_classes=2, average="none")), - ("Recall", recall_score, Recall(num_classes=2, average="none")), - ("F1", f1_score, F1(num_classes=2, average="none")), - ("FBeta", f2_score_bin, FBeta(beta=2, num_classes=2, average="none")), - ("AUC", auc_score, AUC()), + ("Accuracy", accuracy_score, Accuracy()), + ("Precision", precision_score, Precision()), + ("Recall", recall_score, Recall()), + ("F1", f1_score, F1()), + ("FBeta", f2_score_bin, FBeta(beta=2)), ], ) def test_binary_metrics(metric_name, sklearn_metric, torch_metric): - if metric_name == "AUC": - torch_metric.num_classes=2 sk_res = sklearn_metric(y_true_bin_np, y_pred_bin_np.round()) wd_metric = MultipleMetrics(metrics=[torch_metric]) wd_logs = wd_metric(y_pred_bin_pt, y_true_bin_pt) @@ -62,8 +58,8 @@ def test_binary_metrics(metric_name, sklearn_metric, torch_metric): [0.1, 0.1, 0.8], [0.1, 0.6, 0.3], [0.1, 0.8, 0.1], - [0.1, 0.6, 0.6], - [0.2, 0.6, 0.8], + [0.1, 0.3, 0.6], + [0.1, 0.1, 0.8], [0.6, 0.1, 0.3], [0.7, 0.2, 0.1], [0.1, 0.7, 0.2], @@ -86,14 +82,11 @@ def f2_score_multi(y_true, y_pred, average): ("Recall", recall_score, Recall(num_classes=3, average="macro")), ("F1", f1_score, F1(num_classes=3, average="macro")), ("FBeta", f2_score_multi, FBeta(beta=3, num_classes=3, average="macro")), - ("AUC", auc_score, AUC()), ], ) def test_muticlass_metrics(metric_name, sklearn_metric, torch_metric): if metric_name == "Accuracy": sk_res = sklearn_metric(y_true_multi_np, y_pred_muli_np.argmax(axis=1)) - elif metric_name == "AUC": - torch_metric.num_classes=3 else: sk_res = sklearn_metric( y_true_multi_np, y_pred_muli_np.argmax(axis=1), average="macro" From f17fe8420006526aa4fa1c5b0e3ad77212361190 Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Sat, 11 Dec 2021 21:58:35 +0100 Subject: [PATCH 5/5] minor code cleanup --- pytorch_widedeep/metrics.py | 60 ------------------------------------- 1 file changed, 60 deletions(-) diff --git a/pytorch_widedeep/metrics.py b/pytorch_widedeep/metrics.py index 6b1cd162..8825fe62 100644 --- a/pytorch_widedeep/metrics.py +++ b/pytorch_widedeep/metrics.py @@ -1,7 +1,6 @@ import numpy as np import torch from torchmetrics import Metric as TorchMetric -from torchmetrics import AUC from .wdtypes import * # noqa: F403 @@ -394,62 +393,3 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: y_true_avg = self.y_true_sum / self.num_examples self.denominator += ((y_true - y_true_avg) ** 2).sum().item() return np.array((1 - (self.numerator / self.denominator))) - - -class Accuracy(Metric): - r"""Class to calculate the accuracy for both binary and categorical problems - - Parameters - ---------- - top_k: int, default = 1 - Accuracy will be computed using the top k most likely classes in - multiclass problems - - Examples - -------- - >>> import torch - >>> - >>> from pytorch_widedeep.metrics import Accuracy - >>> - >>> acc = Accuracy() - >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1) - >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1) - >>> acc(y_pred, y_true) - array(0.5) - >>> - >>> acc = Accuracy(top_k=2) - >>> y_true = torch.tensor([0, 1, 2]) - >>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]]) - >>> acc(y_pred, y_true) - array(0.66666667) - """ - - def __init__(self, top_k: int = 1): - super(Accuracy, self).__init__() - - self.top_k = top_k - self.correct_count = 0 - self.total_count = 0 - self._name = "acc" - - def reset(self): - """ - resets counters to 0 - """ - self.correct_count = 0 - self.total_count = 0 - - def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray: - num_classes = y_pred.size(1) - - if num_classes == 1: - y_pred = y_pred.round() - y_true = y_true - elif num_classes > 1: - y_pred = y_pred.topk(self.top_k, 1)[1] - y_true = y_true.view(-1, 1).expand_as(y_pred) - - self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore[assignment] - self.total_count += len(y_pred) - accuracy = float(self.correct_count) / float(self.total_count) - return np.array(accuracy)