diff --git a/notebooks/README.md b/notebooks/README.md new file mode 100644 index 0000000..67e8932 --- /dev/null +++ b/notebooks/README.md @@ -0,0 +1,9 @@ +# Example notebooks using `folktexts` + +| File | Description | +| ---- | ----------- | +| [minimal-example.ipynb](minimal-example.ipynb) | A minimal example of how to use `folktexts` to evaluate model calibration and other statistics. | +| [run-benchmark.ipynb](run-benchmark.ipynb) | Running an ACS benchmark on a given LLM. | +| [parse-acs-results.ipynb](parse-acs-results.ipynb) | Aggregates and parses all ACS benchmark results saved under a given directory. | + diff --git a/notebooks/parse-acs-results.ipynb b/notebooks/parse-acs-results.ipynb new file mode 100755 index 0000000..9781e30 --- /dev/null +++ b/notebooks/parse-acs-results.ipynb @@ -0,0 +1,805 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "04c2f12d-f989-4ac9-b90f-3464a8bcca96", + "metadata": {}, + "source": [ + "# Fetch and parse ACS benchmark results under a given directory\n", + "Each ACS benchmark run outputs a json file. This script collects all such files under a given root directory, parses them, and aggregates them into a more easily digestable pandas DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3b241208-d10f-43cf-a486-84c54bbf43c3", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm.auto import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "faf6afa4-8648-4d35-9312-65636ea5d0b2", + "metadata": {}, + "source": [ + "Set the local path to the root results directory:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "26089a60-81c0-4736-8ba5-99572ec01398", + "metadata": {}, + "outputs": [], + "source": [ + "RESULTS_ROOT_DIR = Path(\"/fast/groups/sf\") / \"folktexts-results\" / \"2024-06-30\"" + ] + }, + { + "cell_type": "markdown", + "id": "03d1af7f-d013-40a0-80de-cb9ad65cff6d", + "metadata": {}, + "source": [ + "Set the local path to the root data directory (needed only to train baseline ML methods):" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b18dad87-a1ed-495b-93b6-4c8af4043ddc", + "metadata": {}, + "outputs": [], + "source": [ + "DATA_DIR = Path(\"/fast/groups/sf\") / \"data\"" + ] + }, + { + "cell_type": "markdown", + "id": "db7c98c5-2942-4f88-984a-8c9014afe761", + "metadata": {}, + "source": [ + "Important results columns:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e96f0c4f-5683-4150-8cca-93a883f20154", + "metadata": {}, + "outputs": [], + "source": [ + "model_col = \"config_model_name\"\n", + "task_col = \"config_task_name\"\n", + "\n", + "feature_subset_col = \"config_feature_subset\"\n", + "population_subset_col = \"config_population_filter\"\n", + "predictions_path_col = \"predictions_path\"" + ] + }, + { + "cell_type": "markdown", + "id": "4296ffdf-a9af-43f7-bd0c-20f4ae5f9619", + "metadata": {}, + "source": [ + "Helper function to parse each dictionary containing benchmark results:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0cf4a6c8-ec28-4e66-9f6d-8a4e977d7b60", + "metadata": {}, + "outputs": [], + "source": [ + "from utils import (\n", + " num_features_helper,\n", + " parse_model_name,\n", + " get_non_instruction_tuned_name,\n", + " prettify_model_name,\n", + ")\n", + "\n", + "def parse_results_dict(dct) -> dict:\n", + " \"\"\"Parses results dict and brings all information to the top-level.\"\"\"\n", + "\n", + " # Make a copy so we don't modify the input object\n", + " dct = dct.copy()\n", + "\n", + " # Discard plots' paths\n", + " dct.pop(\"plots\", None)\n", + "\n", + " # Bring configs to top-level\n", + " config = dct.pop(\"config\", {})\n", + " for key, val in config.items():\n", + " dct[f\"config_{key}\"] = val\n", + "\n", + " # Parse model name\n", + " dct[model_col] = parse_model_name(dct[model_col])\n", + " dct[\"base_name\"] = get_non_instruction_tuned_name(dct[model_col])\n", + " dct[\"name\"] = prettify_model_name(dct[model_col])\n", + "\n", + " # Is instruction-tuned model?\n", + " dct[\"is_inst\"] = dct[\"base_name\"] != dct[model_col]\n", + "\n", + " # Log number of features\n", + " dct[\"num_features\"] = num_features_helper(dct[feature_subset_col], max_features_return=-1)\n", + " dct[\"uses_all_features\"] = (dct[feature_subset_col] is None) or (dct[\"num_features\"] == -1)\n", + "\n", + " if dct[feature_subset_col] is None:\n", + " dct[feature_subset_col] = \"full\"\n", + "\n", + " # Assert all results are at the top-level\n", + " assert not any(isinstance(val, dict) for val in dct.values())\n", + " return dct\n" + ] + }, + { + "cell_type": "markdown", + "id": "02009a86-e099-4ec0-8676-8d66190ceddb", + "metadata": {}, + "source": [ + "Iteratively search the root directory for results files matching the given regex:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "aaefa99d-dbd7-40a3-a3c1-7571d3811409", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c84441efd3d74b30953f951c762e9b62", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 94 benchmark results.\n" + ] + } + ], + "source": [ + "from utils import find_files, load_json\n", + "\n", + "# Results file name pattern\n", + "pattern = r'^results.bench-(?P\\d+)[.]json$'\n", + "\n", + "# Find results files and aggregate\n", + "results = {}\n", + "for file_path in tqdm(find_files(RESULTS_ROOT_DIR, pattern)):\n", + " results[Path(file_path).parent.name] = parse_results_dict(load_json(file_path))\n", + "\n", + "if len(results) == 0:\n", + " raise RuntimeError(f\"Couldn't find any results at {RESULTS_ROOT_DIR}\")\n", + "else:\n", + " print(f\"Found {len(results)} benchmark results.\")" + ] + }, + { + "cell_type": "markdown", + "id": "0ecc0d33-902d-492a-89c6-d9729fe69fa1", + "metadata": {}, + "source": [ + "Aggregate results into a single DataFrame, and generate a unique identifier for each row:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "94f1900f-fea3-4872-b722-ee1cd3f5f7e1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "df.shape=(94, 58)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
accuracyaccuracy_diffaccuracy_ratiobalanced_accuracybalanced_accuracy_diffbalanced_accuracy_ratiobrier_score_losseceece_quantileequalized_odds_diff...config_population_filterconfig_reuse_few_shot_examplesconfig_seedconfig_task_hashconfig_task_namebase_namenameis_instnum_featuresuses_all_features
id
gemma-2-9b__ACSIncome__-10.4029440.2207660.5147240.5276150.0652480.8859220.2467660.2567950.2544300.134841...NoneFalse422612382143ACSIncomegemma-2-9bGemma 2 9BFalse-1True
Mistral-7B-v0.1__ACSEmployment__-10.4534460.2620420.4734080.5000000.0000001.0000000.3009820.2628490.2734420.000000...NoneFalse421212564561ACSEmploymentMistral-7B-v0.1Mistral 7BFalse-1True
gemma-2b__ACSTravelTime__-10.4384480.4014340.2106720.4998910.0013780.9972490.2552490.0911380.0911380.002320...NoneFalse42233993660ACSTravelTimegemma-2bGemma 2BFalse-1True
\n", + "

