Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BCB'21 paper "DeepNote-GNN: predicting hospital readmission using clinical notes and patient network" #229

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from .base_ehr_dataset import BaseEHRDataset
from .base_signal_dataset import BaseSignalDataset
from .base_note_dataset import BaseNoteDataset
from .cardiology import CardiologyDataset
from .eicu import eICUDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4Dataset
from .mimicextract import MIMICExtractDataset
from .mimic3_note import MIMIC3NoteDataset
from .omop import OMOPDataset
from .sleepedf import SleepEDFDataset
from .isruc import ISRUCDataset
from .shhs import SHHSDataset
from .tuab import TUABDataset
from .tuev import TUEVDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleNoteDataset, SampleEHRDataset
from .splitter import split_by_patient, split_by_visit
from .utils import collate_fn_dict, get_dataloader, strptime
163 changes: 163 additions & 0 deletions pyhealth/datasets/base_note_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import logging
import time
import os
from abc import ABC
from collections import Counter
from copy import deepcopy
from typing import Dict, Callable, Tuple, Union, List, Optional

import pandas as pd
from tqdm import tqdm
from pandarallel import pandarallel

from pyhealth.data import Patient, Event
from pyhealth.datasets.sample_dataset import SampleNoteDataset
from pyhealth.datasets.utils import MODULE_CACHE_PATH, DATASET_BASIC_TABLES
from pyhealth.datasets.utils import hash_str
from pyhealth.medcode import CrossMap
from pyhealth.utils import load_pickle, save_pickle
from pyhealth.tasks.utils import add_embedding

logger = logging.getLogger(__name__)

INFO_MSG = """
"""


class BaseNoteDataset(ABC):

def __init__(
self,
root: str,
dataset_name: Optional[str] = None,
dev: bool = False,
refresh_cache: bool = False,
):

"""Loads tables into a dict of patients and saves it to cache."""

# base attributes
self.dataset_name = (
self.__class__.__name__ if dataset_name is None else dataset_name
)
self.root = root

self.dev = dev

# hash filename for cache
args_to_hash = (
[self.dataset_name, root]
+ ["dev" if dev else "prod"]
)
filename = hash_str("+".join([str(arg) for arg in args_to_hash])) + ".pkl"
self.filepath = os.path.join(MODULE_CACHE_PATH, filename)

# check if cache exists or refresh_cache is True
if os.path.exists(self.filepath) and (not refresh_cache):
# load from cache
logger.debug(
f"Loaded {self.dataset_name} base dataset from {self.filepath}"
)
self.patients = load_pickle(self.filepath)
else:
# load from raw data
logger.debug(f"Processing {self.dataset_name} base dataset...")
# parse tables
patients = self.parse_tables()

self.patients = patients
# save to cache
logger.debug(f"Saved {self.dataset_name} base dataset to {self.filepath}")
save_pickle(self.patients, self.filepath)



def parse_tables(self) -> Dict[str, Patient]:
"""Parses the tables in `self.tables` and return a dict of patients.

Will be called in `self.__init__()` if cache file does not exist or
refresh_cache is True.

This function will first call `self.parse_basic_info()` to parse the
basic patient information, and then call `self.parse_[table_name]()` to
parse the table with name `table_name`. Both `self.parse_basic_info()` and
`self.parse_[table_name]()` should be implemented in the subclass.

Returns:
A dict mapping patient_id to `Patient` object.
"""
pandarallel.initialize(progress_bar=False)

# patients is a dict of Patient objects indexed by patient_id
patients: Dict[str, Patient] = dict()
# process basic information (e.g., patients and visits)
tic = time.time()
patients = self.parse_basic_info(patients)
print(
"finish basic patient information parsing : {}s".format(time.time() - tic)
)

return patients


def __str__(self):
"""Prints some information of the dataset."""
return f"Base dataset {self.dataset_name}"

def stat(self) -> str:
"""Returns some statistics of the base dataset."""
lines = list()
lines.append("")
lines.append(f"Statistics of base dataset (dev={self.dev}):")
lines.append(f"\t- Dataset: {self.dataset_name}")
lines.append(f"\t- Number of patients: {len(self.patients)}")
lines.append("")
print("\n".join(lines))
return "\n".join(lines)


