From 70e375e613674d46a5686c8020e12aeb1a54ac79 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Tue, 17 Sep 2024 20:19:07 +0100 Subject: [PATCH 1/3] Cleaning up --- .github/workflows/pypi.yml | 2 +- .pre-commit-config.yaml | 2 +- .vscode/settings.json | 2 +- poetry.lock | 53 ++++++++++++++++++++++++++++------ pyproject.toml | 1 + tutorials/adult_tutorial.ipynb | 48 ++++++++++++++++++++++++++++++ 6 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 tutorials/adult_tutorial.ipynb diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 4397127..14344c8 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -58,4 +58,4 @@ jobs: name: release-dists path: dist/ - name: Publish release distributions to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc68c96..4f42037 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-toml - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/.vscode/settings.json b/.vscode/settings.json index 9b38853..a3a1838 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,4 +4,4 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true -} \ No newline at end of file +} diff --git a/poetry.lock b/poetry.lock index 06f3d3b..d95249a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,40 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "aif360" +version = "0.6.1" +description = "IBM AI Fairness 360" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aif360-0.6.1-py3-none-any.whl", hash = "sha256:2bae0f7ba95c4902f551df33c7603c3be0319442a96b35ad99199a7d62093217"}, + {file = "aif360-0.6.1.tar.gz", hash = "sha256:635afc0bbe785e08fa242eee5e5080238193d0583bc0f5006a132050d4fb77f2"}, +] + +[package.dependencies] +matplotlib = "*" +numpy = ">=1.16" +pandas = ">=0.24.0" +scikit-learn = ">=1.0" +scipy = ">=1.2.0" + +[package.extras] +adversarialdebiasing = ["tensorflow (>=1.13.1)"] +all = ["BlackBoxAuditing", "adversarial-robustness-toolbox (>=1.0.0)", "colorama", "cvxpy (>=1.0)", "fairlearn (>=0.7,<1.0)", "igraph[plotting]", "inFairness (>=0.2.2)", "ipympl", "jinja2 (>3.1.0)", "jupyter", "lightgbm", "lime", "mlxtend", "pot", "pytest (>=3.5)", "pytest-cov (>=2.8.1)", "rpy2", "seaborn", "skorch", "sphinx", "sphinx-rtd-theme", "tensorflow (>=1.13.1)", "torch", "tqdm"] +art = ["adversarial-robustness-toolbox (>=1.0.0)"] +disparateimpactremover = ["BlackBoxAuditing"] +docs = ["fairlearn (>=0.7,<1.0)", "jinja2 (>3.1.0)", "sphinx", "sphinx-rtd-theme"] +facts = ["colorama", "mlxtend", "tqdm"] +fairadapt = ["rpy2"] +infairness = ["inFairness (>=0.2.2)", "skorch"] +lfr = ["torch"] +lime = ["lime"] +notebooks = ["igraph[plotting]", "ipympl", "jupyter", "lightgbm", "seaborn", "tqdm"] +optimaltransport = ["pot"] +optimpreproc = ["cvxpy (>=1.0)"] +reductions = ["fairlearn (>=0.7,<1.0)"] +tests = ["BlackBoxAuditing", "adversarial-robustness-toolbox (>=1.0.0)", "colorama", "cvxpy (>=1.0)", "fairlearn (>=0.7,<1.0)", "igraph[plotting]", "inFairness (>=0.2.2)", "ipympl", "jupyter", "lightgbm", "lime", "mlxtend", "pot", "pytest (>=3.5)", "pytest-cov (>=2.8.1)", "rpy2", "seaborn", "skorch", "tensorflow (>=1.13.1)", "torch", "tqdm"] + [[package]] name = "anyio" version = "4.4.0" @@ -757,18 +792,18 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.16.0" +version = "3.16.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.16.0-py3-none-any.whl", hash = "sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609"}, - {file = "filelock-3.16.0.tar.gz", hash = "sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec"}, + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, ] [package.extras] -docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.1.1)", "pytest (>=8.3.2)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.3)"] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] [[package]] @@ -2138,13 +2173,13 @@ xmp = ["defusedxml"] [[package]] name = "platformdirs" -version = "4.3.4" +version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.3.4-py3-none-any.whl", hash = "sha256:8b4ba85412f5065dae40aa19feaa02ac2be584c8b14abd70712b5cd11ad80034"}, - {file = "platformdirs-4.3.4.tar.gz", hash = "sha256:9e8a037c36fe1b1f1b5de4482e60464272cc8dca725e40b568bf2c285f7509cf"}, + {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, + {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, ] [package.extras] @@ -3374,4 +3409,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "49cbc2fed64b1485ebf6bfb648f9b44d652e7078d77325f90ad07c4629acef63" +content-hash = "573fa94c197e46f82b8218406dab61abde75e8a963e25eae459c75b8e0fb956f" diff --git a/pyproject.toml b/pyproject.toml index f138067..bad8316 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ optional = true [tool.poetry.group.tutorials.dependencies] jupyter = "^1.1.1" notebook = "^7.2.2" +aif360 = "^0.6.1" [tool.ruff] include = ["*.py", "*.pyi", "pyproject.toml", "*.ipynb"] diff --git a/tutorials/adult_tutorial.ipynb b/tutorials/adult_tutorial.ipynb new file mode 100644 index 0000000..6f9aa86 --- /dev/null +++ b/tutorials/adult_tutorial.ipynb @@ -0,0 +1,48 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use the [`aif360`](https://aif360.readthedocs.io/en/latest/Getting%20Started.html) package to load the UCI adult dataset, fit a simple model and then analyse the fairness of the model using the DA-AUC." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from aif360.datasets import AdultDataset\n", + "from aif360.algorithms.preprocessing import Reweighing\n", + "\n", + "import numpy as np\n", + "\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.pipeline import make_pipeline\n", + "from metrics.eval_metrics import print_metrics_binary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "daindex", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 827e01d53ff4c1bb8bf3612ad014ed8b50c37e1b Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Tue, 17 Sep 2024 21:37:31 +0100 Subject: [PATCH 2/3] in progress --- daindex/util.py | 6 + tutorials/adult_tutorial.ipynb | 323 ++++++++++++++++++++++++++++++++- 2 files changed, 323 insertions(+), 6 deletions(-) diff --git a/daindex/util.py b/daindex/util.py index 7136993..a364473 100644 --- a/daindex/util.py +++ b/daindex/util.py @@ -181,3 +181,9 @@ def viz( plt.axvspan(decision_boundary, 1, facecolor="b", alpha=0.1) plt.legend(fontsize=font_size, loc="best") + + return_dict = {"AUC": (a2 - a1) / a1} + if da1: + return_dict.update({"Decision AUC": (da2 - da1) / da1}) + + return return_dict diff --git a/tutorials/adult_tutorial.ipynb b/tutorials/adult_tutorial.ipynb index 6f9aa86..344d10a 100644 --- a/tutorials/adult_tutorial.ipynb +++ b/tutorials/adult_tutorial.ipynb @@ -9,19 +9,122 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from aif360.datasets import AdultDataset\n", "from aif360.algorithms.preprocessing import Reweighing\n", "\n", - "import numpy as np\n", - "\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.linear_model import LogisticRegression\n", - "from sklearn.pipeline import make_pipeline\n", - "from metrics.eval_metrics import print_metrics_binary" + "from sklearn.pipeline import make_pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The below code downloads the required files from the UCI website if they are not already present in ai360's data directory." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "adult.data already exists.\n", + "adult.test already exists.\n", + "adult.names already exists.\n" + ] + } + ], + "source": [ + "import urllib.request\n", + "import aif360\n", + "import os\n", + "\n", + "aif360_path = os.path.dirname(aif360.__file__)\n", + "adult_path = os.path.join(aif360_path, \"data\", \"raw\", \"adult\")\n", + "\n", + "# Define the file paths\n", + "files = [\"adult.data\", \"adult.test\", \"adult.names\"]\n", + "urls = [\n", + " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n", + " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\",\n", + " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names\",\n", + "]\n", + "\n", + "# Check and download files if they do not exist\n", + "for file, url in zip(files, urls):\n", + " file_path = os.path.join(adult_path, file)\n", + " if not os.path.exists(file_path):\n", + " print(f\"{file} not found. Downloading from {url}...\")\n", + " urllib.request.urlretrieve(url, file_path)\n", + " print(f\"{file} downloaded.\")\n", + " else:\n", + " print(f\"{file} already exists.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Missing Data: 3620 rows removed from AdultDataset.\n" + ] + } + ], + "source": [ + "privileged_groups1 = [{\"sex\": 1}]\n", + "unprivileged_groups1 = [{\"sex\": 0}]\n", + "privileged_groups2 = [{\"race\": 1}]\n", + "unprivileged_groups2 = [{\"race\": 0}]\n", + "\n", + "\n", + "dataset = AdultDataset()\n", + "\n", + "(dataset_orig_train, dataset_orig_val) = dataset.split([0.7], shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "model = make_pipeline(StandardScaler(), LogisticRegression(solver=\"liblinear\", random_state=1))\n", + "fit_params = {\"logisticregression__sample_weight\": dataset_orig_train.instance_weights}\n", + "\n", + "lr_orig = model.fit(dataset_orig_train.features, dataset_orig_train.labels.ravel(), **fit_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "rw_sex = Reweighing(unprivileged_groups=unprivileged_groups1, privileged_groups=privileged_groups1)\n", + "rw_race = Reweighing(unprivileged_groups=unprivileged_groups2, privileged_groups=privileged_groups2)\n", + "\n", + "trans_sex_dataset = rw_sex.fit(dataset).transform(dataset)\n", + "trans_sex_race_dataset = rw_race.fit(trans_sex_dataset).transform(trans_sex_dataset)\n", + "\n", + "trans_sex_dataset_train = rw_sex.fit(dataset_orig_train).transform(dataset_orig_train)\n", + "trans_sex_race_dataset_train = rw_race.fit(trans_sex_dataset_train).transform(trans_sex_dataset_train)\n", + "\n", + "model = make_pipeline(StandardScaler(), LogisticRegression(solver=\"liblinear\", random_state=1))\n", + "fit_params = {\"logisticregression__sample_weight\": trans_sex_race_dataset_train.instance_weights}\n", + "lr_rw = model.fit(trans_sex_race_dataset_train.features, trans_sex_race_dataset_train.labels.ravel(), **fit_params)" ] }, { @@ -29,7 +132,215 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from daindex.core import deterioration_index\n", + "import numpy as np\n", + "\n", + "def obtain_det_alo_index(\n", + " df,\n", + " scores,\n", + " det_feature,\n", + " score_threshold,\n", + " cohort_name,\n", + " det_label,\n", + " score_margin=0.05,\n", + " det_threshold=3,\n", + " min_v=-5,\n", + " max_v=15,\n", + " reverse=False,\n", + " is_discrete=True,\n", + " search_bandwidth=True,\n", + " det_feature_func=None,\n", + "):\n", + " lb = score_threshold - score_margin\n", + " up = score_threshold + score_margin\n", + " det_list = []\n", + " for i, r in df.iterrows():\n", + " p = scores[i]\n", + " if lb <= p <= up:\n", + " if det_feature_func is not None:\n", + " det_list.append(det_feature_func(r))\n", + " else:\n", + " det_list.append(r[det_feature])\n", + " if len(det_list) > 20:\n", + " X = np.array(det_list)\n", + " di_ret = deterioration_index(\n", + " X[~np.isnan(X)].reshape(-1, 1),\n", + " min_v,\n", + " max_v,\n", + " threshold=det_threshold,\n", + " plot_title=f\"{cohort_name} | {det_label}\",\n", + " reverse=reverse,\n", + " is_discrete=is_discrete,\n", + " search_bandwidth=search_bandwidth,\n", + " do_plot=False,\n", + " )\n", + " return score_threshold, len(det_list), di_ret[\"k-step\"]\n", + " else:\n", + " return score_threshold, 0, 0\n", + "\n", + "\n", + "def get_scores(models, df, features):\n", + " predicted_probs = np.array([m.predict_proba(df[features].to_numpy()) for m in models])\n", + " return predicted_probs[:, :, 1].mean(axis=0)\n", + "\n", + "\n", + "df_white = dataset[dataset[\"ethnicity\"] == 1]\n", + "df_non_white = dataset[dataset[\"ethnicity\"] == 0]\n", + "\n", + "df_male = kidney_df[kidney_df[\"gender\"] == 1]\n", + "df_female = kidney_df[kidney_df[\"gender\"] == 0]\n", + "\n", + "steps = 50\n", + "det_feature = \"Creatinine Max\"\n", + "det_label = \"Creatinine Max\"\n", + "is_discrete = False\n", + "min_v = min(np.min(df_white[det_feature]), np.min(df_non_white[det_feature]))\n", + "max_v = max(np.max(df_white[det_feature]), np.max(df_non_white[det_feature]))\n", + "print(min_v, max_v)\n", + "det_threshold = 1.35 # .7" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score, precision_score, recall_score, precision_recall_curve, auc\n", + "from aif360.metrics import ClassificationMetric\n", + "from collections import defaultdict\n", + "from metrics.eval_metrics import print_metrics_binary\n", + "\n", + "\n", + "def compute_deterioration_index(clf):\n", + " white_ret = []\n", + " non_white_ret = []\n", + " male_ret = []\n", + " female_ret = []\n", + "\n", + " scores1 = get_scores(clf, df_white, feature_list)\n", + " scores2 = get_scores(clf, df_non_white, feature_list)\n", + " scores3 = get_scores(clf, df_male, feature_list)\n", + " scores4 = get_scores(clf, df_female, feature_list)\n", + "\n", + " # Calculate DA AUC for this fold and accumulate the values\n", + " for s in range(1, steps + 1):\n", + " white_ret.append(\n", + " obtain_det_alo_index(\n", + " df_white,\n", + " scores1,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"White cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + " non_white_ret.append(\n", + " obtain_det_alo_index(\n", + " df_non_white,\n", + " scores2,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"Non-White cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + " male_ret.append(\n", + " obtain_det_alo_index(\n", + " df_male,\n", + " scores3,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"Male cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + " female_ret.append(\n", + " obtain_det_alo_index(\n", + " df_female,\n", + " scores4,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"Female cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + "\n", + " return white_ret, non_white_ret, male_ret, female_ret" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_daauc_values_and_graphs(results):\n", + " # Visualization of Average DA AUC Values for Each Model\n", + " viz_config = {\"fig_size\": (4, 3), \"font_size\": 12}\n", + "\n", + " # Loop through each model in all_results to generate the visualizations\n", + " for model_name, _ in models.items():\n", + " print(f\"\\n\\n{model_name}\")\n", + " for i in range(0, 10):\n", + " white_ret = results[model_name][f\"fold_{i}\"][\"white_ret\"]\n", + " non_white_ret = results[model_name][f\"fold_{i}\"][\"non_white_ret\"]\n", + " male_ret = results[model_name][f\"fold_{i}\"][\"male_ret\"]\n", + " female_ret = results[model_name][f\"fold_{i}\"][\"female_ret\"]\n", + "\n", + " print(\"\\n===============\")\n", + " print(i)\n", + " print(\"ethnicity:\")\n", + " # Visualize White vs Non-White average DA AUC\n", + " a1, a2, da1, da2 = viz(\n", + " np.array(white_ret),\n", + " np.array(non_white_ret),\n", + " \"White\",\n", + " \"Non-White\",\n", + " f\"{det_label} {'>='}{det_threshold}\",\n", + " f\"{model_name}\",\n", + " viz_config,\n", + " )\n", + " results[model_name][f\"fold_{i}\"][\"ethnicity_auc\"] = (a2 - a1) / a1\n", + " try:\n", + " results[model_name][f\"fold_{i}\"][\"ethnicity_dauc\"] = (da2 - da1) / da1\n", + " except ZeroDivisionError:\n", + " results[model_name][f\"fold_{i}\"][\"ethnicity_dauc\"] = 0\n", + " print(\"sex:\")\n", + " # Visualize Male vs Female average DA AUC\n", + " a1, a2, da1, da2 = viz(\n", + " np.array(male_ret),\n", + " np.array(female_ret),\n", + " \"Male\",\n", + " \"Female\",\n", + " f\"{det_label} {'>='}{det_threshold}\",\n", + " f\"{model_name}\",\n", + " viz_config,\n", + " )\n", + " results[model_name][f\"fold_{i}\"][\"sex_auc\"] = (a2 - a1) / a1\n", + " try:\n", + " results[model_name][f\"fold_{i}\"][\"sex_dauc\"] = (da2 - da1) / da1\n", + " except ZeroDivisionError:\n", + " results[model_name][f\"fold_{i}\"][\"sex_dauc\"] = 0\n", + " return results" + ] } ], "metadata": { From d929b5b099691a50832b91a3c373acf6ee3a12dc Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Wed, 18 Sep 2024 15:37:50 +0100 Subject: [PATCH 3/3] Closes #13 --- daindex/core.py | 1 - daindex/util.py | 2 +- tutorials/adult_tutorial.ipynb | 630 +++++++++++++++++++++++---------- 3 files changed, 445 insertions(+), 188 deletions(-) diff --git a/daindex/core.py b/daindex/core.py index 7a30070..aa50f5a 100644 --- a/daindex/core.py +++ b/daindex/core.py @@ -102,7 +102,6 @@ def kde_estimate( kde = KernelDensity(bandwidth=bandwidth, kernel=kernel) kde.fit(X) - return kde, bandwidth diff --git a/daindex/util.py b/daindex/util.py index a364473..2326ba1 100644 --- a/daindex/util.py +++ b/daindex/util.py @@ -131,7 +131,7 @@ def viz( g2_label: str, deterioration_label: str, allocation_label: str, - config: dict, + config: dict = {}, decision_boundary: float = 0.5, ) -> None: """ diff --git a/tutorials/adult_tutorial.ipynb b/tutorials/adult_tutorial.ipynb index 344d10a..d851411 100644 --- a/tutorials/adult_tutorial.ipynb +++ b/tutorials/adult_tutorial.ipynb @@ -9,16 +9,19 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from aif360.datasets import AdultDataset\n", - "from aif360.algorithms.preprocessing import Reweighing\n", - "\n", - "from sklearn.preprocessing import StandardScaler\n", + "import numpy as np\n", + "from aif360.datasets import MEPSDataset19\n", + "from aif360.explainers import MetricTextExplainer\n", + "from aif360.metrics import BinaryLabelDatasetMetric, ClassificationMetric\n", "from sklearn.linear_model import LogisticRegression\n", - "from sklearn.pipeline import make_pipeline" + "from sklearn.pipeline import make_pipeline\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "np.random.seed(1)" ] }, { @@ -30,101 +33,418 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "adult.data already exists.\n", - "adult.test already exists.\n", - "adult.names already exists.\n" - ] - } - ], + "outputs": [], "source": [ - "import urllib.request\n", - "import aif360\n", "import os\n", + "import shutil\n", + "import subprocess\n", "\n", - "aif360_path = os.path.dirname(aif360.__file__)\n", - "adult_path = os.path.join(aif360_path, \"data\", \"raw\", \"adult\")\n", - "\n", - "# Define the file paths\n", - "files = [\"adult.data\", \"adult.test\", \"adult.names\"]\n", - "urls = [\n", - " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n", - " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\",\n", - " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names\",\n", - "]\n", - "\n", - "# Check and download files if they do not exist\n", - "for file, url in zip(files, urls):\n", - " file_path = os.path.join(adult_path, file)\n", - " if not os.path.exists(file_path):\n", - " print(f\"{file} not found. Downloading from {url}...\")\n", - " urllib.request.urlretrieve(url, file_path)\n", - " print(f\"{file} downloaded.\")\n", - " else:\n", - " print(f\"{file} already exists.\")" + "import aif360\n", + "\n", + "aif360_location = os.path.dirname(aif360.__file__)\n", + "meps_data_dir = os.path.join(aif360_location, \"data\", \"raw\", \"meps\")\n", + "h181_file_path = os.path.join(meps_data_dir, \"h181.csv\")\n", + "\n", + "if not os.path.isfile(h181_file_path):\n", + " r_script_path = os.path.join(meps_data_dir, \"generate_data.R\")\n", + " process = subprocess.Popen([\"Rscript\", r_script_path], stdin=subprocess.PIPE)\n", + " process.communicate(input=b\"y\\n\")\n", + "\n", + " # Move the generated CSV files to meps_data_dir\n", + " generated_files = [\"h181.csv\", \"h192.csv\"]\n", + " for file_name in generated_files:\n", + " src_path = os.path.join(os.getcwd(), file_name)\n", + " dest_path = os.path.join(meps_data_dir, file_name)\n", + " if os.path.isfile(src_path):\n", + " shutil.move(src_path, dest_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocessing_w_multimorb(df):\n", + " \"\"\"\n", + " 1.Create a new column, RACE that is 'White' if RACEV2X = 1 and HISPANX = 2 i.e. non Hispanic White\n", + " and 'non-White' otherwise\n", + " 2. Restrict to Panel 19\n", + " 3. RENAME all columns that are PANEL/ROUND SPECIFIC\n", + " 4. Drop rows based on certain values of individual features that correspond to missing/unknown - generally < -1\n", + " 5. Compute UTILIZATION, binarize it to 0 (< 10) and 1 (>= 10)\n", + " \"\"\"\n", + "\n", + " def race(row):\n", + " if (row[\"HISPANX\"] == 2) and (\n", + " row[\"RACEV2X\"] == 1\n", + " ): # non-Hispanic Whites are marked as WHITE; all others as NON-WHITE\n", + " return \"White\"\n", + " return \"Non-White\"\n", + "\n", + " df[\"RACEV2X\"] = df.apply(lambda row: race(row), axis=1)\n", + " df = df.rename(columns={\"RACEV2X\": \"RACE\"})\n", + "\n", + " df = df[df[\"PANEL\"] == 19]\n", + "\n", + " # RENAME COLUMNS\n", + " df = df.rename(\n", + " columns={\n", + " \"FTSTU53X\": \"FTSTU\",\n", + " \"ACTDTY53\": \"ACTDTY\",\n", + " \"HONRDC53\": \"HONRDC\",\n", + " \"RTHLTH53\": \"RTHLTH\",\n", + " \"MNHLTH53\": \"MNHLTH\",\n", + " \"CHBRON53\": \"CHBRON\",\n", + " \"JTPAIN53\": \"JTPAIN\",\n", + " \"PREGNT53\": \"PREGNT\",\n", + " \"WLKLIM53\": \"WLKLIM\",\n", + " \"ACTLIM53\": \"ACTLIM\",\n", + " \"SOCLIM53\": \"SOCLIM\",\n", + " \"COGLIM53\": \"COGLIM\",\n", + " \"EMPST53\": \"EMPST\",\n", + " \"REGION53\": \"REGION\",\n", + " \"MARRY53X\": \"MARRY\",\n", + " \"AGE53X\": \"AGE\",\n", + " \"POVCAT15\": \"POVCAT\",\n", + " \"INSCOV15\": \"INSCOV\",\n", + " }\n", + " )\n", + "\n", + " df = df[df[\"REGION\"] >= 0] # remove values -1\n", + " df = df[df[\"AGE\"] >= 0] # remove values -1\n", + "\n", + " df = df[df[\"MARRY\"] >= 0] # remove values -1, -7, -8, -9\n", + "\n", + " df = df[df[\"ASTHDX\"] >= 0] # remove values -1, -7, -8, -9\n", + "\n", + " df = df[\n", + " (\n", + " df[\n", + " [\n", + " \"FTSTU\",\n", + " \"ACTDTY\",\n", + " \"HONRDC\",\n", + " \"RTHLTH\",\n", + " \"MNHLTH\",\n", + " \"HIBPDX\",\n", + " \"CHDDX\",\n", + " \"ANGIDX\",\n", + " \"EDUCYR\",\n", + " \"HIDEG\",\n", + " \"MIDX\",\n", + " \"OHRTDX\",\n", + " \"STRKDX\",\n", + " \"EMPHDX\",\n", + " \"CHBRON\",\n", + " \"CHOLDX\",\n", + " \"CANCERDX\",\n", + " \"DIABDX\",\n", + " \"JTPAIN\",\n", + " \"ARTHDX\",\n", + " \"ARTHTYPE\",\n", + " \"ASTHDX\",\n", + " \"ADHDADDX\",\n", + " \"PREGNT\",\n", + " \"WLKLIM\",\n", + " \"ACTLIM\",\n", + " \"SOCLIM\",\n", + " \"COGLIM\",\n", + " \"DFHEAR42\",\n", + " \"DFSEE42\",\n", + " \"ADSMOK42\",\n", + " \"PHQ242\",\n", + " \"EMPST\",\n", + " \"POVCAT\",\n", + " \"INSCOV\",\n", + " ]\n", + " ]\n", + " >= -1\n", + " ).all(1)\n", + " ] # for all other categorical features, remove values < -1\n", + "\n", + " def utilization(row):\n", + " return row[\"OBTOTV15\"] + row[\"OPTOTV15\"] + row[\"ERTOT15\"] + row[\"IPNGTD15\"] + row[\"HHTOTD15\"]\n", + "\n", + " df[\"TOTEXP15\"] = df.apply(lambda row: utilization(row), axis=1)\n", + " lessE = df[\"TOTEXP15\"] < 10.0\n", + " df.loc[lessE, \"TOTEXP15\"] = 0.0\n", + " moreE = df[\"TOTEXP15\"] >= 10.0\n", + " df.loc[moreE, \"TOTEXP15\"] = 1.0\n", + " df[\"MULTIMORBIDITY\"] = (\n", + " df.filter(regex=\"DX$|CHBRON$|JTPAIN$\").drop(columns=[\"ADHDADDX\"]).apply(lambda x: (x == 1).sum(), axis=1)\n", + " )\n", + "\n", + " df = df.rename(columns={\"TOTEXP15\": \"UTILIZATION\"})\n", + " return df" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Missing Data: 3620 rows removed from AdultDataset.\n" - ] - } - ], + "outputs": [], "source": [ - "privileged_groups1 = [{\"sex\": 1}]\n", - "unprivileged_groups1 = [{\"sex\": 0}]\n", - "privileged_groups2 = [{\"race\": 1}]\n", - "unprivileged_groups2 = [{\"race\": 0}]\n", + "dataset_orig_panel19_train, dataset_orig_panel19_val, dataset_orig_panel19_test = MEPSDataset19(\n", + " custom_preprocessing=preprocessing_w_multimorb,\n", + " features_to_keep=[\n", + " \"REGION\",\n", + " \"AGE\",\n", + " \"SEX\",\n", + " \"RACE\",\n", + " \"MARRY\",\n", + " \"FTSTU\",\n", + " \"ACTDTY\",\n", + " \"HONRDC\",\n", + " \"RTHLTH\",\n", + " \"MNHLTH\",\n", + " \"HIBPDX\",\n", + " \"CHDDX\",\n", + " \"ANGIDX\",\n", + " \"MIDX\",\n", + " \"OHRTDX\",\n", + " \"STRKDX\",\n", + " \"EMPHDX\",\n", + " \"CHBRON\",\n", + " \"CHOLDX\",\n", + " \"CANCERDX\",\n", + " \"DIABDX\",\n", + " \"JTPAIN\",\n", + " \"ARTHDX\",\n", + " \"ARTHTYPE\",\n", + " \"ASTHDX\",\n", + " \"ADHDADDX\",\n", + " \"PREGNT\",\n", + " \"WLKLIM\",\n", + " \"ACTLIM\",\n", + " \"SOCLIM\",\n", + " \"COGLIM\",\n", + " \"DFHEAR42\",\n", + " \"DFSEE42\",\n", + " \"ADSMOK42\",\n", + " \"PCS42\",\n", + " \"MCS42\",\n", + " \"K6SUM42\",\n", + " \"PHQ242\",\n", + " \"EMPST\",\n", + " \"POVCAT\",\n", + " \"INSCOV\",\n", + " \"UTILIZATION\",\n", + " \"PERWT15F\",\n", + " \"MULTIMORBIDITY\",\n", + " ],\n", + ").split([0.5, 0.8], shuffle=True)\n", "\n", + "sens_ind = 0\n", + "sens_attr = dataset_orig_panel19_train.protected_attribute_names[sens_ind]\n", "\n", - "dataset = AdultDataset()\n", + "unprivileged_groups = [{sens_attr: v} for v in dataset_orig_panel19_train.unprivileged_protected_attributes[sens_ind]]\n", + "privileged_groups = [{sens_attr: v} for v in dataset_orig_panel19_train.privileged_protected_attributes[sens_ind]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def describe(train=None, val=None, test=None):\n", + " if train is not None:\n", + " print(train.features.shape)\n", + " if val is not None:\n", + " print(val.features.shape)\n", + " print(test.features.shape)\n", + " print(test.favorable_label, test.unfavorable_label)\n", + " print(test.protected_attribute_names)\n", + " print(test.privileged_protected_attributes, test.unprivileged_protected_attributes)\n", + " print(test.feature_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "describe(dataset_orig_panel19_train, dataset_orig_panel19_val, dataset_orig_panel19_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_orig_panel19_train = BinaryLabelDatasetMetric(\n", + " dataset_orig_panel19_train, unprivileged_groups=unprivileged_groups, privileged_groups=privileged_groups\n", + ")\n", + "explainer_orig_panel19_train = MetricTextExplainer(metric_orig_panel19_train)\n", "\n", - "(dataset_orig_train, dataset_orig_val) = dataset.split([0.7], shuffle=True)" + "print(explainer_orig_panel19_train.disparate_impact())" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "dataset = dataset_orig_panel19_train\n", "model = make_pipeline(StandardScaler(), LogisticRegression(solver=\"liblinear\", random_state=1))\n", - "fit_params = {\"logisticregression__sample_weight\": dataset_orig_train.instance_weights}\n", + "fit_params = {\"logisticregression__sample_weight\": dataset.instance_weights}\n", "\n", - "lr_orig = model.fit(dataset_orig_train.features, dataset_orig_train.labels.ravel(), **fit_params)" + "lr_orig_panel19 = model.fit(dataset.features, dataset.labels.ravel(), **fit_params)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "rw_sex = Reweighing(unprivileged_groups=unprivileged_groups1, privileged_groups=privileged_groups1)\n", - "rw_race = Reweighing(unprivileged_groups=unprivileged_groups2, privileged_groups=privileged_groups2)\n", + "from collections import defaultdict\n", "\n", - "trans_sex_dataset = rw_sex.fit(dataset).transform(dataset)\n", - "trans_sex_race_dataset = rw_race.fit(trans_sex_dataset).transform(trans_sex_dataset)\n", "\n", - "trans_sex_dataset_train = rw_sex.fit(dataset_orig_train).transform(dataset_orig_train)\n", - "trans_sex_race_dataset_train = rw_race.fit(trans_sex_dataset_train).transform(trans_sex_dataset_train)\n", + "def test(dataset, model, thresh_arr):\n", + " try:\n", + " # sklearn classifier\n", + " y_val_pred_prob = model.predict_proba(dataset.features)\n", + " pos_ind = np.where(model.classes_ == dataset.favorable_label)[0][0]\n", + " except AttributeError:\n", + " # aif360 inprocessing algorithm\n", + " y_val_pred_prob = model.predict(dataset).scores\n", + " pos_ind = 0\n", "\n", - "model = make_pipeline(StandardScaler(), LogisticRegression(solver=\"liblinear\", random_state=1))\n", - "fit_params = {\"logisticregression__sample_weight\": trans_sex_race_dataset_train.instance_weights}\n", - "lr_rw = model.fit(trans_sex_race_dataset_train.features, trans_sex_race_dataset_train.labels.ravel(), **fit_params)" + " metric_arrs = defaultdict(list)\n", + " for thresh in thresh_arr:\n", + " y_val_pred = (y_val_pred_prob[:, pos_ind] > thresh).astype(np.float64)\n", + "\n", + " dataset_pred = dataset.copy()\n", + " dataset_pred.labels = y_val_pred\n", + " metric = ClassificationMetric(\n", + " dataset, dataset_pred, unprivileged_groups=unprivileged_groups, privileged_groups=privileged_groups\n", + " )\n", + "\n", + " metric_arrs[\"bal_acc\"].append((metric.true_positive_rate() + metric.true_negative_rate()) / 2)\n", + " metric_arrs[\"avg_odds_diff\"].append(metric.average_odds_difference())\n", + " metric_arrs[\"disp_imp\"].append(metric.disparate_impact())\n", + " metric_arrs[\"stat_par_diff\"].append(metric.statistical_parity_difference())\n", + " metric_arrs[\"eq_opp_diff\"].append(metric.equal_opportunity_difference())\n", + " metric_arrs[\"theil_ind\"].append(metric.theil_index())\n", + "\n", + " return metric_arrs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "thresh_arr = np.linspace(0.01, 0.5, 50)\n", + "val_metrics = test(dataset=dataset_orig_panel19_val, model=lr_orig_panel19, thresh_arr=thresh_arr)\n", + "lr_orig_best_ind = np.argmax(val_metrics[\"bal_acc\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def describe_metrics(metrics, thresh_arr):\n", + " best_ind = np.argmax(metrics[\"bal_acc\"])\n", + " print(\"Threshold corresponding to Best balanced accuracy: {:6.4f}\".format(thresh_arr[best_ind]))\n", + " print(\"Best balanced accuracy: {:6.4f}\".format(metrics[\"bal_acc\"][best_ind]))\n", + " disp_imp_at_best_ind = 1 - min(metrics[\"disp_imp\"][best_ind], 1 / metrics[\"disp_imp\"][best_ind])\n", + " print(\"Corresponding 1-min(DI, 1/DI) value: {:6.4f}\".format(disp_imp_at_best_ind))\n", + " print(\"Corresponding average odds difference value: {:6.4f}\".format(metrics[\"avg_odds_diff\"][best_ind]))\n", + " print(\"Corresponding statistical parity difference value: {:6.4f}\".format(metrics[\"stat_par_diff\"][best_ind]))\n", + " print(\"Corresponding equal opportunity difference value: {:6.4f}\".format(metrics[\"eq_opp_diff\"][best_ind]))\n", + " print(\"Corresponding Theil index value: {:6.4f}\".format(metrics[\"theil_ind\"][best_ind]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "describe_metrics(val_metrics, thresh_arr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lr_orig_metrics = test(\n", + " dataset=dataset_orig_panel19_test, model=lr_orig_panel19, thresh_arr=[thresh_arr[lr_orig_best_ind]]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "describe_metrics(lr_orig_metrics, [thresh_arr[lr_orig_best_ind]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = dataset.convert_to_dataframe()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df[\"MULTIMORBIDITY\"].hist()\n", + "df_add = df.copy()\n", + "df_add[\"RTHLTH\"] = df_add.filter(regex=\"RTHLTH\").idxmax(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_add.boxplot(column=\"MULTIMORBIDITY\", by=\"RTHLTH\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feature_list = dataset.feature_names\n", + "target = dataset.label_names[0]\n", + "det_feature = \"MULTIMORBIDITY\"\n", + "reverse = False\n", + "\n", + "df_features = df[feature_list]\n", + "df_white, df_non_white = df_features[df_features[\"RACE\"] == 1], df_features[df_features[\"RACE\"] == 0]\n", + "\n", + "steps = 50\n", + "\n", + "is_discrete = True\n", + "min_v = min(np.min(df_white[det_feature]), np.min(df_non_white[det_feature]))\n", + "max_v = max(np.max(df_white[det_feature]), np.max(df_non_white[det_feature]))\n", + "print(min_v, max_v)\n", + "det_threshold = 1 # .7" ] }, { @@ -134,7 +454,14 @@ "outputs": [], "source": [ "from daindex.core import deterioration_index\n", - "import numpy as np\n", + "\n", + "\n", + "def get_scores(models, df, target):\n", + " if not isinstance(models, list):\n", + " models = [models]\n", + " predicted_probs = np.array([m.predict_proba(df.to_numpy()) for m in models])\n", + " return predicted_probs[:, :, 1].mean(axis=0)\n", + "\n", "\n", "def obtain_det_alo_index(\n", " df,\n", @@ -142,26 +469,28 @@ " det_feature,\n", " score_threshold,\n", " cohort_name,\n", - " det_label,\n", + " det_label=det_feature,\n", " score_margin=0.05,\n", - " det_threshold=3,\n", + " det_threshold=2,\n", " min_v=-5,\n", " max_v=15,\n", " reverse=False,\n", " is_discrete=True,\n", - " search_bandwidth=True,\n", + " optimise_bandwidth=True,\n", " det_feature_func=None,\n", "):\n", " lb = score_threshold - score_margin\n", " up = score_threshold + score_margin\n", " det_list = []\n", - " for i, r in df.iterrows():\n", + " i = 0\n", + " for _, r in df.iterrows():\n", " p = scores[i]\n", " if lb <= p <= up:\n", " if det_feature_func is not None:\n", " det_list.append(det_feature_func(r))\n", " else:\n", " det_list.append(r[det_feature])\n", + " i += 1\n", " if len(det_list) > 20:\n", " X = np.array(det_list)\n", " di_ret = deterioration_index(\n", @@ -172,33 +501,12 @@ " plot_title=f\"{cohort_name} | {det_label}\",\n", " reverse=reverse,\n", " is_discrete=is_discrete,\n", - " search_bandwidth=search_bandwidth,\n", + " optimise_bandwidth=optimise_bandwidth,\n", " do_plot=False,\n", " )\n", " return score_threshold, len(det_list), di_ret[\"k-step\"]\n", " else:\n", - " return score_threshold, 0, 0\n", - "\n", - "\n", - "def get_scores(models, df, features):\n", - " predicted_probs = np.array([m.predict_proba(df[features].to_numpy()) for m in models])\n", - " return predicted_probs[:, :, 1].mean(axis=0)\n", - "\n", - "\n", - "df_white = dataset[dataset[\"ethnicity\"] == 1]\n", - "df_non_white = dataset[dataset[\"ethnicity\"] == 0]\n", - "\n", - "df_male = kidney_df[kidney_df[\"gender\"] == 1]\n", - "df_female = kidney_df[kidney_df[\"gender\"] == 0]\n", - "\n", - "steps = 50\n", - "det_feature = \"Creatinine Max\"\n", - "det_label = \"Creatinine Max\"\n", - "is_discrete = False\n", - "min_v = min(np.min(df_white[det_feature]), np.min(df_non_white[det_feature]))\n", - "max_v = max(np.max(df_white[det_feature]), np.max(df_non_white[det_feature]))\n", - "print(min_v, max_v)\n", - "det_threshold = 1.35 # .7" + " return score_threshold, 0, 0" ] }, { @@ -207,22 +515,12 @@ "metadata": {}, "outputs": [], "source": [ - "from sklearn.metrics import accuracy_score, precision_score, recall_score, precision_recall_curve, auc\n", - "from aif360.metrics import ClassificationMetric\n", - "from collections import defaultdict\n", - "from metrics.eval_metrics import print_metrics_binary\n", - "\n", - "\n", "def compute_deterioration_index(clf):\n", " white_ret = []\n", " non_white_ret = []\n", - " male_ret = []\n", - " female_ret = []\n", "\n", " scores1 = get_scores(clf, df_white, feature_list)\n", " scores2 = get_scores(clf, df_non_white, feature_list)\n", - " scores3 = get_scores(clf, df_male, feature_list)\n", - " scores4 = get_scores(clf, df_female, feature_list)\n", "\n", " # Calculate DA AUC for this fold and accumulate the values\n", " for s in range(1, steps + 1):\n", @@ -233,11 +531,11 @@ " det_feature,\n", " s / steps,\n", " cohort_name=\"White cohort\",\n", - " det_label=det_label,\n", " det_threshold=det_threshold,\n", " min_v=min_v,\n", " max_v=max_v,\n", " is_discrete=is_discrete,\n", + " reverse=reverse,\n", " )\n", " )\n", " non_white_ret.append(\n", @@ -247,43 +545,15 @@ " det_feature,\n", " s / steps,\n", " cohort_name=\"Non-White cohort\",\n", - " det_label=det_label,\n", - " det_threshold=det_threshold,\n", - " min_v=min_v,\n", - " max_v=max_v,\n", - " is_discrete=is_discrete,\n", - " )\n", - " )\n", - " male_ret.append(\n", - " obtain_det_alo_index(\n", - " df_male,\n", - " scores3,\n", - " det_feature,\n", - " s / steps,\n", - " cohort_name=\"Male cohort\",\n", - " det_label=det_label,\n", - " det_threshold=det_threshold,\n", - " min_v=min_v,\n", - " max_v=max_v,\n", - " is_discrete=is_discrete,\n", - " )\n", - " )\n", - " female_ret.append(\n", - " obtain_det_alo_index(\n", - " df_female,\n", - " scores4,\n", - " det_feature,\n", - " s / steps,\n", - " cohort_name=\"Female cohort\",\n", - " det_label=det_label,\n", " det_threshold=det_threshold,\n", " min_v=min_v,\n", " max_v=max_v,\n", " is_discrete=is_discrete,\n", + " reverse=reverse,\n", " )\n", " )\n", "\n", - " return white_ret, non_white_ret, male_ret, female_ret" + " return white_ret, non_white_ret" ] }, { @@ -292,55 +562,35 @@ "metadata": {}, "outputs": [], "source": [ - "def calculate_daauc_values_and_graphs(results):\n", - " # Visualization of Average DA AUC Values for Each Model\n", - " viz_config = {\"fig_size\": (4, 3), \"font_size\": 12}\n", - "\n", - " # Loop through each model in all_results to generate the visualizations\n", - " for model_name, _ in models.items():\n", - " print(f\"\\n\\n{model_name}\")\n", - " for i in range(0, 10):\n", - " white_ret = results[model_name][f\"fold_{i}\"][\"white_ret\"]\n", - " non_white_ret = results[model_name][f\"fold_{i}\"][\"non_white_ret\"]\n", - " male_ret = results[model_name][f\"fold_{i}\"][\"male_ret\"]\n", - " female_ret = results[model_name][f\"fold_{i}\"][\"female_ret\"]\n", - "\n", - " print(\"\\n===============\")\n", - " print(i)\n", - " print(\"ethnicity:\")\n", - " # Visualize White vs Non-White average DA AUC\n", - " a1, a2, da1, da2 = viz(\n", - " np.array(white_ret),\n", - " np.array(non_white_ret),\n", - " \"White\",\n", - " \"Non-White\",\n", - " f\"{det_label} {'>='}{det_threshold}\",\n", - " f\"{model_name}\",\n", - " viz_config,\n", - " )\n", - " results[model_name][f\"fold_{i}\"][\"ethnicity_auc\"] = (a2 - a1) / a1\n", - " try:\n", - " results[model_name][f\"fold_{i}\"][\"ethnicity_dauc\"] = (da2 - da1) / da1\n", - " except ZeroDivisionError:\n", - " results[model_name][f\"fold_{i}\"][\"ethnicity_dauc\"] = 0\n", - " print(\"sex:\")\n", - " # Visualize Male vs Female average DA AUC\n", - " a1, a2, da1, da2 = viz(\n", - " np.array(male_ret),\n", - " np.array(female_ret),\n", - " \"Male\",\n", - " \"Female\",\n", - " f\"{det_label} {'>='}{det_threshold}\",\n", - " f\"{model_name}\",\n", - " viz_config,\n", - " )\n", - " results[model_name][f\"fold_{i}\"][\"sex_auc\"] = (a2 - a1) / a1\n", - " try:\n", - " results[model_name][f\"fold_{i}\"][\"sex_dauc\"] = (da2 - da1) / da1\n", - " except ZeroDivisionError:\n", - " results[model_name][f\"fold_{i}\"][\"sex_dauc\"] = 0\n", - " return results" + "white_ret, non_white_ret = compute_deterioration_index(lr_orig_panel19)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from daindex.util import viz\n", + "\n", + "# Visualize White vs Non-White average DA AUC\n", + "auc_dict = viz(\n", + " np.array(white_ret),\n", + " np.array(non_white_ret),\n", + " \"White\",\n", + " \"Non-White\",\n", + " f\"{det_feature}{'>='}{det_threshold}\",\n", + " \"Logistic Regression\",\n", + " config={},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -350,7 +600,15 @@ "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", "version": "3.12.5" } },