3 rows × 58 columns

\n", + "
" + ], + "text/plain": [ + " accuracy accuracy_diff accuracy_ratio \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 0.402944 0.220766 0.514724 \n", + "Mistral-7B-v0.1__ACSEmployment__-1 0.453446 0.262042 0.473408 \n", + "gemma-2b__ACSTravelTime__-1 0.438448 0.401434 0.210672 \n", + "\n", + " balanced_accuracy balanced_accuracy_diff \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 0.527615 0.065248 \n", + "Mistral-7B-v0.1__ACSEmployment__-1 0.500000 0.000000 \n", + "gemma-2b__ACSTravelTime__-1 0.499891 0.001378 \n", + "\n", + " balanced_accuracy_ratio brier_score_loss \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 0.885922 0.246766 \n", + "Mistral-7B-v0.1__ACSEmployment__-1 1.000000 0.300982 \n", + "gemma-2b__ACSTravelTime__-1 0.997249 0.255249 \n", + "\n", + " ece ece_quantile \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 0.256795 0.254430 \n", + "Mistral-7B-v0.1__ACSEmployment__-1 0.262849 0.273442 \n", + "gemma-2b__ACSTravelTime__-1 0.091138 0.091138 \n", + "\n", + " equalized_odds_diff ... \\\n", + "id ... \n", + "gemma-2-9b__ACSIncome__-1 0.134841 ... \n", + "Mistral-7B-v0.1__ACSEmployment__-1 0.000000 ... \n", + "gemma-2b__ACSTravelTime__-1 0.002320 ... \n", + "\n", + " config_population_filter \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 None \n", + "Mistral-7B-v0.1__ACSEmployment__-1 None \n", + "gemma-2b__ACSTravelTime__-1 None \n", + "\n", + " config_reuse_few_shot_examples \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 False \n", + "Mistral-7B-v0.1__ACSEmployment__-1 False \n", + "gemma-2b__ACSTravelTime__-1 False \n", + "\n", + " config_seed config_task_hash \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 42 2612382143 \n", + "Mistral-7B-v0.1__ACSEmployment__-1 42 1212564561 \n", + "gemma-2b__ACSTravelTime__-1 42 233993660 \n", + "\n", + " config_task_name base_name \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 ACSIncome gemma-2-9b \n", + "Mistral-7B-v0.1__ACSEmployment__-1 ACSEmployment Mistral-7B-v0.1 \n", + "gemma-2b__ACSTravelTime__-1 ACSTravelTime gemma-2b \n", + "\n", + " name is_inst num_features \\\n", + "id \n", + "gemma-2-9b__ACSIncome__-1 Gemma 2 9B False -1 \n", + "Mistral-7B-v0.1__ACSEmployment__-1 Mistral 7B False -1 \n", + "gemma-2b__ACSTravelTime__-1 Gemma 2B False -1 \n", + "\n", + " uses_all_features \n", + "id \n", + "gemma-2-9b__ACSIncome__-1 True \n", + "Mistral-7B-v0.1__ACSEmployment__-1 True \n", + "gemma-2b__ACSTravelTime__-1 True \n", + "\n", + "[3 rows x 58 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.DataFrame(list(results.values()))\n", + "\n", + "def row_id(row) -> str:\n", + " \"\"\"Unique row identifier.\"\"\"\n", + " return f\"{row[model_col]}__{row[task_col]}__{row['num_features']}\"\n", + "\n", + "df[\"id\"] = df.apply(row_id, axis=1)\n", + "df = df.set_index(\"id\", drop=True, verify_integrity=True)\n", + "\n", + "print(f\"{df.shape=}\")\n", + "df.sample(3)" + ] + }, + { + "cell_type": "markdown", + "id": "5fd59aee-3802-4430-a977-243dd44f8f62", + "metadata": {}, + "source": [ + "Drop potential duplicates:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "69ef57de-0059-4c40-9eae-fb4e6b3bf582", + "metadata": {}, + "outputs": [], + "source": [ + "parsed_df = df.drop_duplicates(subset=[\"name\", \"is_inst\", \"num_features\", task_col])\n", + "if len(parsed_df) != len(df):\n", + " print(f\"Found {len(df) - len(parsed_df)} duplicates! dropping rows...\")\n", + " df = parsed_df" + ] + }, + { + "cell_type": "markdown", + "id": "97cd24df-3c1f-49c0-bb50-a6fdba22e0fb", + "metadata": {}, + "source": [ + "Load scores DFs and analyze score distribution:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "43ba0fbf-b886-4f2a-8a01-564aa6f0f2c6", + "metadata": {}, + "outputs": [], + "source": [ + "def load_model_scores_df(df_row: pd.Series) -> pd.DataFrame:\n", + " \"\"\"Loads csv containing model scores corresponding to the given DF row.\"\"\"\n", + " if predictions_path_col in df_row and not pd.isna(df_row[predictions_path_col]):\n", + " return pd.read_csv(df_row[predictions_path_col], index_col=0)\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "54e22896-1d3d-4d95-bb90-2b453f58f09a", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be38d2d1723546e88324af13b4838343", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/94 [00:00= fit_thr).astype(int))\n", + " opt_acc = metrics.accuracy_score(labels, (risk_scores >= opt_thr).astype(int))\n", + "\n", + " # Save results\n", + " scores_stats[row_id] = {\n", + " fit_thresh_col: fit_thr,\n", + " fit_acc_col: fit_acc,\n", + " optimal_thres_col: opt_thr,\n", + " optimal_acc_col: opt_acc,\n", + " score_stdev_col: np.std(risk_scores),\n", + " score_mean_col: np.mean(risk_scores),\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "b1a58b87-2de8-473a-badb-934951d1bcdc", + "metadata": {}, + "source": [ + "Update results DF with scores statistics:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0ba74c81-c008-4010-a9a3-1b511ae445df", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
accuracyaccuracy_diffaccuracy_ratiobalanced_accuracybalanced_accuracy_diffbalanced_accuracy_ratiobrier_score_losseceece_quantileequalized_odds_diff...nameis_instnum_featuresuses_all_featuresfit_thresh_on_100fit_thresh_accuracyoptimal_threshoptimal_thresh_accuracyscore_stdevscore_mean
gemma-2-27b__ACSMobility__-10.6070960.1406620.8052370.4972020.0460880.9125740.2749960.228450.2279590.091793...Gemma 2 27BFalse-1True0.6514230.5362510.0600860.7315200.2700620.664836
Meta-Llama-3-70B-Instruct__ACSTravelTime__-10.5956020.2567230.6874680.5408760.1221570.7901650.2405570.148480.1488090.198067...Llama 3 70B (it)True-1True0.2450700.6141890.2690030.6387410.1241550.289832
\n", + "

2 rows × 64 columns

\n", + "
" + ], + "text/plain": [ + " accuracy accuracy_diff \\\n", + "gemma-2-27b__ACSMobility__-1 0.607096 0.140662 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.595602 0.256723 \n", + "\n", + " accuracy_ratio \\\n", + "gemma-2-27b__ACSMobility__-1 0.805237 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.687468 \n", + "\n", + " balanced_accuracy \\\n", + "gemma-2-27b__ACSMobility__-1 0.497202 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.540876 \n", + "\n", + " balanced_accuracy_diff \\\n", + "gemma-2-27b__ACSMobility__-1 0.046088 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.122157 \n", + "\n", + " balanced_accuracy_ratio \\\n", + "gemma-2-27b__ACSMobility__-1 0.912574 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.790165 \n", + "\n", + " brier_score_loss ece \\\n", + "gemma-2-27b__ACSMobility__-1 0.274996 0.22845 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.240557 0.14848 \n", + "\n", + " ece_quantile \\\n", + "gemma-2-27b__ACSMobility__-1 0.227959 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.148809 \n", + "\n", + " equalized_odds_diff ... \\\n", + "gemma-2-27b__ACSMobility__-1 0.091793 ... \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.198067 ... \n", + "\n", + " name is_inst \\\n", + "gemma-2-27b__ACSMobility__-1 Gemma 2 27B False \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 Llama 3 70B (it) True \n", + "\n", + " num_features uses_all_features \\\n", + "gemma-2-27b__ACSMobility__-1 -1 True \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 -1 True \n", + "\n", + " fit_thresh_on_100 \\\n", + "gemma-2-27b__ACSMobility__-1 0.651423 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.245070 \n", + "\n", + " fit_thresh_accuracy \\\n", + "gemma-2-27b__ACSMobility__-1 0.536251 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.614189 \n", + "\n", + " optimal_thresh \\\n", + "gemma-2-27b__ACSMobility__-1 0.060086 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.269003 \n", + "\n", + " optimal_thresh_accuracy \\\n", + "gemma-2-27b__ACSMobility__-1 0.731520 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.638741 \n", + "\n", + " score_stdev score_mean \n", + "gemma-2-27b__ACSMobility__-1 0.270062 0.664836 \n", + "Meta-Llama-3-70B-Instruct__ACSTravelTime__-1 0.124155 0.289832 \n", + "\n", + "[2 rows x 64 columns]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores_stats_df = pd.DataFrame(scores_stats.values(), index=list(scores_stats.keys()))\n", + "\n", + "results_df = pd.concat((df, scores_stats_df), axis=\"columns\")\n", + "results_df.sample(2)" + ] + }, + { + "cell_type": "markdown", + "id": "a6e859fc-9d82-436d-a238-11fd008da44c", + "metadata": {}, + "source": [ + "Finally, save results DF to the results root directory:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9725afa2-b6f8-46b1-86e7-57ae82ef05d9", + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "def get_current_timestamp() -> str:\n", + " \"\"\"Return a timestamp representing the current time up to the second.\"\"\"\n", + " return datetime.now().strftime(\"%Y.%m.%d-%H.%M.%S\")\n", + "\n", + "df.to_csv(Path(RESULTS_ROOT_DIR) / f\"aggregated_results.{get_current_timestamp()}.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "c3a6beb5-b44b-4f81-ae1a-d027afb2c5f4", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 2842ed3..a70f97a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] -version = "0.0.16" +version = "0.0.17" requires-python = ">=3.8" dynamic = [ "readme",