def set_task(
self,
task_fn: Callable,
task_name: Optional[str] = None,
emb=False,
) -> SampleNoteDataset:
"""Processes the base dataset to generate the task-specific sample dataset.

This function should be called by the user after the base dataset is
initialized. It will iterate through all patients in the base dataset
and call `task_fn` which should be implemented by the specific task.

Args:
task_fn: a function that takes a single patient and returns a
list of samples (each sample is a dict with patient_id, visit_id,
and other task-specific attributes as key). The samples will be
concatenated to form the sample dataset.
task_name: the name of the task. If None, the name of the task
function will be used.

Returns:
sample_dataset: the task-specific sample dataset.

Note:
In `task_fn`, a patient may be converted to multiple samples, e.g.,
a patient with three visits may be converted to three samples
([visit 1], [visit 1, visit 2], [visit 1, visit 2, visit 3]).
Patients can also be excluded from the task dataset by returning
an empty list.
"""
if task_name is None:
task_name = task_fn.__name__
samples = []
for patient_id, patient in tqdm(
self.patients.items(), desc=f"Generating samples for {task_name}"
):
samples.extend(task_fn(patient))
if emb:
samples = add_embedding(samples)
sample_dataset = SampleNoteDataset(
samples,
dataset_name=self.dataset_name,
task_name=task_name,
)
return sample_dataset
115 changes: 115 additions & 0 deletions pyhealth/datasets/mimic3_note.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
from typing import Optional, List, Dict, Tuple, Union

import pandas as pd
import numpy as np

from pyhealth.data import Event, Visit, Patient
from pyhealth.datasets import BaseNoteDataset
from pyhealth.datasets.utils import strptime


class MIMIC3NoteDataset(BaseNoteDataset):
"""
TODO: add docs
"""

def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses PATIENTS and ADMISSIONS tables.

Will be called in `self.parse_tables()`

Docs:
- PATIENTS: https://mimic.mit.edu/docs/iii/tables/patients/
- ADMISSIONS: https://mimic.mit.edu/docs/iii/tables/admissions/
- NOTEEVENTS: https://mimic.mit.edu/docs/iii/tables/noteevents/

Args:
patients: a dict of `Patient` objects indexed by patient_id which is updated with the mimic-3 table result.

