diff --git a/examples/eol/sard_eol.ipynb b/examples/eol/sard_eol.ipynb
index c272c50..8c9f185 100644
--- a/examples/eol/sard_eol.ipynb
+++ b/examples/eol/sard_eol.ipynb
@@ -1,758 +1,758 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "n3ieKAgvPfsZ"
- },
- "source": [
- "# Run End of Life prediction task on Synthetic Patient Data in OMOP"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "u1UUubwGPfsc"
- },
- "source": [
- "This notebook runs the end-of-life (EOL) prediction task on synthetic patient data in OMOP using a linear baseline model and the SARD architecture [Kodialam et al. 2021].\n",
- "\n",
- "Data is sourced from the publicly available Medicare Claims Synthetic Public Use Files (SynPUF), released by the Centers for Medicare and Medicaid Services (CMS) and available in [Google BigQuery. The synthetic set contains 2008-2010 Medicare insurance claims for development and demonstration purposes and was coverted to the Medical Outcomes Partnership (OMOP) Common Data Model from its original CSV form."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "YaHCADZXPfsd"
- },
- "source": [
- "## Imports and GPU setup"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "lYmOTr6sPfse"
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import pandas as pd\n",
- "import torch\n",
- "import time\n",
- "import os\n",
- "\n",
- "from sklearn.model_selection import train_test_split\n",
- "from sklearn.metrics import roc_auc_score\n",
- "\n",
- "from ipywidgets import IntProgress, FloatText\n",
- "from IPython.display import display\n",
- "\n",
- "import matplotlib\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "plt.rcParams[\"font.family\"] = \"serif\"\n",
- "plt.rcParams[\"font.size\"] = 13"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "AEXAJWHXPfsf"
- },
- "outputs": [],
- "source": [
- "from omop_learn.backends.bigquery import BigQueryBackend\n",
- "from omop_learn.data.cohort import Cohort\n",
- "from omop_learn.data.feature import Feature\n",
- "from omop_learn.utils.config import Config\n",
- "from omop_learn.omop import OMOPDataset\n",
- "from omop_learn.utils import date_utils, embedding_utils\n",
- "from omop_learn.sparse.models import OMOPLogisticRegression\n",
- "from omop_learn.models import transformer, visit_transformer"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "GU-M1ASqPfsg"
- },
- "source": [
- "## Cohort, Outcome and Feature Collection\n",
- "\n",
- "### 1. Set up a connection to the OMOP CDM database\n",
- "\n",
- "Parameters for connection to be specified in ./config.py"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": false,
- "id": "ZtAvjOCzPfsg"
- },
- "outputs": [],
- "source": [
- "config = Config({\n",
- " \"project_name\": \"project\",\n",
- " \"cdm_schema\": \"bigquery-public-data.cms_synthetic_patient_data_omop\",\n",
- " \"prefix_schema\": \"username\",\n",
- " \"datasets_dir\": \"data_dir\",\n",
- " \"models_dir\": \"model_dir\"\n",
- "})\n",
- "\n",
- "# Set up database, reset schemas as needed\n",
- "backend = BigQueryBackend(config)\n",
- "backend.reset_schema(config.prefix_schema) # Rebuild schema from scratch\n",
- "backend.create_schema(config.prefix_schema) # Create schema if not exists\n",
- "\n",
- "cohort_params = {\n",
- " \"cohort_table_name\": \"synpuf_eol_cohort\",\n",
- " \"schema_name\": config.prefix_schema,\n",
- " \"cdm_schema\": config.cdm_schema,\n",
- " \"aux_data_schema\": config.aux_cdm_schema,\n",
- " \"training_start_date\": \"2009-01-01\",\n",
- " \"training_end_date\": \"2009-12-31\",\n",
- " \"gap\": \"3 month\",\n",
- " \"outcome_window\": \"6 month\",\n",
- "}\n",
- "sql_dir = \"./bigquery_sql\"\n",
- "sql_file = open(f\"{sql_dir}/gen_EOL_cohort.sql\", 'r')\n",
- "cohort = Cohort.from_sql_file(sql_file, backend, params=cohort_params)\n",
- "\n",
- "feature_names = [\"drugs\", \"conditions\", \"procedures\"]\n",
- "feature_paths = [f\"{sql_dir}/{feature_name}.sql\" for feature_name in feature_names]\n",
- "features = [Feature(n, p) for n, p in zip(feature_names, feature_paths)]\n",
- "\n",
- "init_args = {\n",
- " \"config\" : config,\n",
- " \"name\" : \"synpuf_eol\",\n",
- " \"cohort\" : cohort,\n",
- " \"features\": features,\n",
- " \"backend\": backend,\n",
- "}\n",
- "\n",
- "dataset = OMOPDataset(**init_args)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Oa7ULCSRPfsh"
- },
- "source": [
- "### 4. Process the collected data and calculate indices needed for the deep model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "U-kcFPtOPfsh"
- },
- "outputs": [],
- "source": [
- "window_days = [30, 180, 365, 730, 1000]\n",
- "windowed_dataset = dataset.to_windowed(window_days)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "nmZZunVJPfsi"
- },
- "outputs": [],
- "source": [
- "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "aLSE6AG8Pfsi"
- },
- "outputs": [],
- "source": [
- "# process data for deep model\n",
- "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords\n",
- "outcomes_filt = windowed_dataset.outcomes\n",
- "time_to_idx = windowed_dataset.times_map\n",
- "idx_to_datetime = {idx: date_utils.from_unixtime([time])[0] for time, idx in time_to_idx.items()}\n",
- "\n",
- "all_codes_tensor = code_ixs\n",
- "people = sorted(np.unique(person_ixs))\n",
- "person_indices = np.searchsorted(person_ixs, people)\n",
- "person_indices = np.append(person_indices, len(person_ixs))\n",
- "person_chunks = [\n",
- " time_ixs[person_indices[i]: person_indices[i + 1]]\n",
- " for i in range(len(person_indices) - 1)\n",
- "]\n",
- "\n",
- "visit_chunks = []\n",
- "visit_times_raw = []\n",
- "\n",
- "for i, chunk in enumerate(person_chunks):\n",
- " visits = sorted(np.unique(chunk))\n",
- " visit_indices_local = np.searchsorted(chunk, visits)\n",
- " visit_indices_local = np.append(\n",
- " visit_indices_local,\n",
- " len(chunk)\n",
- " )\n",
- " visit_chunks.append(visit_indices_local)\n",
- " visit_times_raw.append(visits)\n",
- "\n",
- "n_visits = {i:len(j) for i,j in enumerate(visit_times_raw)}\n",
- "\n",
- "visit_days_rel = {\n",
- " i: (\n",
- " pd.to_datetime(cohort_params['training_end_date']) \\\n",
- " - pd.to_datetime(idx_to_datetime[time])\n",
- " ).days for time in time_ixs\n",
- "}\n",
- "vdrel_func = np.vectorize(visit_days_rel.get)\n",
- "visit_time_rel = [\n",
- " vdrel_func(v) for v in visit_times_raw\n",
- "]\n",
- "\n",
- "remap = {\n",
- " 'id': people,\n",
- " 'time': sorted(np.unique(time_ixs)),\n",
- " 'concept': sorted(np.unique(code_ixs))\n",
- "}\n",
- "\n",
- "dataset_dict = {\n",
- " 'all_codes_tensor': all_codes_tensor, # A tensor of all codes occurring in the dataset\n",
- " 'person_indices': person_indices, # A list of indices such that all_codes_tensor[person_indices[i]: person_indices[i+1]] are the codes assigned to the ith patient\n",
- " 'visit_chunks': visit_chunks, # A list of indices such that all_codes_tensor[person_indices[i]+visit_chunks[j]:person_indices[i]+visit_chunks[j+1]] are the codes assigned to the ith patient during their jth visit\n",
- " 'visit_time_rel': visit_time_rel, # A list of times (as measured in days to the prediction date) for each visit\n",
- " 'n_visits': n_visits, # A dict defined such that n_visits[i] is the number of visits made by the ith patient\n",
- " 'outcomes_filt': outcomes_filt, # A pandas Series defined such that outcomes_filt.iloc[i] is the outcome of the ith patient\n",
- " 'remap': remap,\n",
- "}"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9rumxm09Pfsj"
- },
- "source": [
- "## Run the windowed regression model on the task defined above"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "EBWOySUSPfsj"
- },
- "outputs": [],
- "source": [
- "# split data into train, validate and test sets\n",
- "windowed_dataset.split()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jtVIFn2CPfsj"
- },
- "outputs": [],
- "source": [
- "# train the regression model over several choices of regularization parameter\n",
- "reg_lambdas = [2, 0.2, 0.02]\n",
- "lr_val_aucs = []\n",
- "model = OMOPLogisticRegression(\"eol_new_50\", windowed_dataset)\n",
- "\n",
- "for reg_lambda in reg_lambdas:\n",
- " # Gen and fit\n",
- " model.gen_pipeline(reg_lambda)\n",
- " model.fit()\n",
- " # Eval on validation data\n",
- " pred_lr = model._pipeline.predict_proba(windowed_dataset.val['X'])[:, 1]\n",
- " lr_val_auc = roc_auc_score(windowed_dataset.val['y'], pred_lr)\n",
- " lr_val_aucs.append(lr_val_auc)\n",
- " print(\"C: %.4f, Val AUC: %.2f\" % (reg_lambda, lr_val_auc))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "pYy4lcbePfsj"
- },
- "outputs": [],
- "source": [
- "# Gen and fit on best C\n",
- "best_reg_lambda = reg_lambdas[np.argmax(lr_val_aucs)]\n",
- "model.gen_pipeline(best_reg_lambda)\n",
- "model.fit()\n",
- "# Eval on test data\n",
- "pred_lr = model._pipeline.predict_proba(windowed_dataset.test['X'])[:, 1]\n",
- "score = roc_auc_score(windowed_dataset.test['y'], pred_lr)\n",
- "print(\"C: %.4f, Test AUC: %.2f\" % (best_reg_lambda, score))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jvYkElLlPfsk"
- },
- "source": [
- "### Learn a Word2Vec embedding"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "e7oM4vOsPfsk"
- },
- "outputs": [],
- "source": [
- "%%time\n",
- "embedding_dim = 300 # size of embedding, must be multiple of number of heads\n",
- "window_days = 90 # number of days in window that defines a \"Sentence\" when learning the embedding\n",
- "train_coords = np.nonzero(np.where(np.isin(person_ixs, indices_train), 1, 0))\n",
- "embedding_filename = embedding_utils.train_embedding(featureSet, feature_matrix_3d_transpose, window_days, \\\n",
- " person_ixs[train_coords], time_ixs[train_coords], \\\n",
- " remap['time'], embedding_dim)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "vlIGYdtIPfsk"
- },
- "source": [
- "## Run the SARD deep model on the predictive task\n",
- "### 1. Set Model Parameters and Construct the Model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Yyr2CO1wPfsk"
- },
- "outputs": [],
- "source": [
- "# using the same split as before, create train/validate/test batches for the deep model\n",
- "# `mbsz` might need to be decreased based on the GPU's memory and the number of features being used\n",
- "mbsz = 50\n",
- "def get_batches(arr, mbsz=mbsz):\n",
- " curr, ret = 0, []\n",
- " while curr < len(arr) - 1:\n",
- " ret.append(arr[curr : curr + mbsz])\n",
- " curr += mbsz\n",
- " return ret\n",
- "\n",
- "p_ranges_train, p_ranges_test = [\n",
- " get_batches(arr) for arr in (\n",
- " indices_train, indices_test\n",
- " )\n",
- "]\n",
- "p_ranges_val = p_ranges_test[:val_size // mbsz]\n",
- "p_ranges_test = p_ranges_test[val_size // mbsz:]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "AfUpBu4sPfsl"
- },
- "outputs": [],
- "source": [
- "# Pick a name for the model (mn_prefix) that will be used when saving checkpoints\n",
- "# Then, set some parameters for SARD. The values below reflect a good starting point that performed well on several tasks\n",
- "mn_prefix = 'eol_experiment_prefix'\n",
- "n_heads = 2\n",
- "assert embedding_dim % n_heads == 0\n",
- "model_params = {\n",
- " 'embedding_dim': int(embedding_dim / n_heads), # Dimension per head of visit embeddings\n",
- " 'n_heads': n_heads, # Number of self-attention heads\n",
- " 'attn_depth': 2, # Number of stacked self-attention layers\n",
- " 'dropout': 0.05, # Dropout rate for both self-attention and the final prediction layer\n",
- " 'use_mask': True, # Only allow visits to attend to other actual visits, not to padding visits\n",
- " 'concept_embedding_path': embedding_filename # if unspecified, uses default Torch embeddings\n",
- "}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "GJ2DXdlWPfsl"
- },
- "outputs": [],
- "source": [
- "# Set up fixed model parameters, loss functions, and build the model on the GPU\n",
- "lr = 2e-4\n",
- "n_epochs_pretrain = 1\n",
- "ft_epochs = 1\n",
- "\n",
- "update_every = 500\n",
- "update_mod = update_every // mbsz\n",
- "\n",
- "base_model = visit_transformer.VisitTransformer(\n",
- " featureSet, **model_params\n",
- ")\n",
- "\n",
- "clf = visit_transformer.VTClassifer(\n",
- " base_model, **model_params\n",
- ").cuda()\n",
- "\n",
- "clf.bert.set_data(\n",
- " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n",
- " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n",
- " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n",
- ")\n",
- "\n",
- "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n",
- " pos_weight=torch.FloatTensor([\n",
- " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n",
- " ]), reduction='sum'\n",
- ").cuda()\n",
- "\n",
- "optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=lr)\n",
- "\n",
- "def eval_curr_model_on(a):\n",
- " with torch.no_grad():\n",
- " preds_test, true_test = [], []\n",
- " for batch_num, p_range in enumerate(a):\n",
- " y_pred = clf(p_range)\n",
- " preds_test += y_pred.tolist()\n",
- " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n",
- " return roc_auc_score(true_test, preds_test)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "4l95ivncPfsl"
- },
- "source": [
- "### 2. Fit the SARD model to the best windowed linear model (Reverse Distillation)\n",
- "\n",
- "The following code saves models in a folder `/SavedModels/{task}/`; make sure to create the directory before running."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "UgJt7p11Pfsl"
- },
- "outputs": [],
- "source": [
- "task = 'eol'"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ESuCBwrMPfsl"
- },
- "outputs": [],
- "source": [
- "# Run `n_epochs_pretrain` of Reverse Distillation pretraining\n",
- "val_losses = []\n",
- "progress_bar = IntProgress(min=0, max=int(n_epochs_pretrain * len(p_ranges_train)))\n",
- "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n",
- "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n",
- "\n",
- "display(progress_bar)\n",
- "display(batch_loss_disp)\n",
- "display(time_disp)\n",
- "\n",
- "for epoch in range(n_epochs_pretrain):\n",
- " t, batch_loss = time.time(), 0\n",
- "\n",
- " for batch_num, p_range in enumerate(p_ranges_train):\n",
- "\n",
- " if batch_num % 50 == 0:\n",
- " batch_loss_disp.value = round(batch_loss / 50, 2)\n",
- " time_disp.value = round(time.time() - t, 2)\n",
- " t, batch_loss = time.time(), 0\n",
- "\n",
- " y_pred = clf(p_range)\n",
- " loss_distill = loss_function_distill(\n",
- " y_pred, torch.FloatTensor(pred_lr_all[p_range]).cuda()\n",
- " )\n",
- "\n",
- " batch_loss += loss_distill.item()\n",
- " loss_distill.backward()\n",
- "\n",
- " if batch_num % update_mod == 0:\n",
- " optimizer_clf.step()\n",
- " optimizer_clf.zero_grad()\n",
- "\n",
- " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n",
- "\n",
- " torch.save(\n",
- " clf.state_dict(),\n",
- " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
- " task=task, mn_prefix = mn_prefix, epochs = epoch + 1\n",
- " )\n",
- " )\n",
- "\n",
- " clf.eval()\n",
- " ckpt_auc = eval_curr_model_on(p_ranges_val)\n",
- " print('Epochs: {} | Val AUC: {}'.format(epoch + 1, ckpt_auc))\n",
- " val_losses.append(ckpt_auc)\n",
- " clf.train()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "S79PbKdhPfsm"
- },
- "outputs": [],
- "source": [
- "# Save the pretrained model with best validation-set performance\n",
- "clf.load_state_dict(\n",
- " torch.load(\"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
- " task=task, mn_prefix=mn_prefix, epochs=np.argmax(val_losses) + 1\n",
- " ))\n",
- ")\n",
- "torch.save(\n",
- " clf.state_dict(),\n",
- " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
- " task=task, mn_prefix = mn_prefix, epochs = 'BEST'\n",
- " )\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "6Wzmh6tCPfsm"
- },
- "source": [
- "### 3. Fine-tune the SARD model by training to match the actual outcomes on the training set"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "iczzP5NAPfsm"
- },
- "outputs": [],
- "source": [
- "# Set up loss functions for fine-tuning. There are two terms:\n",
- "# - `loss_function_distill`, which penalizes differences between the linear model prediction and SARD's prediction\n",
- "# - `loss_function_clf`, which penalizes differences between the true outcome and SARD's prediction\n",
- "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n",
- " pos_weight=torch.FloatTensor([\n",
- " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n",
- " ]), reduction='sum'\n",
- ").cuda()\n",
- "\n",
- "loss_function_clf = torch.nn.BCEWithLogitsLoss(\n",
- " pos_weight=torch.FloatTensor([\n",
- " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n",
- " ]), reduction='sum'\n",
- ").cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "30_HYniePfsm"
- },
- "outputs": [],
- "source": [
- "# run `ft_epochs` of fine-tuning training, for each of the values of `alpha` below.\n",
- "# Note that `alpha` is the relative weight of `loss_function_distill` as compared to `loss_function_clf`\n",
- "\n",
- "all_pred_models = {}\n",
- "\n",
- "progress_bar = IntProgress(min=0, max=int(ft_epochs * len(p_ranges_train)))\n",
- "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n",
- "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n",
- "\n",
- "display(progress_bar)\n",
- "display(batch_loss_disp)\n",
- "display(time_disp)\n",
- "\n",
- "\n",
- "no_rd = False\n",
- "for alpha in [0,0.05,0.1,0.15, 0.2]:\n",
- "\n",
- " progress_bar.value = 0\n",
- "\n",
- " if no_rd:\n",
- " pretrained_model_fn = mn_prefix + '_None'\n",
- " start_model = None\n",
- " if start_model is None:\n",
- " base_model = visit_transformer.VisitTransformer(\n",
- " featureSet, **model_params\n",
- " )\n",
- "\n",
- " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n",
- "\n",
- " clf.bert.set_data(\n",
- " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n",
- " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n",
- " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n",
- " )\n",
- " else:\n",
- " pretrained_model_path = \"SavedModels/{task}/{start_model}\".format(\n",
- " task=task, start_model=start_model\n",
- " )\n",
- " clf.load_state_dict(torch.load(pretrained_model_path))\n",
- "\n",
- " else:\n",
- " pretrained_model_fn = \"{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
- " mn_prefix=mn_prefix, epochs='BEST'\n",
- " )\n",
- " pretrained_model_path = \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
- " task=task, mn_prefix=mn_prefix, epochs='BEST'\n",
- " )\n",
- "\n",
- " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n",
- " clf.bert.set_data(\n",
- " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n",
- " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n",
- " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n",
- " )\n",
- "\n",
- " clf.load_state_dict(torch.load(pretrained_model_path))\n",
- "\n",
- " clf.train()\n",
- "\n",
- " optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=2e-4)\n",
- "\n",
- " for epoch in range(ft_epochs):\n",
- "\n",
- " t, batch_loss = time.time(), 0\n",
- "\n",
- " for batch_num, p_range in enumerate(p_ranges_train):\n",
- "\n",
- " if batch_num % 50 == 0:\n",
- " batch_loss_disp.value = round(batch_loss / 50, 2)\n",
- " time_disp.value = round(time.time() - t, 2)\n",
- " t, batch_loss = time.time(), 0\n",
- "\n",
- " y_pred = clf(p_range)\n",
- "\n",
- " loss = loss_function_clf(\n",
- " y_pred,\n",
- " torch.FloatTensor(dataset_dict['outcomes_filt'].values[p_range]).cuda()\n",
- " )\n",
- "\n",
- " loss_distill = loss_distill = loss_function_distill(\n",
- " y_pred,\n",
- " torch.FloatTensor(pred_lr_all[p_range]).cuda()\n",
- " )\n",
- "\n",
- " batch_loss += loss.item() + alpha * loss_distill.item()\n",
- " loss_total = loss + alpha * loss_distill\n",
- " loss_total.backward()\n",
- "\n",
- " if batch_num % update_mod == 0:\n",
- " optimizer_clf.step()\n",
- " optimizer_clf.zero_grad()\n",
- "\n",
- " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n",
- "\n",
- " saving_fn = \"{pretrain}_alpha_{alpha}_epochs_{epochs}\".format(\n",
- " task=task, pretrain = pretrained_model_fn, alpha=alpha, epochs = epoch + 1\n",
- " )\n",
- " torch.save(\n",
- " clf.state_dict(),\n",
- " \"SavedModels/{task}/{saving_fn}\".format(\n",
- " task=task, saving_fn=saving_fn\n",
- " )\n",
- " )\n",
- "\n",
- " clf.eval()\n",
- " val_auc = eval_curr_model_on(p_ranges_val)\n",
- " print(val_auc)\n",
- " all_pred_models[saving_fn] = val_auc\n",
- " clf.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rXk7tpRdPfsm"
- },
- "source": [
- "### 4. Evaluate the best SARD model, as determined by validation performance"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "35WmOaisPfsm"
- },
- "outputs": [],
- "source": [
- "best_model = max(all_pred_models, key=all_pred_models.get)\n",
- "clf.load_state_dict(\n",
- " torch.load(\"SavedModels/{task}/{model}\".format(\n",
- " task=task, model=best_model\n",
- " ))\n",
- ")\n",
- "clf.eval();\n",
- "with torch.no_grad():\n",
- " preds_test, true_test = [], []\n",
- " for batch_num, p_range in enumerate(p_ranges_test):\n",
- " y_pred = clf(p_range)\n",
- " preds_test += y_pred.tolist()\n",
- " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n",
- " print(roc_auc_score(true_test, preds_test))\n",
- "clf.train();"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "omop-learn",
- "language": "python",
- "name": "omop-learn"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.5"
- },
- "colab": {
- "provenance": [],
- "include_colab_link": true
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "n3ieKAgvPfsZ"
+ },
+ "source": [
+ "# Run End of Life prediction task on Synthetic Patient Data in OMOP"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "u1UUubwGPfsc"
+ },
+ "source": [
+ "This notebook runs the end-of-life (EOL) prediction task on synthetic patient data in OMOP using a linear baseline model and the SARD architecture [Kodialam et al. 2021].\n",
+ "\n",
+ "Data is sourced from the publicly available Medicare Claims Synthetic Public Use Files (SynPUF), released by the Centers for Medicare and Medicaid Services (CMS) and available in [Google BigQuery](https://console.cloud.google.com/marketplace/product/hhs/synpuf). The synthetic set contains 2008-2010 Medicare insurance claims for development and demonstration purposes and was coverted to the Medical Outcomes Partnership (OMOP) Common Data Model from its original CSV form."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YaHCADZXPfsd"
+ },
+ "source": [
+ "## Imports and GPU setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lYmOTr6sPfse"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "import time\n",
+ "import os\n",
+ "\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.metrics import roc_auc_score\n",
+ "\n",
+ "from ipywidgets import IntProgress, FloatText\n",
+ "from IPython.display import display\n",
+ "\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "plt.rcParams[\"font.family\"] = \"serif\"\n",
+ "plt.rcParams[\"font.size\"] = 13"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AEXAJWHXPfsf"
+ },
+ "outputs": [],
+ "source": [
+ "from omop_learn.backends.bigquery import BigQueryBackend\n",
+ "from omop_learn.data.cohort import Cohort\n",
+ "from omop_learn.data.feature import Feature\n",
+ "from omop_learn.utils.config import Config\n",
+ "from omop_learn.omop import OMOPDataset\n",
+ "from omop_learn.utils import date_utils, embedding_utils\n",
+ "from omop_learn.sparse.models import OMOPLogisticRegression\n",
+ "from omop_learn.models import transformer, visit_transformer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GU-M1ASqPfsg"
+ },
+ "source": [
+ "## Cohort, Outcome and Feature Collection\n",
+ "\n",
+ "### 1. Set up a connection to the OMOP CDM database\n",
+ "\n",
+ "Parameters for connection to be specified in ./config.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZtAvjOCzPfsg",
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "config = Config({\n",
+ " \"project_name\": \"project\",\n",
+ " \"cdm_schema\": \"bigquery-public-data.cms_synthetic_patient_data_omop\",\n",
+ " \"prefix_schema\": \"username\",\n",
+ " \"datasets_dir\": \"data_dir\",\n",
+ " \"models_dir\": \"model_dir\"\n",
+ "})\n",
+ "\n",
+ "# Set up database, reset schemas as needed\n",
+ "backend = BigQueryBackend(config)\n",
+ "backend.reset_schema(config.prefix_schema) # Rebuild schema from scratch\n",
+ "backend.create_schema(config.prefix_schema) # Create schema if not exists\n",
+ "\n",
+ "cohort_params = {\n",
+ " \"cohort_table_name\": \"synpuf_eol_cohort\",\n",
+ " \"schema_name\": config.prefix_schema,\n",
+ " \"cdm_schema\": config.cdm_schema,\n",
+ " \"aux_data_schema\": config.aux_cdm_schema,\n",
+ " \"training_start_date\": \"2009-01-01\",\n",
+ " \"training_end_date\": \"2009-12-31\",\n",
+ " \"gap\": \"3 month\",\n",
+ " \"outcome_window\": \"6 month\",\n",
+ "}\n",
+ "sql_dir = \"./bigquery_sql\"\n",
+ "sql_file = open(f\"{sql_dir}/gen_EOL_cohort.sql\", 'r')\n",
+ "cohort = Cohort.from_sql_file(sql_file, backend, params=cohort_params)\n",
+ "\n",
+ "feature_names = [\"drugs\", \"conditions\", \"procedures\"]\n",
+ "feature_paths = [f\"{sql_dir}/{feature_name}.sql\" for feature_name in feature_names]\n",
+ "features = [Feature(n, p) for n, p in zip(feature_names, feature_paths)]\n",
+ "\n",
+ "init_args = {\n",
+ " \"config\" : config,\n",
+ " \"name\" : \"synpuf_eol\",\n",
+ " \"cohort\" : cohort,\n",
+ " \"features\": features,\n",
+ " \"backend\": backend,\n",
+ "}\n",
+ "\n",
+ "dataset = OMOPDataset(**init_args)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Oa7ULCSRPfsh"
+ },
+ "source": [
+ "### 4. Process the collected data and calculate indices needed for the deep model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "U-kcFPtOPfsh"
+ },
+ "outputs": [],
+ "source": [
+ "window_days = [30, 180, 365, 730, 1000]\n",
+ "windowed_dataset = dataset.to_windowed(window_days)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nmZZunVJPfsi"
+ },
+ "outputs": [],
+ "source": [
+ "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "aLSE6AG8Pfsi"
+ },
+ "outputs": [],
+ "source": [
+ "# process data for deep model\n",
+ "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords\n",
+ "outcomes_filt = windowed_dataset.outcomes\n",
+ "time_to_idx = windowed_dataset.times_map\n",
+ "idx_to_datetime = {idx: date_utils.from_unixtime([time])[0] for time, idx in time_to_idx.items()}\n",
+ "\n",
+ "all_codes_tensor = code_ixs\n",
+ "people = sorted(np.unique(person_ixs))\n",
+ "person_indices = np.searchsorted(person_ixs, people)\n",
+ "person_indices = np.append(person_indices, len(person_ixs))\n",
+ "person_chunks = [\n",
+ " time_ixs[person_indices[i]: person_indices[i + 1]]\n",
+ " for i in range(len(person_indices) - 1)\n",
+ "]\n",
+ "\n",
+ "visit_chunks = []\n",
+ "visit_times_raw = []\n",
+ "\n",
+ "for i, chunk in enumerate(person_chunks):\n",
+ " visits = sorted(np.unique(chunk))\n",
+ " visit_indices_local = np.searchsorted(chunk, visits)\n",
+ " visit_indices_local = np.append(\n",
+ " visit_indices_local,\n",
+ " len(chunk)\n",
+ " )\n",
+ " visit_chunks.append(visit_indices_local)\n",
+ " visit_times_raw.append(visits)\n",
+ "\n",
+ "n_visits = {i:len(j) for i,j in enumerate(visit_times_raw)}\n",
+ "\n",
+ "visit_days_rel = {\n",
+ " i: (\n",
+ " pd.to_datetime(cohort_params['training_end_date']) \\\n",
+ " - pd.to_datetime(idx_to_datetime[time])\n",
+ " ).days for time in time_ixs\n",
+ "}\n",
+ "vdrel_func = np.vectorize(visit_days_rel.get)\n",
+ "visit_time_rel = [\n",
+ " vdrel_func(v) for v in visit_times_raw\n",
+ "]\n",
+ "\n",
+ "remap = {\n",
+ " 'id': people,\n",
+ " 'time': sorted(np.unique(time_ixs)),\n",
+ " 'concept': sorted(np.unique(code_ixs))\n",
+ "}\n",
+ "\n",
+ "dataset_dict = {\n",
+ " 'all_codes_tensor': all_codes_tensor, # A tensor of all codes occurring in the dataset\n",
+ " 'person_indices': person_indices, # A list of indices such that all_codes_tensor[person_indices[i]: person_indices[i+1]] are the codes assigned to the ith patient\n",
+ " 'visit_chunks': visit_chunks, # A list of indices such that all_codes_tensor[person_indices[i]+visit_chunks[j]:person_indices[i]+visit_chunks[j+1]] are the codes assigned to the ith patient during their jth visit\n",
+ " 'visit_time_rel': visit_time_rel, # A list of times (as measured in days to the prediction date) for each visit\n",
+ " 'n_visits': n_visits, # A dict defined such that n_visits[i] is the number of visits made by the ith patient\n",
+ " 'outcomes_filt': outcomes_filt, # A pandas Series defined such that outcomes_filt.iloc[i] is the outcome of the ith patient\n",
+ " 'remap': remap,\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9rumxm09Pfsj"
+ },
+ "source": [
+ "## Run the windowed regression model on the task defined above"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "EBWOySUSPfsj"
+ },
+ "outputs": [],
+ "source": [
+ "# split data into train, validate and test sets\n",
+ "windowed_dataset.split()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jtVIFn2CPfsj"
+ },
+ "outputs": [],
+ "source": [
+ "# train the regression model over several choices of regularization parameter\n",
+ "reg_lambdas = [2, 0.2, 0.02]\n",
+ "lr_val_aucs = []\n",
+ "model = OMOPLogisticRegression(\"eol_new_50\", windowed_dataset)\n",
+ "\n",
+ "for reg_lambda in reg_lambdas:\n",
+ " # Gen and fit\n",
+ " model.gen_pipeline(reg_lambda)\n",
+ " model.fit()\n",
+ " # Eval on validation data\n",
+ " pred_lr = model._pipeline.predict_proba(windowed_dataset.val['X'])[:, 1]\n",
+ " lr_val_auc = roc_auc_score(windowed_dataset.val['y'], pred_lr)\n",
+ " lr_val_aucs.append(lr_val_auc)\n",
+ " print(\"C: %.4f, Val AUC: %.2f\" % (reg_lambda, lr_val_auc))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "pYy4lcbePfsj"
+ },
+ "outputs": [],
+ "source": [
+ "# Gen and fit on best C\n",
+ "best_reg_lambda = reg_lambdas[np.argmax(lr_val_aucs)]\n",
+ "model.gen_pipeline(best_reg_lambda)\n",
+ "model.fit()\n",
+ "# Eval on test data\n",
+ "pred_lr = model._pipeline.predict_proba(windowed_dataset.test['X'])[:, 1]\n",
+ "score = roc_auc_score(windowed_dataset.test['y'], pred_lr)\n",
+ "print(\"C: %.4f, Test AUC: %.2f\" % (best_reg_lambda, score))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jvYkElLlPfsk"
+ },
+ "source": [
+ "### Learn a Word2Vec embedding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e7oM4vOsPfsk"
+ },
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "embedding_dim = 300 # size of embedding, must be multiple of number of heads\n",
+ "window_days = 90 # number of days in window that defines a \"Sentence\" when learning the embedding\n",
+ "train_coords = np.nonzero(np.where(np.isin(person_ixs, indices_train), 1, 0))\n",
+ "embedding_filename = embedding_utils.train_embedding(featureSet, feature_matrix_3d_transpose, window_days, \\\n",
+ " person_ixs[train_coords], time_ixs[train_coords], \\\n",
+ " remap['time'], embedding_dim)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vlIGYdtIPfsk"
+ },
+ "source": [
+ "## Run the SARD deep model on the predictive task\n",
+ "### 1. Set Model Parameters and Construct the Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Yyr2CO1wPfsk"
+ },
+ "outputs": [],
+ "source": [
+ "# using the same split as before, create train/validate/test batches for the deep model\n",
+ "# `mbsz` might need to be decreased based on the GPU's memory and the number of features being used\n",
+ "mbsz = 50\n",
+ "def get_batches(arr, mbsz=mbsz):\n",
+ " curr, ret = 0, []\n",
+ " while curr < len(arr) - 1:\n",
+ " ret.append(arr[curr : curr + mbsz])\n",
+ " curr += mbsz\n",
+ " return ret\n",
+ "\n",
+ "p_ranges_train, p_ranges_test = [\n",
+ " get_batches(arr) for arr in (\n",
+ " indices_train, indices_test\n",
+ " )\n",
+ "]\n",
+ "p_ranges_val = p_ranges_test[:val_size // mbsz]\n",
+ "p_ranges_test = p_ranges_test[val_size // mbsz:]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AfUpBu4sPfsl"
+ },
+ "outputs": [],
+ "source": [
+ "# Pick a name for the model (mn_prefix) that will be used when saving checkpoints\n",
+ "# Then, set some parameters for SARD. The values below reflect a good starting point that performed well on several tasks\n",
+ "mn_prefix = 'eol_experiment_prefix'\n",
+ "n_heads = 2\n",
+ "assert embedding_dim % n_heads == 0\n",
+ "model_params = {\n",
+ " 'embedding_dim': int(embedding_dim / n_heads), # Dimension per head of visit embeddings\n",
+ " 'n_heads': n_heads, # Number of self-attention heads\n",
+ " 'attn_depth': 2, # Number of stacked self-attention layers\n",
+ " 'dropout': 0.05, # Dropout rate for both self-attention and the final prediction layer\n",
+ " 'use_mask': True, # Only allow visits to attend to other actual visits, not to padding visits\n",
+ " 'concept_embedding_path': embedding_filename # if unspecified, uses default Torch embeddings\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GJ2DXdlWPfsl"
+ },
+ "outputs": [],
+ "source": [
+ "# Set up fixed model parameters, loss functions, and build the model on the GPU\n",
+ "lr = 2e-4\n",
+ "n_epochs_pretrain = 1\n",
+ "ft_epochs = 1\n",
+ "\n",
+ "update_every = 500\n",
+ "update_mod = update_every // mbsz\n",
+ "\n",
+ "base_model = visit_transformer.VisitTransformer(\n",
+ " featureSet, **model_params\n",
+ ")\n",
+ "\n",
+ "clf = visit_transformer.VTClassifer(\n",
+ " base_model, **model_params\n",
+ ").cuda()\n",
+ "\n",
+ "clf.bert.set_data(\n",
+ " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n",
+ " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n",
+ " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n",
+ ")\n",
+ "\n",
+ "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n",
+ " pos_weight=torch.FloatTensor([\n",
+ " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n",
+ " ]), reduction='sum'\n",
+ ").cuda()\n",
+ "\n",
+ "optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=lr)\n",
+ "\n",
+ "def eval_curr_model_on(a):\n",
+ " with torch.no_grad():\n",
+ " preds_test, true_test = [], []\n",
+ " for batch_num, p_range in enumerate(a):\n",
+ " y_pred = clf(p_range)\n",
+ " preds_test += y_pred.tolist()\n",
+ " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n",
+ " return roc_auc_score(true_test, preds_test)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4l95ivncPfsl"
+ },
+ "source": [
+ "### 2. Fit the SARD model to the best windowed linear model (Reverse Distillation)\n",
+ "\n",
+ "The following code saves models in a folder `/SavedModels/{task}/`; make sure to create the directory before running."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "UgJt7p11Pfsl"
+ },
+ "outputs": [],
+ "source": [
+ "task = 'eol'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ESuCBwrMPfsl"
+ },
+ "outputs": [],
+ "source": [
+ "# Run `n_epochs_pretrain` of Reverse Distillation pretraining\n",
+ "val_losses = []\n",
+ "progress_bar = IntProgress(min=0, max=int(n_epochs_pretrain * len(p_ranges_train)))\n",
+ "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n",
+ "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n",
+ "\n",
+ "display(progress_bar)\n",
+ "display(batch_loss_disp)\n",
+ "display(time_disp)\n",
+ "\n",
+ "for epoch in range(n_epochs_pretrain):\n",
+ " t, batch_loss = time.time(), 0\n",
+ "\n",
+ " for batch_num, p_range in enumerate(p_ranges_train):\n",
+ "\n",
+ " if batch_num % 50 == 0:\n",
+ " batch_loss_disp.value = round(batch_loss / 50, 2)\n",
+ " time_disp.value = round(time.time() - t, 2)\n",
+ " t, batch_loss = time.time(), 0\n",
+ "\n",
+ " y_pred = clf(p_range)\n",
+ " loss_distill = loss_function_distill(\n",
+ " y_pred, torch.FloatTensor(pred_lr_all[p_range]).cuda()\n",
+ " )\n",
+ "\n",
+ " batch_loss += loss_distill.item()\n",
+ " loss_distill.backward()\n",
+ "\n",
+ " if batch_num % update_mod == 0:\n",
+ " optimizer_clf.step()\n",
+ " optimizer_clf.zero_grad()\n",
+ "\n",
+ " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n",
+ "\n",
+ " torch.save(\n",
+ " clf.state_dict(),\n",
+ " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
+ " task=task, mn_prefix = mn_prefix, epochs = epoch + 1\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " clf.eval()\n",
+ " ckpt_auc = eval_curr_model_on(p_ranges_val)\n",
+ " print('Epochs: {} | Val AUC: {}'.format(epoch + 1, ckpt_auc))\n",
+ " val_losses.append(ckpt_auc)\n",
+ " clf.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "S79PbKdhPfsm"
+ },
+ "outputs": [],
+ "source": [
+ "# Save the pretrained model with best validation-set performance\n",
+ "clf.load_state_dict(\n",
+ " torch.load(\"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
+ " task=task, mn_prefix=mn_prefix, epochs=np.argmax(val_losses) + 1\n",
+ " ))\n",
+ ")\n",
+ "torch.save(\n",
+ " clf.state_dict(),\n",
+ " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
+ " task=task, mn_prefix = mn_prefix, epochs = 'BEST'\n",
+ " )\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6Wzmh6tCPfsm"
+ },
+ "source": [
+ "### 3. Fine-tune the SARD model by training to match the actual outcomes on the training set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iczzP5NAPfsm"
+ },
+ "outputs": [],
+ "source": [
+ "# Set up loss functions for fine-tuning. There are two terms:\n",
+ "# - `loss_function_distill`, which penalizes differences between the linear model prediction and SARD's prediction\n",
+ "# - `loss_function_clf`, which penalizes differences between the true outcome and SARD's prediction\n",
+ "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n",
+ " pos_weight=torch.FloatTensor([\n",
+ " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n",
+ " ]), reduction='sum'\n",
+ ").cuda()\n",
+ "\n",
+ "loss_function_clf = torch.nn.BCEWithLogitsLoss(\n",
+ " pos_weight=torch.FloatTensor([\n",
+ " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n",
+ " ]), reduction='sum'\n",
+ ").cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "30_HYniePfsm"
+ },
+ "outputs": [],
+ "source": [
+ "# run `ft_epochs` of fine-tuning training, for each of the values of `alpha` below.\n",
+ "# Note that `alpha` is the relative weight of `loss_function_distill` as compared to `loss_function_clf`\n",
+ "\n",
+ "all_pred_models = {}\n",
+ "\n",
+ "progress_bar = IntProgress(min=0, max=int(ft_epochs * len(p_ranges_train)))\n",
+ "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n",
+ "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n",
+ "\n",
+ "display(progress_bar)\n",
+ "display(batch_loss_disp)\n",
+ "display(time_disp)\n",
+ "\n",
+ "\n",
+ "no_rd = False\n",
+ "for alpha in [0,0.05,0.1,0.15, 0.2]:\n",
+ "\n",
+ " progress_bar.value = 0\n",
+ "\n",
+ " if no_rd:\n",
+ " pretrained_model_fn = mn_prefix + '_None'\n",
+ " start_model = None\n",
+ " if start_model is None:\n",
+ " base_model = visit_transformer.VisitTransformer(\n",
+ " featureSet, **model_params\n",
+ " )\n",
+ "\n",
+ " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n",
+ "\n",
+ " clf.bert.set_data(\n",
+ " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n",
+ " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n",
+ " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n",
+ " )\n",
+ " else:\n",
+ " pretrained_model_path = \"SavedModels/{task}/{start_model}\".format(\n",
+ " task=task, start_model=start_model\n",
+ " )\n",
+ " clf.load_state_dict(torch.load(pretrained_model_path))\n",
+ "\n",
+ " else:\n",
+ " pretrained_model_fn = \"{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
+ " mn_prefix=mn_prefix, epochs='BEST'\n",
+ " )\n",
+ " pretrained_model_path = \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n",
+ " task=task, mn_prefix=mn_prefix, epochs='BEST'\n",
+ " )\n",
+ "\n",
+ " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n",
+ " clf.bert.set_data(\n",
+ " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n",
+ " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n",
+ " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n",
+ " )\n",
+ "\n",
+ " clf.load_state_dict(torch.load(pretrained_model_path))\n",
+ "\n",
+ " clf.train()\n",
+ "\n",
+ " optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=2e-4)\n",
+ "\n",
+ " for epoch in range(ft_epochs):\n",
+ "\n",
+ " t, batch_loss = time.time(), 0\n",
+ "\n",
+ " for batch_num, p_range in enumerate(p_ranges_train):\n",
+ "\n",
+ " if batch_num % 50 == 0:\n",
+ " batch_loss_disp.value = round(batch_loss / 50, 2)\n",
+ " time_disp.value = round(time.time() - t, 2)\n",
+ " t, batch_loss = time.time(), 0\n",
+ "\n",
+ " y_pred = clf(p_range)\n",
+ "\n",
+ " loss = loss_function_clf(\n",
+ " y_pred,\n",
+ " torch.FloatTensor(dataset_dict['outcomes_filt'].values[p_range]).cuda()\n",
+ " )\n",
+ "\n",
+ " loss_distill = loss_distill = loss_function_distill(\n",
+ " y_pred,\n",
+ " torch.FloatTensor(pred_lr_all[p_range]).cuda()\n",
+ " )\n",
+ "\n",
+ " batch_loss += loss.item() + alpha * loss_distill.item()\n",
+ " loss_total = loss + alpha * loss_distill\n",
+ " loss_total.backward()\n",
+ "\n",
+ " if batch_num % update_mod == 0:\n",
+ " optimizer_clf.step()\n",
+ " optimizer_clf.zero_grad()\n",
+ "\n",
+ " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n",
+ "\n",
+ " saving_fn = \"{pretrain}_alpha_{alpha}_epochs_{epochs}\".format(\n",
+ " task=task, pretrain = pretrained_model_fn, alpha=alpha, epochs = epoch + 1\n",
+ " )\n",
+ " torch.save(\n",
+ " clf.state_dict(),\n",
+ " \"SavedModels/{task}/{saving_fn}\".format(\n",
+ " task=task, saving_fn=saving_fn\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " clf.eval()\n",
+ " val_auc = eval_curr_model_on(p_ranges_val)\n",
+ " print(val_auc)\n",
+ " all_pred_models[saving_fn] = val_auc\n",
+ " clf.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rXk7tpRdPfsm"
+ },
+ "source": [
+ "### 4. Evaluate the best SARD model, as determined by validation performance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "35WmOaisPfsm"
+ },
+ "outputs": [],
+ "source": [
+ "best_model = max(all_pred_models, key=all_pred_models.get)\n",
+ "clf.load_state_dict(\n",
+ " torch.load(\"SavedModels/{task}/{model}\".format(\n",
+ " task=task, model=best_model\n",
+ " ))\n",
+ ")\n",
+ "clf.eval();\n",
+ "with torch.no_grad():\n",
+ " preds_test, true_test = [], []\n",
+ " for batch_num, p_range in enumerate(p_ranges_test):\n",
+ " y_pred = clf(p_range)\n",
+ " preds_test += y_pred.tolist()\n",
+ " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n",
+ " print(roc_auc_score(true_test, preds_test))\n",
+ "clf.train();"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "include_colab_link": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "omop-learn",
+ "language": "python",
+ "name": "omop-learn"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
}