From 02819a35b37ab049b80cebf41d6f95ea01991e00 Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Tue, 17 Sep 2024 21:37:31 +0100 Subject: [PATCH] in progress --- daindex/util.py | 6 + tutorials/adult_tutorial.ipynb | 323 ++++++++++++++++++++++++++++++++- 2 files changed, 323 insertions(+), 6 deletions(-) diff --git a/daindex/util.py b/daindex/util.py index 7136993..a364473 100644 --- a/daindex/util.py +++ b/daindex/util.py @@ -181,3 +181,9 @@ def viz( plt.axvspan(decision_boundary, 1, facecolor="b", alpha=0.1) plt.legend(fontsize=font_size, loc="best") + + return_dict = {"AUC": (a2 - a1) / a1} + if da1: + return_dict.update({"Decision AUC": (da2 - da1) / da1}) + + return return_dict diff --git a/tutorials/adult_tutorial.ipynb b/tutorials/adult_tutorial.ipynb index 6f9aa86..344d10a 100644 --- a/tutorials/adult_tutorial.ipynb +++ b/tutorials/adult_tutorial.ipynb @@ -9,19 +9,122 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from aif360.datasets import AdultDataset\n", "from aif360.algorithms.preprocessing import Reweighing\n", "\n", - "import numpy as np\n", - "\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.linear_model import LogisticRegression\n", - "from sklearn.pipeline import make_pipeline\n", - "from metrics.eval_metrics import print_metrics_binary" + "from sklearn.pipeline import make_pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The below code downloads the required files from the UCI website if they are not already present in ai360's data directory." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "adult.data already exists.\n", + "adult.test already exists.\n", + "adult.names already exists.\n" + ] + } + ], + "source": [ + "import urllib.request\n", + "import aif360\n", + "import os\n", + "\n", + "aif360_path = os.path.dirname(aif360.__file__)\n", + "adult_path = os.path.join(aif360_path, \"data\", \"raw\", \"adult\")\n", + "\n", + "# Define the file paths\n", + "files = [\"adult.data\", \"adult.test\", \"adult.names\"]\n", + "urls = [\n", + " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n", + " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\",\n", + " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names\",\n", + "]\n", + "\n", + "# Check and download files if they do not exist\n", + "for file, url in zip(files, urls):\n", + " file_path = os.path.join(adult_path, file)\n", + " if not os.path.exists(file_path):\n", + " print(f\"{file} not found. Downloading from {url}...\")\n", + " urllib.request.urlretrieve(url, file_path)\n", + " print(f\"{file} downloaded.\")\n", + " else:\n", + " print(f\"{file} already exists.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Missing Data: 3620 rows removed from AdultDataset.\n" + ] + } + ], + "source": [ + "privileged_groups1 = [{\"sex\": 1}]\n", + "unprivileged_groups1 = [{\"sex\": 0}]\n", + "privileged_groups2 = [{\"race\": 1}]\n", + "unprivileged_groups2 = [{\"race\": 0}]\n", + "\n", + "\n", + "dataset = AdultDataset()\n", + "\n", + "(dataset_orig_train, dataset_orig_val) = dataset.split([0.7], shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "model = make_pipeline(StandardScaler(), LogisticRegression(solver=\"liblinear\", random_state=1))\n", + "fit_params = {\"logisticregression__sample_weight\": dataset_orig_train.instance_weights}\n", + "\n", + "lr_orig = model.fit(dataset_orig_train.features, dataset_orig_train.labels.ravel(), **fit_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "rw_sex = Reweighing(unprivileged_groups=unprivileged_groups1, privileged_groups=privileged_groups1)\n", + "rw_race = Reweighing(unprivileged_groups=unprivileged_groups2, privileged_groups=privileged_groups2)\n", + "\n", + "trans_sex_dataset = rw_sex.fit(dataset).transform(dataset)\n", + "trans_sex_race_dataset = rw_race.fit(trans_sex_dataset).transform(trans_sex_dataset)\n", + "\n", + "trans_sex_dataset_train = rw_sex.fit(dataset_orig_train).transform(dataset_orig_train)\n", + "trans_sex_race_dataset_train = rw_race.fit(trans_sex_dataset_train).transform(trans_sex_dataset_train)\n", + "\n", + "model = make_pipeline(StandardScaler(), LogisticRegression(solver=\"liblinear\", random_state=1))\n", + "fit_params = {\"logisticregression__sample_weight\": trans_sex_race_dataset_train.instance_weights}\n", + "lr_rw = model.fit(trans_sex_race_dataset_train.features, trans_sex_race_dataset_train.labels.ravel(), **fit_params)" ] }, { @@ -29,7 +132,215 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from daindex.core import deterioration_index\n", + "import numpy as np\n", + "\n", + "def obtain_det_alo_index(\n", + " df,\n", + " scores,\n", + " det_feature,\n", + " score_threshold,\n", + " cohort_name,\n", + " det_label,\n", + " score_margin=0.05,\n", + " det_threshold=3,\n", + " min_v=-5,\n", + " max_v=15,\n", + " reverse=False,\n", + " is_discrete=True,\n", + " search_bandwidth=True,\n", + " det_feature_func=None,\n", + "):\n", + " lb = score_threshold - score_margin\n", + " up = score_threshold + score_margin\n", + " det_list = []\n", + " for i, r in df.iterrows():\n", + " p = scores[i]\n", + " if lb <= p <= up:\n", + " if det_feature_func is not None:\n", + " det_list.append(det_feature_func(r))\n", + " else:\n", + " det_list.append(r[det_feature])\n", + " if len(det_list) > 20:\n", + " X = np.array(det_list)\n", + " di_ret = deterioration_index(\n", + " X[~np.isnan(X)].reshape(-1, 1),\n", + " min_v,\n", + " max_v,\n", + " threshold=det_threshold,\n", + " plot_title=f\"{cohort_name} | {det_label}\",\n", + " reverse=reverse,\n", + " is_discrete=is_discrete,\n", + " search_bandwidth=search_bandwidth,\n", + " do_plot=False,\n", + " )\n", + " return score_threshold, len(det_list), di_ret[\"k-step\"]\n", + " else:\n", + " return score_threshold, 0, 0\n", + "\n", + "\n", + "def get_scores(models, df, features):\n", + " predicted_probs = np.array([m.predict_proba(df[features].to_numpy()) for m in models])\n", + " return predicted_probs[:, :, 1].mean(axis=0)\n", + "\n", + "\n", + "df_white = dataset[dataset[\"ethnicity\"] == 1]\n", + "df_non_white = dataset[dataset[\"ethnicity\"] == 0]\n", + "\n", + "df_male = kidney_df[kidney_df[\"gender\"] == 1]\n", + "df_female = kidney_df[kidney_df[\"gender\"] == 0]\n", + "\n", + "steps = 50\n", + "det_feature = \"Creatinine Max\"\n", + "det_label = \"Creatinine Max\"\n", + "is_discrete = False\n", + "min_v = min(np.min(df_white[det_feature]), np.min(df_non_white[det_feature]))\n", + "max_v = max(np.max(df_white[det_feature]), np.max(df_non_white[det_feature]))\n", + "print(min_v, max_v)\n", + "det_threshold = 1.35 # .7" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score, precision_score, recall_score, precision_recall_curve, auc\n", + "from aif360.metrics import ClassificationMetric\n", + "from collections import defaultdict\n", + "from metrics.eval_metrics import print_metrics_binary\n", + "\n", + "\n", + "def compute_deterioration_index(clf):\n", + " white_ret = []\n", + " non_white_ret = []\n", + " male_ret = []\n", + " female_ret = []\n", + "\n", + " scores1 = get_scores(clf, df_white, feature_list)\n", + " scores2 = get_scores(clf, df_non_white, feature_list)\n", + " scores3 = get_scores(clf, df_male, feature_list)\n", + " scores4 = get_scores(clf, df_female, feature_list)\n", + "\n", + " # Calculate DA AUC for this fold and accumulate the values\n", + " for s in range(1, steps + 1):\n", + " white_ret.append(\n", + " obtain_det_alo_index(\n", + " df_white,\n", + " scores1,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"White cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + " non_white_ret.append(\n", + " obtain_det_alo_index(\n", + " df_non_white,\n", + " scores2,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"Non-White cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + " male_ret.append(\n", + " obtain_det_alo_index(\n", + " df_male,\n", + " scores3,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"Male cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + " female_ret.append(\n", + " obtain_det_alo_index(\n", + " df_female,\n", + " scores4,\n", + " det_feature,\n", + " s / steps,\n", + " cohort_name=\"Female cohort\",\n", + " det_label=det_label,\n", + " det_threshold=det_threshold,\n", + " min_v=min_v,\n", + " max_v=max_v,\n", + " is_discrete=is_discrete,\n", + " )\n", + " )\n", + "\n", + " return white_ret, non_white_ret, male_ret, female_ret" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_daauc_values_and_graphs(results):\n", + " # Visualization of Average DA AUC Values for Each Model\n", + " viz_config = {\"fig_size\": (4, 3), \"font_size\": 12}\n", + "\n", + " # Loop through each model in all_results to generate the visualizations\n", + " for model_name, _ in models.items():\n", + " print(f\"\\n\\n{model_name}\")\n", + " for i in range(0, 10):\n", + " white_ret = results[model_name][f\"fold_{i}\"][\"white_ret\"]\n", + " non_white_ret = results[model_name][f\"fold_{i}\"][\"non_white_ret\"]\n", + " male_ret = results[model_name][f\"fold_{i}\"][\"male_ret\"]\n", + " female_ret = results[model_name][f\"fold_{i}\"][\"female_ret\"]\n", + "\n", + " print(\"\\n===============\")\n", + " print(i)\n", + " print(\"ethnicity:\")\n", + " # Visualize White vs Non-White average DA AUC\n", + " a1, a2, da1, da2 = viz(\n", + " np.array(white_ret),\n", + " np.array(non_white_ret),\n", + " \"White\",\n", + " \"Non-White\",\n", + " f\"{det_label} {'>='}{det_threshold}\",\n", + " f\"{model_name}\",\n", + " viz_config,\n", + " )\n", + " results[model_name][f\"fold_{i}\"][\"ethnicity_auc\"] = (a2 - a1) / a1\n", + " try:\n", + " results[model_name][f\"fold_{i}\"][\"ethnicity_dauc\"] = (da2 - da1) / da1\n", + " except ZeroDivisionError:\n", + " results[model_name][f\"fold_{i}\"][\"ethnicity_dauc\"] = 0\n", + " print(\"sex:\")\n", + " # Visualize Male vs Female average DA AUC\n", + " a1, a2, da1, da2 = viz(\n", + " np.array(male_ret),\n", + " np.array(female_ret),\n", + " \"Male\",\n", + " \"Female\",\n", + " f\"{det_label} {'>='}{det_threshold}\",\n", + " f\"{model_name}\",\n", + " viz_config,\n", + " )\n", + " results[model_name][f\"fold_{i}\"][\"sex_auc\"] = (a2 - a1) / a1\n", + " try:\n", + " results[model_name][f\"fold_{i}\"][\"sex_dauc\"] = (da2 - da1) / da1\n", + " except ZeroDivisionError:\n", + " results[model_name][f\"fold_{i}\"][\"sex_dauc\"] = 0\n", + " return results" + ] } ], "metadata": {