Returns:
The updated patients dict.
"""
# read patients table
self.patients_df = pd.read_csv(
os.path.join(self.root, "PATIENTS.csv"),
dtype={"SUBJECT_ID": str},
nrows=1000 if self.dev else None,
)
# read admissions table
self.admissions_df = pd.read_csv(
os.path.join(self.root, "ADMISSIONS.csv"),
dtype={"SUBJECT_ID": str, "HADM_ID": str},
)
# read noteevents table
self.noteevents_df = pd.read_csv(
os.path.join(self.root, "NOTEEVENTS.csv"),
dtype={"SUBJECT_ID": str, "HADM_ID": str},
)

self.admissions_df.ADMITTIME = pd.to_datetime(self.admissions_df.ADMITTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')
self.admissions_df.DISCHTIME = pd.to_datetime(self.admissions_df.DISCHTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')
self.admissions_df.DEATHTIME = pd.to_datetime(self.admissions_df.DEATHTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')

self.admissions_df = self.admissions_df.sort_values(['SUBJECT_ID', 'ADMITTIME'])
self.admissions_df = self.admissions_df.reset_index(drop=True)
self.admissions_df['NEXT_ADMITTIME'] = self.admissions_df.groupby('SUBJECT_ID').ADMITTIME.shift(-1)
self.admissions_df['NEXT_ADMISSION_TYPE'] = self.admissions_df.groupby('SUBJECT_ID').ADMISSION_TYPE.shift(-1)

rows = self.admissions_df.NEXT_ADMISSION_TYPE == 'ELECTIVE'
self.admissions_df.loc[rows, 'NEXT_ADMITTIME'] = pd.NaT
self.admissions_df.loc[rows, 'NEXT_ADMISSION_TYPE'] = np.NaN

self.admissions_df = self.admissions_df.sort_values(['SUBJECT_ID', 'ADMITTIME'])
self.admissions_df[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']] = \
self.admissions_df.groupby(['SUBJECT_ID'])[['NEXT_ADMITTIME', 'NEXT_ADMISSION_TYPE']].fillna(method='bfill')
self.admissions_df['DAYS_NEXT_ADMIT'] = (self.admissions_df.NEXT_ADMITTIME - self.admissions_df.DISCHTIME).dt.total_seconds() / (24 * 60 * 60)
self.admissions_df['OUTPUT_LABEL'] = (self.admissions_df.DAYS_NEXT_ADMIT < 30).astype('int')

# filter out newborn and death
self.admissions_df = self.admissions_df[self.admissions_df['ADMISSION_TYPE'] != 'NEWBORN']
self.admissions_df = self.admissions_df[self.admissions_df.DEATHTIME.isnull()]
self.admissions_df['DURATION'] = (self.admissions_df['DISCHTIME'] - self.admissions_df['ADMITTIME']).dt.total_seconds() / (24 * 60 * 60)

self.noteevents_df = self.noteevents_df.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'CHARTDATE'])

# merge admission and noteevents tables
self.admission_notes_df = pd.merge(
self.admissions_df[['SUBJECT_ID', 'HADM_ID', 'ADMITTIME', 'DISCHTIME', 'DAYS_NEXT_ADMIT', 'NEXT_ADMITTIME',
'ADMISSION_TYPE', 'DEATHTIME', 'OUTPUT_LABEL', 'DURATION']],
self.noteevents_df[['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'TEXT', 'CATEGORY']],
on=['SUBJECT_ID', 'HADM_ID'], how='left'
)

self.admission_notes_df['ADMITTIME_C'] = self.admission_notes_df.ADMITTIME.apply(lambda x: str(x).split(' ')[0])
self.admission_notes_df['ADMITTIME_C'] = pd.to_datetime(self.admission_notes_df.ADMITTIME_C, format='%Y-%m-%d', errors='coerce')
self.admission_notes_df['CHARTDATE'] = pd.to_datetime(self.admission_notes_df.CHARTDATE, format='%Y-%m-%d', errors='coerce')

# merge patient and admission_noteevents tables
df = pd.merge(self.patients_df, self.admission_notes_df, on="SUBJECT_ID", how="inner")

# sort by admission and discharge time
df = df.sort_values(["SUBJECT_ID", "ADMITTIME", "DISCHTIME"], ascending=True)
# group by patient
df_group = df.groupby("SUBJECT_ID")

# parallel unit of basic information (per patient)
def basic_unit(p_id, p_info):
p_info_dict = p_info.to_dict(orient='records')[0]
patient = Patient(
patient_id=p_id,
birth_datetime=strptime(p_info["DOB"].values[0]),
death_datetime=strptime(p_info["DOD_HOSP"].values[0]),
attr=p_info_dict,
)
return patient

# parallel apply
df_group = df_group.parallel_apply(
lambda x: basic_unit(x.SUBJECT_ID.unique()[0], x)
)
# summarize the results
for pat_id, pat in df_group.items():
patients[pat_id] = pat

return patients
44 changes: 44 additions & 0 deletions pyhealth/datasets/sample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,50 @@ def stat(self) -> str:
return "\n".join(lines)


class SampleNoteDataset(SampleBaseDataset):
"""
TODO: add documentation
"""
def __init__(self, samples: List[Dict], dataset_name="", task_name=""):
super().__init__(samples, dataset_name, task_name)
self.patient_to_index: Dict[str, List[int]] = self._index_patient()
# self.record_to_index: Dict[str, List[int]] = self._index_record()
# self.input_info: Dict = self._validate()
self.type_ = "note"
self.pos_neg_labels = []
self._check_label()


def _check_label(self):
pos_labels = []
neg_labels = []
samples = self.samples
for i in range(len(samples)):
if samples[i]['label'] == 1:
pos_labels.append(samples[i]['patient_id'])
else:
neg_labels.append(samples[i]['patient_id'])

self.pos_neg_labels.append(pos_labels)
self.pos_neg_labels.append(neg_labels)


def _index_patient(self) -> Dict[str, List[int]]:
"""Helper function which indexes the samples by patient_id.

Will be called in `self.__init__()`.
Returns:
patient_to_index: Dict[str, int], a dict mapping patient_id to a list
of sample indices.
"""
patient_to_index = {}
for idx, sample in enumerate(self.samples):
patient_to_index.setdefault(sample["patient_id"], []).append(idx)
return patient_to_index




class SampleEHRDataset(SampleBaseDataset):
"""Sample EHR dataset class.

Expand Down
Loading