diff --git a/examples/eol/sard_eol.ipynb b/examples/eol/sard_eol.ipynb index c272c50..8c9f185 100644 --- a/examples/eol/sard_eol.ipynb +++ b/examples/eol/sard_eol.ipynb @@ -1,758 +1,758 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n3ieKAgvPfsZ" - }, - "source": [ - "# Run End of Life prediction task on Synthetic Patient Data in OMOP" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u1UUubwGPfsc" - }, - "source": [ - "This notebook runs the end-of-life (EOL) prediction task on synthetic patient data in OMOP using a linear baseline model and the SARD architecture [Kodialam et al. 2021].\n", - "\n", - "Data is sourced from the publicly available Medicare Claims Synthetic Public Use Files (SynPUF), released by the Centers for Medicare and Medicaid Services (CMS) and available in [Google BigQuery. The synthetic set contains 2008-2010 Medicare insurance claims for development and demonstration purposes and was coverted to the Medical Outcomes Partnership (OMOP) Common Data Model from its original CSV form." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YaHCADZXPfsd" - }, - "source": [ - "## Imports and GPU setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "lYmOTr6sPfse" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import time\n", - "import os\n", - "\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import roc_auc_score\n", - "\n", - "from ipywidgets import IntProgress, FloatText\n", - "from IPython.display import display\n", - "\n", - "import matplotlib\n", - "import matplotlib.pyplot as plt\n", - "\n", - "plt.rcParams[\"font.family\"] = \"serif\"\n", - "plt.rcParams[\"font.size\"] = 13" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "AEXAJWHXPfsf" - }, - "outputs": [], - "source": [ - "from omop_learn.backends.bigquery import BigQueryBackend\n", - "from omop_learn.data.cohort import Cohort\n", - "from omop_learn.data.feature import Feature\n", - "from omop_learn.utils.config import Config\n", - "from omop_learn.omop import OMOPDataset\n", - "from omop_learn.utils import date_utils, embedding_utils\n", - "from omop_learn.sparse.models import OMOPLogisticRegression\n", - "from omop_learn.models import transformer, visit_transformer" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GU-M1ASqPfsg" - }, - "source": [ - "## Cohort, Outcome and Feature Collection\n", - "\n", - "### 1. Set up a connection to the OMOP CDM database\n", - "\n", - "Parameters for connection to be specified in ./config.py" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false, - "id": "ZtAvjOCzPfsg" - }, - "outputs": [], - "source": [ - "config = Config({\n", - " \"project_name\": \"project\",\n", - " \"cdm_schema\": \"bigquery-public-data.cms_synthetic_patient_data_omop\",\n", - " \"prefix_schema\": \"username\",\n", - " \"datasets_dir\": \"data_dir\",\n", - " \"models_dir\": \"model_dir\"\n", - "})\n", - "\n", - "# Set up database, reset schemas as needed\n", - "backend = BigQueryBackend(config)\n", - "backend.reset_schema(config.prefix_schema) # Rebuild schema from scratch\n", - "backend.create_schema(config.prefix_schema) # Create schema if not exists\n", - "\n", - "cohort_params = {\n", - " \"cohort_table_name\": \"synpuf_eol_cohort\",\n", - " \"schema_name\": config.prefix_schema,\n", - " \"cdm_schema\": config.cdm_schema,\n", - " \"aux_data_schema\": config.aux_cdm_schema,\n", - " \"training_start_date\": \"2009-01-01\",\n", - " \"training_end_date\": \"2009-12-31\",\n", - " \"gap\": \"3 month\",\n", - " \"outcome_window\": \"6 month\",\n", - "}\n", - "sql_dir = \"./bigquery_sql\"\n", - "sql_file = open(f\"{sql_dir}/gen_EOL_cohort.sql\", 'r')\n", - "cohort = Cohort.from_sql_file(sql_file, backend, params=cohort_params)\n", - "\n", - "feature_names = [\"drugs\", \"conditions\", \"procedures\"]\n", - "feature_paths = [f\"{sql_dir}/{feature_name}.sql\" for feature_name in feature_names]\n", - "features = [Feature(n, p) for n, p in zip(feature_names, feature_paths)]\n", - "\n", - "init_args = {\n", - " \"config\" : config,\n", - " \"name\" : \"synpuf_eol\",\n", - " \"cohort\" : cohort,\n", - " \"features\": features,\n", - " \"backend\": backend,\n", - "}\n", - "\n", - "dataset = OMOPDataset(**init_args)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Oa7ULCSRPfsh" - }, - "source": [ - "### 4. Process the collected data and calculate indices needed for the deep model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "U-kcFPtOPfsh" - }, - "outputs": [], - "source": [ - "window_days = [30, 180, 365, 730, 1000]\n", - "windowed_dataset = dataset.to_windowed(window_days)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "nmZZunVJPfsi" - }, - "outputs": [], - "source": [ - "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aLSE6AG8Pfsi" - }, - "outputs": [], - "source": [ - "# process data for deep model\n", - "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords\n", - "outcomes_filt = windowed_dataset.outcomes\n", - "time_to_idx = windowed_dataset.times_map\n", - "idx_to_datetime = {idx: date_utils.from_unixtime([time])[0] for time, idx in time_to_idx.items()}\n", - "\n", - "all_codes_tensor = code_ixs\n", - "people = sorted(np.unique(person_ixs))\n", - "person_indices = np.searchsorted(person_ixs, people)\n", - "person_indices = np.append(person_indices, len(person_ixs))\n", - "person_chunks = [\n", - " time_ixs[person_indices[i]: person_indices[i + 1]]\n", - " for i in range(len(person_indices) - 1)\n", - "]\n", - "\n", - "visit_chunks = []\n", - "visit_times_raw = []\n", - "\n", - "for i, chunk in enumerate(person_chunks):\n", - " visits = sorted(np.unique(chunk))\n", - " visit_indices_local = np.searchsorted(chunk, visits)\n", - " visit_indices_local = np.append(\n", - " visit_indices_local,\n", - " len(chunk)\n", - " )\n", - " visit_chunks.append(visit_indices_local)\n", - " visit_times_raw.append(visits)\n", - "\n", - "n_visits = {i:len(j) for i,j in enumerate(visit_times_raw)}\n", - "\n", - "visit_days_rel = {\n", - " i: (\n", - " pd.to_datetime(cohort_params['training_end_date']) \\\n", - " - pd.to_datetime(idx_to_datetime[time])\n", - " ).days for time in time_ixs\n", - "}\n", - "vdrel_func = np.vectorize(visit_days_rel.get)\n", - "visit_time_rel = [\n", - " vdrel_func(v) for v in visit_times_raw\n", - "]\n", - "\n", - "remap = {\n", - " 'id': people,\n", - " 'time': sorted(np.unique(time_ixs)),\n", - " 'concept': sorted(np.unique(code_ixs))\n", - "}\n", - "\n", - "dataset_dict = {\n", - " 'all_codes_tensor': all_codes_tensor, # A tensor of all codes occurring in the dataset\n", - " 'person_indices': person_indices, # A list of indices such that all_codes_tensor[person_indices[i]: person_indices[i+1]] are the codes assigned to the ith patient\n", - " 'visit_chunks': visit_chunks, # A list of indices such that all_codes_tensor[person_indices[i]+visit_chunks[j]:person_indices[i]+visit_chunks[j+1]] are the codes assigned to the ith patient during their jth visit\n", - " 'visit_time_rel': visit_time_rel, # A list of times (as measured in days to the prediction date) for each visit\n", - " 'n_visits': n_visits, # A dict defined such that n_visits[i] is the number of visits made by the ith patient\n", - " 'outcomes_filt': outcomes_filt, # A pandas Series defined such that outcomes_filt.iloc[i] is the outcome of the ith patient\n", - " 'remap': remap,\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9rumxm09Pfsj" - }, - "source": [ - "## Run the windowed regression model on the task defined above" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "EBWOySUSPfsj" - }, - "outputs": [], - "source": [ - "# split data into train, validate and test sets\n", - "windowed_dataset.split()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jtVIFn2CPfsj" - }, - "outputs": [], - "source": [ - "# train the regression model over several choices of regularization parameter\n", - "reg_lambdas = [2, 0.2, 0.02]\n", - "lr_val_aucs = []\n", - "model = OMOPLogisticRegression(\"eol_new_50\", windowed_dataset)\n", - "\n", - "for reg_lambda in reg_lambdas:\n", - " # Gen and fit\n", - " model.gen_pipeline(reg_lambda)\n", - " model.fit()\n", - " # Eval on validation data\n", - " pred_lr = model._pipeline.predict_proba(windowed_dataset.val['X'])[:, 1]\n", - " lr_val_auc = roc_auc_score(windowed_dataset.val['y'], pred_lr)\n", - " lr_val_aucs.append(lr_val_auc)\n", - " print(\"C: %.4f, Val AUC: %.2f\" % (reg_lambda, lr_val_auc))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pYy4lcbePfsj" - }, - "outputs": [], - "source": [ - "# Gen and fit on best C\n", - "best_reg_lambda = reg_lambdas[np.argmax(lr_val_aucs)]\n", - "model.gen_pipeline(best_reg_lambda)\n", - "model.fit()\n", - "# Eval on test data\n", - "pred_lr = model._pipeline.predict_proba(windowed_dataset.test['X'])[:, 1]\n", - "score = roc_auc_score(windowed_dataset.test['y'], pred_lr)\n", - "print(\"C: %.4f, Test AUC: %.2f\" % (best_reg_lambda, score))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jvYkElLlPfsk" - }, - "source": [ - "### Learn a Word2Vec embedding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e7oM4vOsPfsk" - }, - "outputs": [], - "source": [ - "%%time\n", - "embedding_dim = 300 # size of embedding, must be multiple of number of heads\n", - "window_days = 90 # number of days in window that defines a \"Sentence\" when learning the embedding\n", - "train_coords = np.nonzero(np.where(np.isin(person_ixs, indices_train), 1, 0))\n", - "embedding_filename = embedding_utils.train_embedding(featureSet, feature_matrix_3d_transpose, window_days, \\\n", - " person_ixs[train_coords], time_ixs[train_coords], \\\n", - " remap['time'], embedding_dim)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vlIGYdtIPfsk" - }, - "source": [ - "## Run the SARD deep model on the predictive task\n", - "### 1. Set Model Parameters and Construct the Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Yyr2CO1wPfsk" - }, - "outputs": [], - "source": [ - "# using the same split as before, create train/validate/test batches for the deep model\n", - "# `mbsz` might need to be decreased based on the GPU's memory and the number of features being used\n", - "mbsz = 50\n", - "def get_batches(arr, mbsz=mbsz):\n", - " curr, ret = 0, []\n", - " while curr < len(arr) - 1:\n", - " ret.append(arr[curr : curr + mbsz])\n", - " curr += mbsz\n", - " return ret\n", - "\n", - "p_ranges_train, p_ranges_test = [\n", - " get_batches(arr) for arr in (\n", - " indices_train, indices_test\n", - " )\n", - "]\n", - "p_ranges_val = p_ranges_test[:val_size // mbsz]\n", - "p_ranges_test = p_ranges_test[val_size // mbsz:]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "AfUpBu4sPfsl" - }, - "outputs": [], - "source": [ - "# Pick a name for the model (mn_prefix) that will be used when saving checkpoints\n", - "# Then, set some parameters for SARD. The values below reflect a good starting point that performed well on several tasks\n", - "mn_prefix = 'eol_experiment_prefix'\n", - "n_heads = 2\n", - "assert embedding_dim % n_heads == 0\n", - "model_params = {\n", - " 'embedding_dim': int(embedding_dim / n_heads), # Dimension per head of visit embeddings\n", - " 'n_heads': n_heads, # Number of self-attention heads\n", - " 'attn_depth': 2, # Number of stacked self-attention layers\n", - " 'dropout': 0.05, # Dropout rate for both self-attention and the final prediction layer\n", - " 'use_mask': True, # Only allow visits to attend to other actual visits, not to padding visits\n", - " 'concept_embedding_path': embedding_filename # if unspecified, uses default Torch embeddings\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GJ2DXdlWPfsl" - }, - "outputs": [], - "source": [ - "# Set up fixed model parameters, loss functions, and build the model on the GPU\n", - "lr = 2e-4\n", - "n_epochs_pretrain = 1\n", - "ft_epochs = 1\n", - "\n", - "update_every = 500\n", - "update_mod = update_every // mbsz\n", - "\n", - "base_model = visit_transformer.VisitTransformer(\n", - " featureSet, **model_params\n", - ")\n", - "\n", - "clf = visit_transformer.VTClassifer(\n", - " base_model, **model_params\n", - ").cuda()\n", - "\n", - "clf.bert.set_data(\n", - " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n", - " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n", - " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n", - ")\n", - "\n", - "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n", - " pos_weight=torch.FloatTensor([\n", - " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n", - " ]), reduction='sum'\n", - ").cuda()\n", - "\n", - "optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=lr)\n", - "\n", - "def eval_curr_model_on(a):\n", - " with torch.no_grad():\n", - " preds_test, true_test = [], []\n", - " for batch_num, p_range in enumerate(a):\n", - " y_pred = clf(p_range)\n", - " preds_test += y_pred.tolist()\n", - " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n", - " return roc_auc_score(true_test, preds_test)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4l95ivncPfsl" - }, - "source": [ - "### 2. Fit the SARD model to the best windowed linear model (Reverse Distillation)\n", - "\n", - "The following code saves models in a folder `/SavedModels/{task}/`; make sure to create the directory before running." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UgJt7p11Pfsl" - }, - "outputs": [], - "source": [ - "task = 'eol'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ESuCBwrMPfsl" - }, - "outputs": [], - "source": [ - "# Run `n_epochs_pretrain` of Reverse Distillation pretraining\n", - "val_losses = []\n", - "progress_bar = IntProgress(min=0, max=int(n_epochs_pretrain * len(p_ranges_train)))\n", - "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n", - "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n", - "\n", - "display(progress_bar)\n", - "display(batch_loss_disp)\n", - "display(time_disp)\n", - "\n", - "for epoch in range(n_epochs_pretrain):\n", - " t, batch_loss = time.time(), 0\n", - "\n", - " for batch_num, p_range in enumerate(p_ranges_train):\n", - "\n", - " if batch_num % 50 == 0:\n", - " batch_loss_disp.value = round(batch_loss / 50, 2)\n", - " time_disp.value = round(time.time() - t, 2)\n", - " t, batch_loss = time.time(), 0\n", - "\n", - " y_pred = clf(p_range)\n", - " loss_distill = loss_function_distill(\n", - " y_pred, torch.FloatTensor(pred_lr_all[p_range]).cuda()\n", - " )\n", - "\n", - " batch_loss += loss_distill.item()\n", - " loss_distill.backward()\n", - "\n", - " if batch_num % update_mod == 0:\n", - " optimizer_clf.step()\n", - " optimizer_clf.zero_grad()\n", - "\n", - " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n", - "\n", - " torch.save(\n", - " clf.state_dict(),\n", - " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", - " task=task, mn_prefix = mn_prefix, epochs = epoch + 1\n", - " )\n", - " )\n", - "\n", - " clf.eval()\n", - " ckpt_auc = eval_curr_model_on(p_ranges_val)\n", - " print('Epochs: {} | Val AUC: {}'.format(epoch + 1, ckpt_auc))\n", - " val_losses.append(ckpt_auc)\n", - " clf.train()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "S79PbKdhPfsm" - }, - "outputs": [], - "source": [ - "# Save the pretrained model with best validation-set performance\n", - "clf.load_state_dict(\n", - " torch.load(\"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", - " task=task, mn_prefix=mn_prefix, epochs=np.argmax(val_losses) + 1\n", - " ))\n", - ")\n", - "torch.save(\n", - " clf.state_dict(),\n", - " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", - " task=task, mn_prefix = mn_prefix, epochs = 'BEST'\n", - " )\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6Wzmh6tCPfsm" - }, - "source": [ - "### 3. Fine-tune the SARD model by training to match the actual outcomes on the training set" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "iczzP5NAPfsm" - }, - "outputs": [], - "source": [ - "# Set up loss functions for fine-tuning. There are two terms:\n", - "# - `loss_function_distill`, which penalizes differences between the linear model prediction and SARD's prediction\n", - "# - `loss_function_clf`, which penalizes differences between the true outcome and SARD's prediction\n", - "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n", - " pos_weight=torch.FloatTensor([\n", - " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n", - " ]), reduction='sum'\n", - ").cuda()\n", - "\n", - "loss_function_clf = torch.nn.BCEWithLogitsLoss(\n", - " pos_weight=torch.FloatTensor([\n", - " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n", - " ]), reduction='sum'\n", - ").cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "30_HYniePfsm" - }, - "outputs": [], - "source": [ - "# run `ft_epochs` of fine-tuning training, for each of the values of `alpha` below.\n", - "# Note that `alpha` is the relative weight of `loss_function_distill` as compared to `loss_function_clf`\n", - "\n", - "all_pred_models = {}\n", - "\n", - "progress_bar = IntProgress(min=0, max=int(ft_epochs * len(p_ranges_train)))\n", - "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n", - "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n", - "\n", - "display(progress_bar)\n", - "display(batch_loss_disp)\n", - "display(time_disp)\n", - "\n", - "\n", - "no_rd = False\n", - "for alpha in [0,0.05,0.1,0.15, 0.2]:\n", - "\n", - " progress_bar.value = 0\n", - "\n", - " if no_rd:\n", - " pretrained_model_fn = mn_prefix + '_None'\n", - " start_model = None\n", - " if start_model is None:\n", - " base_model = visit_transformer.VisitTransformer(\n", - " featureSet, **model_params\n", - " )\n", - "\n", - " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n", - "\n", - " clf.bert.set_data(\n", - " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n", - " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n", - " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n", - " )\n", - " else:\n", - " pretrained_model_path = \"SavedModels/{task}/{start_model}\".format(\n", - " task=task, start_model=start_model\n", - " )\n", - " clf.load_state_dict(torch.load(pretrained_model_path))\n", - "\n", - " else:\n", - " pretrained_model_fn = \"{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", - " mn_prefix=mn_prefix, epochs='BEST'\n", - " )\n", - " pretrained_model_path = \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", - " task=task, mn_prefix=mn_prefix, epochs='BEST'\n", - " )\n", - "\n", - " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n", - " clf.bert.set_data(\n", - " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n", - " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n", - " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n", - " )\n", - "\n", - " clf.load_state_dict(torch.load(pretrained_model_path))\n", - "\n", - " clf.train()\n", - "\n", - " optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=2e-4)\n", - "\n", - " for epoch in range(ft_epochs):\n", - "\n", - " t, batch_loss = time.time(), 0\n", - "\n", - " for batch_num, p_range in enumerate(p_ranges_train):\n", - "\n", - " if batch_num % 50 == 0:\n", - " batch_loss_disp.value = round(batch_loss / 50, 2)\n", - " time_disp.value = round(time.time() - t, 2)\n", - " t, batch_loss = time.time(), 0\n", - "\n", - " y_pred = clf(p_range)\n", - "\n", - " loss = loss_function_clf(\n", - " y_pred,\n", - " torch.FloatTensor(dataset_dict['outcomes_filt'].values[p_range]).cuda()\n", - " )\n", - "\n", - " loss_distill = loss_distill = loss_function_distill(\n", - " y_pred,\n", - " torch.FloatTensor(pred_lr_all[p_range]).cuda()\n", - " )\n", - "\n", - " batch_loss += loss.item() + alpha * loss_distill.item()\n", - " loss_total = loss + alpha * loss_distill\n", - " loss_total.backward()\n", - "\n", - " if batch_num % update_mod == 0:\n", - " optimizer_clf.step()\n", - " optimizer_clf.zero_grad()\n", - "\n", - " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n", - "\n", - " saving_fn = \"{pretrain}_alpha_{alpha}_epochs_{epochs}\".format(\n", - " task=task, pretrain = pretrained_model_fn, alpha=alpha, epochs = epoch + 1\n", - " )\n", - " torch.save(\n", - " clf.state_dict(),\n", - " \"SavedModels/{task}/{saving_fn}\".format(\n", - " task=task, saving_fn=saving_fn\n", - " )\n", - " )\n", - "\n", - " clf.eval()\n", - " val_auc = eval_curr_model_on(p_ranges_val)\n", - " print(val_auc)\n", - " all_pred_models[saving_fn] = val_auc\n", - " clf.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rXk7tpRdPfsm" - }, - "source": [ - "### 4. Evaluate the best SARD model, as determined by validation performance" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "35WmOaisPfsm" - }, - "outputs": [], - "source": [ - "best_model = max(all_pred_models, key=all_pred_models.get)\n", - "clf.load_state_dict(\n", - " torch.load(\"SavedModels/{task}/{model}\".format(\n", - " task=task, model=best_model\n", - " ))\n", - ")\n", - "clf.eval();\n", - "with torch.no_grad():\n", - " preds_test, true_test = [], []\n", - " for batch_num, p_range in enumerate(p_ranges_test):\n", - " y_pred = clf(p_range)\n", - " preds_test += y_pred.tolist()\n", - " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n", - " print(roc_auc_score(true_test, preds_test))\n", - "clf.train();" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "omop-learn", - "language": "python", - "name": "omop-learn" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - }, - "colab": { - "provenance": [], - "include_colab_link": true - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n3ieKAgvPfsZ" + }, + "source": [ + "# Run End of Life prediction task on Synthetic Patient Data in OMOP" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u1UUubwGPfsc" + }, + "source": [ + "This notebook runs the end-of-life (EOL) prediction task on synthetic patient data in OMOP using a linear baseline model and the SARD architecture [Kodialam et al. 2021].\n", + "\n", + "Data is sourced from the publicly available Medicare Claims Synthetic Public Use Files (SynPUF), released by the Centers for Medicare and Medicaid Services (CMS) and available in [Google BigQuery](https://console.cloud.google.com/marketplace/product/hhs/synpuf). The synthetic set contains 2008-2010 Medicare insurance claims for development and demonstration purposes and was coverted to the Medical Outcomes Partnership (OMOP) Common Data Model from its original CSV form." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YaHCADZXPfsd" + }, + "source": [ + "## Imports and GPU setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lYmOTr6sPfse" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import time\n", + "import os\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import roc_auc_score\n", + "\n", + "from ipywidgets import IntProgress, FloatText\n", + "from IPython.display import display\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rcParams[\"font.family\"] = \"serif\"\n", + "plt.rcParams[\"font.size\"] = 13" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AEXAJWHXPfsf" + }, + "outputs": [], + "source": [ + "from omop_learn.backends.bigquery import BigQueryBackend\n", + "from omop_learn.data.cohort import Cohort\n", + "from omop_learn.data.feature import Feature\n", + "from omop_learn.utils.config import Config\n", + "from omop_learn.omop import OMOPDataset\n", + "from omop_learn.utils import date_utils, embedding_utils\n", + "from omop_learn.sparse.models import OMOPLogisticRegression\n", + "from omop_learn.models import transformer, visit_transformer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GU-M1ASqPfsg" + }, + "source": [ + "## Cohort, Outcome and Feature Collection\n", + "\n", + "### 1. Set up a connection to the OMOP CDM database\n", + "\n", + "Parameters for connection to be specified in ./config.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZtAvjOCzPfsg", + "scrolled": false + }, + "outputs": [], + "source": [ + "config = Config({\n", + " \"project_name\": \"project\",\n", + " \"cdm_schema\": \"bigquery-public-data.cms_synthetic_patient_data_omop\",\n", + " \"prefix_schema\": \"username\",\n", + " \"datasets_dir\": \"data_dir\",\n", + " \"models_dir\": \"model_dir\"\n", + "})\n", + "\n", + "# Set up database, reset schemas as needed\n", + "backend = BigQueryBackend(config)\n", + "backend.reset_schema(config.prefix_schema) # Rebuild schema from scratch\n", + "backend.create_schema(config.prefix_schema) # Create schema if not exists\n", + "\n", + "cohort_params = {\n", + " \"cohort_table_name\": \"synpuf_eol_cohort\",\n", + " \"schema_name\": config.prefix_schema,\n", + " \"cdm_schema\": config.cdm_schema,\n", + " \"aux_data_schema\": config.aux_cdm_schema,\n", + " \"training_start_date\": \"2009-01-01\",\n", + " \"training_end_date\": \"2009-12-31\",\n", + " \"gap\": \"3 month\",\n", + " \"outcome_window\": \"6 month\",\n", + "}\n", + "sql_dir = \"./bigquery_sql\"\n", + "sql_file = open(f\"{sql_dir}/gen_EOL_cohort.sql\", 'r')\n", + "cohort = Cohort.from_sql_file(sql_file, backend, params=cohort_params)\n", + "\n", + "feature_names = [\"drugs\", \"conditions\", \"procedures\"]\n", + "feature_paths = [f\"{sql_dir}/{feature_name}.sql\" for feature_name in feature_names]\n", + "features = [Feature(n, p) for n, p in zip(feature_names, feature_paths)]\n", + "\n", + "init_args = {\n", + " \"config\" : config,\n", + " \"name\" : \"synpuf_eol\",\n", + " \"cohort\" : cohort,\n", + " \"features\": features,\n", + " \"backend\": backend,\n", + "}\n", + "\n", + "dataset = OMOPDataset(**init_args)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Oa7ULCSRPfsh" + }, + "source": [ + "### 4. Process the collected data and calculate indices needed for the deep model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "U-kcFPtOPfsh" + }, + "outputs": [], + "source": [ + "window_days = [30, 180, 365, 730, 1000]\n", + "windowed_dataset = dataset.to_windowed(window_days)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nmZZunVJPfsi" + }, + "outputs": [], + "source": [ + "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aLSE6AG8Pfsi" + }, + "outputs": [], + "source": [ + "# process data for deep model\n", + "person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords\n", + "outcomes_filt = windowed_dataset.outcomes\n", + "time_to_idx = windowed_dataset.times_map\n", + "idx_to_datetime = {idx: date_utils.from_unixtime([time])[0] for time, idx in time_to_idx.items()}\n", + "\n", + "all_codes_tensor = code_ixs\n", + "people = sorted(np.unique(person_ixs))\n", + "person_indices = np.searchsorted(person_ixs, people)\n", + "person_indices = np.append(person_indices, len(person_ixs))\n", + "person_chunks = [\n", + " time_ixs[person_indices[i]: person_indices[i + 1]]\n", + " for i in range(len(person_indices) - 1)\n", + "]\n", + "\n", + "visit_chunks = []\n", + "visit_times_raw = []\n", + "\n", + "for i, chunk in enumerate(person_chunks):\n", + " visits = sorted(np.unique(chunk))\n", + " visit_indices_local = np.searchsorted(chunk, visits)\n", + " visit_indices_local = np.append(\n", + " visit_indices_local,\n", + " len(chunk)\n", + " )\n", + " visit_chunks.append(visit_indices_local)\n", + " visit_times_raw.append(visits)\n", + "\n", + "n_visits = {i:len(j) for i,j in enumerate(visit_times_raw)}\n", + "\n", + "visit_days_rel = {\n", + " i: (\n", + " pd.to_datetime(cohort_params['training_end_date']) \\\n", + " - pd.to_datetime(idx_to_datetime[time])\n", + " ).days for time in time_ixs\n", + "}\n", + "vdrel_func = np.vectorize(visit_days_rel.get)\n", + "visit_time_rel = [\n", + " vdrel_func(v) for v in visit_times_raw\n", + "]\n", + "\n", + "remap = {\n", + " 'id': people,\n", + " 'time': sorted(np.unique(time_ixs)),\n", + " 'concept': sorted(np.unique(code_ixs))\n", + "}\n", + "\n", + "dataset_dict = {\n", + " 'all_codes_tensor': all_codes_tensor, # A tensor of all codes occurring in the dataset\n", + " 'person_indices': person_indices, # A list of indices such that all_codes_tensor[person_indices[i]: person_indices[i+1]] are the codes assigned to the ith patient\n", + " 'visit_chunks': visit_chunks, # A list of indices such that all_codes_tensor[person_indices[i]+visit_chunks[j]:person_indices[i]+visit_chunks[j+1]] are the codes assigned to the ith patient during their jth visit\n", + " 'visit_time_rel': visit_time_rel, # A list of times (as measured in days to the prediction date) for each visit\n", + " 'n_visits': n_visits, # A dict defined such that n_visits[i] is the number of visits made by the ith patient\n", + " 'outcomes_filt': outcomes_filt, # A pandas Series defined such that outcomes_filt.iloc[i] is the outcome of the ith patient\n", + " 'remap': remap,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9rumxm09Pfsj" + }, + "source": [ + "## Run the windowed regression model on the task defined above" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EBWOySUSPfsj" + }, + "outputs": [], + "source": [ + "# split data into train, validate and test sets\n", + "windowed_dataset.split()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jtVIFn2CPfsj" + }, + "outputs": [], + "source": [ + "# train the regression model over several choices of regularization parameter\n", + "reg_lambdas = [2, 0.2, 0.02]\n", + "lr_val_aucs = []\n", + "model = OMOPLogisticRegression(\"eol_new_50\", windowed_dataset)\n", + "\n", + "for reg_lambda in reg_lambdas:\n", + " # Gen and fit\n", + " model.gen_pipeline(reg_lambda)\n", + " model.fit()\n", + " # Eval on validation data\n", + " pred_lr = model._pipeline.predict_proba(windowed_dataset.val['X'])[:, 1]\n", + " lr_val_auc = roc_auc_score(windowed_dataset.val['y'], pred_lr)\n", + " lr_val_aucs.append(lr_val_auc)\n", + " print(\"C: %.4f, Val AUC: %.2f\" % (reg_lambda, lr_val_auc))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pYy4lcbePfsj" + }, + "outputs": [], + "source": [ + "# Gen and fit on best C\n", + "best_reg_lambda = reg_lambdas[np.argmax(lr_val_aucs)]\n", + "model.gen_pipeline(best_reg_lambda)\n", + "model.fit()\n", + "# Eval on test data\n", + "pred_lr = model._pipeline.predict_proba(windowed_dataset.test['X'])[:, 1]\n", + "score = roc_auc_score(windowed_dataset.test['y'], pred_lr)\n", + "print(\"C: %.4f, Test AUC: %.2f\" % (best_reg_lambda, score))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jvYkElLlPfsk" + }, + "source": [ + "### Learn a Word2Vec embedding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e7oM4vOsPfsk" + }, + "outputs": [], + "source": [ + "%%time\n", + "embedding_dim = 300 # size of embedding, must be multiple of number of heads\n", + "window_days = 90 # number of days in window that defines a \"Sentence\" when learning the embedding\n", + "train_coords = np.nonzero(np.where(np.isin(person_ixs, indices_train), 1, 0))\n", + "embedding_filename = embedding_utils.train_embedding(featureSet, feature_matrix_3d_transpose, window_days, \\\n", + " person_ixs[train_coords], time_ixs[train_coords], \\\n", + " remap['time'], embedding_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vlIGYdtIPfsk" + }, + "source": [ + "## Run the SARD deep model on the predictive task\n", + "### 1. Set Model Parameters and Construct the Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yyr2CO1wPfsk" + }, + "outputs": [], + "source": [ + "# using the same split as before, create train/validate/test batches for the deep model\n", + "# `mbsz` might need to be decreased based on the GPU's memory and the number of features being used\n", + "mbsz = 50\n", + "def get_batches(arr, mbsz=mbsz):\n", + " curr, ret = 0, []\n", + " while curr < len(arr) - 1:\n", + " ret.append(arr[curr : curr + mbsz])\n", + " curr += mbsz\n", + " return ret\n", + "\n", + "p_ranges_train, p_ranges_test = [\n", + " get_batches(arr) for arr in (\n", + " indices_train, indices_test\n", + " )\n", + "]\n", + "p_ranges_val = p_ranges_test[:val_size // mbsz]\n", + "p_ranges_test = p_ranges_test[val_size // mbsz:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AfUpBu4sPfsl" + }, + "outputs": [], + "source": [ + "# Pick a name for the model (mn_prefix) that will be used when saving checkpoints\n", + "# Then, set some parameters for SARD. The values below reflect a good starting point that performed well on several tasks\n", + "mn_prefix = 'eol_experiment_prefix'\n", + "n_heads = 2\n", + "assert embedding_dim % n_heads == 0\n", + "model_params = {\n", + " 'embedding_dim': int(embedding_dim / n_heads), # Dimension per head of visit embeddings\n", + " 'n_heads': n_heads, # Number of self-attention heads\n", + " 'attn_depth': 2, # Number of stacked self-attention layers\n", + " 'dropout': 0.05, # Dropout rate for both self-attention and the final prediction layer\n", + " 'use_mask': True, # Only allow visits to attend to other actual visits, not to padding visits\n", + " 'concept_embedding_path': embedding_filename # if unspecified, uses default Torch embeddings\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GJ2DXdlWPfsl" + }, + "outputs": [], + "source": [ + "# Set up fixed model parameters, loss functions, and build the model on the GPU\n", + "lr = 2e-4\n", + "n_epochs_pretrain = 1\n", + "ft_epochs = 1\n", + "\n", + "update_every = 500\n", + "update_mod = update_every // mbsz\n", + "\n", + "base_model = visit_transformer.VisitTransformer(\n", + " featureSet, **model_params\n", + ")\n", + "\n", + "clf = visit_transformer.VTClassifer(\n", + " base_model, **model_params\n", + ").cuda()\n", + "\n", + "clf.bert.set_data(\n", + " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n", + " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n", + " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n", + ")\n", + "\n", + "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n", + " pos_weight=torch.FloatTensor([\n", + " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n", + " ]), reduction='sum'\n", + ").cuda()\n", + "\n", + "optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=lr)\n", + "\n", + "def eval_curr_model_on(a):\n", + " with torch.no_grad():\n", + " preds_test, true_test = [], []\n", + " for batch_num, p_range in enumerate(a):\n", + " y_pred = clf(p_range)\n", + " preds_test += y_pred.tolist()\n", + " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n", + " return roc_auc_score(true_test, preds_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4l95ivncPfsl" + }, + "source": [ + "### 2. Fit the SARD model to the best windowed linear model (Reverse Distillation)\n", + "\n", + "The following code saves models in a folder `/SavedModels/{task}/`; make sure to create the directory before running." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UgJt7p11Pfsl" + }, + "outputs": [], + "source": [ + "task = 'eol'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ESuCBwrMPfsl" + }, + "outputs": [], + "source": [ + "# Run `n_epochs_pretrain` of Reverse Distillation pretraining\n", + "val_losses = []\n", + "progress_bar = IntProgress(min=0, max=int(n_epochs_pretrain * len(p_ranges_train)))\n", + "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n", + "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n", + "\n", + "display(progress_bar)\n", + "display(batch_loss_disp)\n", + "display(time_disp)\n", + "\n", + "for epoch in range(n_epochs_pretrain):\n", + " t, batch_loss = time.time(), 0\n", + "\n", + " for batch_num, p_range in enumerate(p_ranges_train):\n", + "\n", + " if batch_num % 50 == 0:\n", + " batch_loss_disp.value = round(batch_loss / 50, 2)\n", + " time_disp.value = round(time.time() - t, 2)\n", + " t, batch_loss = time.time(), 0\n", + "\n", + " y_pred = clf(p_range)\n", + " loss_distill = loss_function_distill(\n", + " y_pred, torch.FloatTensor(pred_lr_all[p_range]).cuda()\n", + " )\n", + "\n", + " batch_loss += loss_distill.item()\n", + " loss_distill.backward()\n", + "\n", + " if batch_num % update_mod == 0:\n", + " optimizer_clf.step()\n", + " optimizer_clf.zero_grad()\n", + "\n", + " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n", + "\n", + " torch.save(\n", + " clf.state_dict(),\n", + " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", + " task=task, mn_prefix = mn_prefix, epochs = epoch + 1\n", + " )\n", + " )\n", + "\n", + " clf.eval()\n", + " ckpt_auc = eval_curr_model_on(p_ranges_val)\n", + " print('Epochs: {} | Val AUC: {}'.format(epoch + 1, ckpt_auc))\n", + " val_losses.append(ckpt_auc)\n", + " clf.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "S79PbKdhPfsm" + }, + "outputs": [], + "source": [ + "# Save the pretrained model with best validation-set performance\n", + "clf.load_state_dict(\n", + " torch.load(\"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", + " task=task, mn_prefix=mn_prefix, epochs=np.argmax(val_losses) + 1\n", + " ))\n", + ")\n", + "torch.save(\n", + " clf.state_dict(),\n", + " \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", + " task=task, mn_prefix = mn_prefix, epochs = 'BEST'\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Wzmh6tCPfsm" + }, + "source": [ + "### 3. Fine-tune the SARD model by training to match the actual outcomes on the training set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iczzP5NAPfsm" + }, + "outputs": [], + "source": [ + "# Set up loss functions for fine-tuning. There are two terms:\n", + "# - `loss_function_distill`, which penalizes differences between the linear model prediction and SARD's prediction\n", + "# - `loss_function_clf`, which penalizes differences between the true outcome and SARD's prediction\n", + "loss_function_distill = torch.nn.BCEWithLogitsLoss(\n", + " pos_weight=torch.FloatTensor([\n", + " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n", + " ]), reduction='sum'\n", + ").cuda()\n", + "\n", + "loss_function_clf = torch.nn.BCEWithLogitsLoss(\n", + " pos_weight=torch.FloatTensor([\n", + " len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1\n", + " ]), reduction='sum'\n", + ").cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "30_HYniePfsm" + }, + "outputs": [], + "source": [ + "# run `ft_epochs` of fine-tuning training, for each of the values of `alpha` below.\n", + "# Note that `alpha` is the relative weight of `loss_function_distill` as compared to `loss_function_clf`\n", + "\n", + "all_pred_models = {}\n", + "\n", + "progress_bar = IntProgress(min=0, max=int(ft_epochs * len(p_ranges_train)))\n", + "batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)\n", + "time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)\n", + "\n", + "display(progress_bar)\n", + "display(batch_loss_disp)\n", + "display(time_disp)\n", + "\n", + "\n", + "no_rd = False\n", + "for alpha in [0,0.05,0.1,0.15, 0.2]:\n", + "\n", + " progress_bar.value = 0\n", + "\n", + " if no_rd:\n", + " pretrained_model_fn = mn_prefix + '_None'\n", + " start_model = None\n", + " if start_model is None:\n", + " base_model = visit_transformer.VisitTransformer(\n", + " featureSet, **model_params\n", + " )\n", + "\n", + " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n", + "\n", + " clf.bert.set_data(\n", + " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n", + " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n", + " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n", + " )\n", + " else:\n", + " pretrained_model_path = \"SavedModels/{task}/{start_model}\".format(\n", + " task=task, start_model=start_model\n", + " )\n", + " clf.load_state_dict(torch.load(pretrained_model_path))\n", + "\n", + " else:\n", + " pretrained_model_fn = \"{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", + " mn_prefix=mn_prefix, epochs='BEST'\n", + " )\n", + " pretrained_model_path = \"SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}\".format(\n", + " task=task, mn_prefix=mn_prefix, epochs='BEST'\n", + " )\n", + "\n", + " clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()\n", + " clf.bert.set_data(\n", + " torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),\n", + " dataset_dict['person_indices'], dataset_dict['visit_chunks'],\n", + " dataset_dict['visit_time_rel'], dataset_dict['n_visits']\n", + " )\n", + "\n", + " clf.load_state_dict(torch.load(pretrained_model_path))\n", + "\n", + " clf.train()\n", + "\n", + " optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=2e-4)\n", + "\n", + " for epoch in range(ft_epochs):\n", + "\n", + " t, batch_loss = time.time(), 0\n", + "\n", + " for batch_num, p_range in enumerate(p_ranges_train):\n", + "\n", + " if batch_num % 50 == 0:\n", + " batch_loss_disp.value = round(batch_loss / 50, 2)\n", + " time_disp.value = round(time.time() - t, 2)\n", + " t, batch_loss = time.time(), 0\n", + "\n", + " y_pred = clf(p_range)\n", + "\n", + " loss = loss_function_clf(\n", + " y_pred,\n", + " torch.FloatTensor(dataset_dict['outcomes_filt'].values[p_range]).cuda()\n", + " )\n", + "\n", + " loss_distill = loss_distill = loss_function_distill(\n", + " y_pred,\n", + " torch.FloatTensor(pred_lr_all[p_range]).cuda()\n", + " )\n", + "\n", + " batch_loss += loss.item() + alpha * loss_distill.item()\n", + " loss_total = loss + alpha * loss_distill\n", + " loss_total.backward()\n", + "\n", + " if batch_num % update_mod == 0:\n", + " optimizer_clf.step()\n", + " optimizer_clf.zero_grad()\n", + "\n", + " progress_bar.value = batch_num + epoch * len(p_ranges_train)\n", + "\n", + " saving_fn = \"{pretrain}_alpha_{alpha}_epochs_{epochs}\".format(\n", + " task=task, pretrain = pretrained_model_fn, alpha=alpha, epochs = epoch + 1\n", + " )\n", + " torch.save(\n", + " clf.state_dict(),\n", + " \"SavedModels/{task}/{saving_fn}\".format(\n", + " task=task, saving_fn=saving_fn\n", + " )\n", + " )\n", + "\n", + " clf.eval()\n", + " val_auc = eval_curr_model_on(p_ranges_val)\n", + " print(val_auc)\n", + " all_pred_models[saving_fn] = val_auc\n", + " clf.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rXk7tpRdPfsm" + }, + "source": [ + "### 4. Evaluate the best SARD model, as determined by validation performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "35WmOaisPfsm" + }, + "outputs": [], + "source": [ + "best_model = max(all_pred_models, key=all_pred_models.get)\n", + "clf.load_state_dict(\n", + " torch.load(\"SavedModels/{task}/{model}\".format(\n", + " task=task, model=best_model\n", + " ))\n", + ")\n", + "clf.eval();\n", + "with torch.no_grad():\n", + " preds_test, true_test = [], []\n", + " for batch_num, p_range in enumerate(p_ranges_test):\n", + " y_pred = clf(p_range)\n", + " preds_test += y_pred.tolist()\n", + " true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)\n", + " print(roc_auc_score(true_test, preds_test))\n", + "clf.train();" + ] + } + ], + "metadata": { + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "omop-learn", + "language": "python", + "name": "omop-learn" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 }