Skip to content

Commit

Permalink
Merge pull request #68 from sunlabuiuc/develop
Browse files Browse the repository at this point in the history
update pyhealth live 05, add deepr model, start unittest (#67)
  • Loading branch information
ycq091044 authored Jan 24, 2023
2 parents e452ad4 + 81db578 commit 940db1b
Show file tree
Hide file tree
Showing 18 changed files with 1,199 additions and 135 deletions.
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ We implement the following models for supporting multiple healthcare predictive
models/pyhealth.models.GAMENet
models/pyhealth.models.MICRON
models/pyhealth.models.SafeDrug
models/pyhealth.models.Deepr

14 changes: 14 additions & 0 deletions docs/api/models/pyhealth.models.Deepr.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pyhealth.models.Deepr
===================================

The separate callable DeeprLayer and the complete Deepr model.

.. autoclass:: pyhealth.models.DeeprLayer
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pyhealth.models.Deepr
:members:
:undoc-members:
:show-inheritance:
8 changes: 5 additions & 3 deletions docs/live.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ PyHealth live

**YouTube**: `Recorded Live Sessions <https://www.youtube.com/playlist?list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV>`_

**User/Developer Slack**: `Click to join <https://join.slack.com/t/pyhealthworkspace/shared_invite/zt-1np4yxs77-aqTKxhlfLOjaPbqTzr6sTA>`_

Schedules
^^^^^^^^^^^^^^
**(Dec 21, Wed)** Live 01 - What is PyHealth and How to Get Started? `[Recap] <https://www.youtube.com/watch?v=1Ir6hzU4Nro&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=1>`_
Expand All @@ -23,10 +25,10 @@ Schedules

**(Jan 11, Wed)** Live 04 - Tokenizer & Medcode: master the medical code lookup and mapping `[Recap I] <https://www.youtube.com/watch?v=MmmfU6_xkYg&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=9>`_ `[II] <https://www.youtube.com/watch?v=CeXJtf0lfs0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=10>`_

**(Jan 18, Wed)** Live 05 - PyHealth can support a complete healthcare ML pipeline
**(Jan 18, Wed)** Live 05 - PyHealth can support a complete healthcare ML pipeline `[Recap I] <https://www.youtube.com/watch?v=GVLzc6E4og0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=11>`_ `[II] <https://www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`_

**(Jan 25, Wed)** Live 06 - Adopt your customized model and quickly try it on our data
**(Jan 25, Wed)** Live 06 - Fit your own dataset into pipeline and use our model

**(Feb 1, Wed)** Live 07 - Fit your own dataset into pipeline and use our model
**(Feb 1, Wed)** Live 07 - Adopt your customized model and quickly try it on our data

**(Feb 8, Wed)** Live 08 - Define your own healthcare task on MIMIC data
30 changes: 30 additions & 0 deletions docs/log.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,36 @@ Development logs
======================
We track the new development here:

**Jan 24, 2023**

.. code-block:: bash
1. Fix the code typo in pyhealth/tasks/drug_recommendation.py for issue #71.
2. update the pyhealth live schedule
**Jan 22, 2023**

.. code-block:: bash
1. Fix the list of list of vector problem in RNN, Transformer, RETAIN, and CNN
2. Add initialization examples for RNN, Transformer, RETAIN, CNN, and Deepr
3. (minor) change the parameters from "Type" and "level" to "type_" and "dim_"
4. BPDanek adds the __repr__ function to medcode for better print understanding
5. add unittest for pyhealth.data
**Jan 21, 2023**

.. code-block:: bash
1. Added a new model, Deepr (models.Deepr)
**Jan 20, 2023**

.. code-block:: bash
1. add the pyhealth live 05
2. add slack channel invitation in pyhealth live page
**Jan 13, 2023**

.. code-block:: bash
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ Tutorials

`Tutorial 4: Introduction to pyhealth.trainer <https://colab.research.google.com/drive/1L1Nz76cRNB7wTp5Pz_4Vp4N2eRZ9R6xl?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=5Hyw3of5pO4&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=7>`_

`Tutorial 5: Introduction to pyhealth.metrics <https://colab.research.google.com/drive/1Mrs77EJ92HwMgDaElJ_CBXbi4iABZBeo?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=d-Kx_xCwre4&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=8>`_
`Tutorial 5: Introduction to pyhealth.metrics <https://colab.research.google.com/drive/1Mrs77EJ92HwMgDaElJ_CBXbi4iABZBeo?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=d-Kx_xCwre4&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=8>`_

`Tutorial 6: Introduction to pyhealth.tokenizer <https://colab.research.google.com/drive/1bDOb0A5g0umBjtz8NIp4wqye7taJ03D0?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=CeXJtf0lfs0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=10>`_
`Tutorial 6: Introduction to pyhealth.tokenizer <https://colab.research.google.com/drive/1bDOb0A5g0umBjtz8NIp4wqye7taJ03D0?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=CeXJtf0lfs0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=10>`_

`Tutorial 7: Introduction to pyhealth.medcode <https://colab.research.google.com/drive/1xrp_ACM2_Hg5Wxzj0SKKKgZfMY0WwEj3?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=MmmfU6_xkYg&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=9>`_


The following tutorials will help users build their own task pipelines. `[Video] <https://drive.google.com/file/d/1roWcfvjRrrtDWTWLjjhgZ1laD6p851Yi/view?usp=share_link>`_
The following tutorials will help users build their own task pipelines. `[Video] <https://www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`_

`Pipeline 1: Drug Recommendation <https://colab.research.google.com/drive/10CSb4F4llYJvv42yTUiRmvSZdoEsbmFF?usp=sharing>`_

Expand Down
116 changes: 116 additions & 0 deletions examples/drug_recommendation_eICU_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from pyhealth.datasets import eICUDataset
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer
from pyhealth.tasks import drug_recommendation_eicu_fn
from pyhealth.trainer import Trainer

# STEP 1: load data
base_dataset = eICUDataset(
root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
tables=["diagnosis", "medication", "physicalExam"],
dev=True,
)
base_dataset.stat()

# STEP 2: set task

from pyhealth.data import Visit, Patient


def drug_recommendation_eicu_fn(patient: Patient):
"""Processes a single patient for the drug recommendation task.
Drug recommendation aims at recommending a set of drugs given the patient health
history (e.g., conditions and procedures).
Args:
patient: a Patient object
Returns:
samples: a list of samples, each sample is a dict with patient_id, visit_id,
and other task-specific attributes as key
Examples:
>>> from pyhealth.datasets import eICUDataset
>>> eicu_base = eICUDataset(
... root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
... tables=["diagnosis", "medication"],
... code_mapping={},
... dev=True
... )
>>> from pyhealth.tasks import drug_recommendation_eicu_fn
>>> eicu_sample = eicu_base.set_task(drug_recommendation_eicu_fn)
>>> eicu_sample.samples[0]
[{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': [['2', '3', '4']]}]
"""
samples = []
for i in range(len(patient)):
visit: Visit = patient[i]
conditions = visit.get_code_list(table="diagnosis")
procedures = visit.get_code_list(table="physicalExam")
drugs = visit.get_code_list(table="medication")
# exclude: visits without condition, procedure, or drug code
if len(conditions) * len(procedures) * len(drugs) == 0:
continue
# TODO: should also exclude visit with age < 18
samples.append(
{
"visit_id": visit.visit_id,
"patient_id": patient.patient_id,
"conditions": conditions,
"procedures": procedures,
"drugs": drugs,
"drugs_all": drugs,
}
)
# exclude: patients with less than 2 visit
if len(samples) < 2:
return []
# add history
samples[0]["conditions"] = [samples[0]["conditions"]]
samples[0]["procedures"] = [samples[0]["procedures"]]
samples[0]["drugs_all"] = [samples[0]["drugs_all"]]

for i in range(1, len(samples)):
samples[i]["conditions"] = samples[i - 1]["conditions"] + [
samples[i]["conditions"]
]
samples[i]["procedures"] = samples[i - 1]["procedures"] + [
samples[i]["procedures"]
]
samples[i]["drugs_all"] = samples[i - 1]["drugs_all"] + [
samples[i]["drugs_all"]
]

return samples


sample_dataset = base_dataset.set_task(drug_recommendation_eicu_fn)
sample_dataset.stat()

train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

# STEP 3: define model
model = Transformer(
dataset=sample_dataset,
feature_keys=["conditions", "procedures"],
label_key="drugs",
mode="multilabel",
)

# STEP 4: define trainer
trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=50,
monitor="pr_auc_samples",
)

# STEP 5: evaluate
trainer.evaluate(test_dataloader)
46 changes: 24 additions & 22 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ class BaseDataset(ABC):
"""

def __init__(
self,
root: str,
tables: List[str],
dataset_name: Optional[str] = None,
code_mapping: Optional[Dict[str, Union[str, Tuple[str, Dict]]]] = None,
dev: bool = False,
refresh_cache: bool = False,
self,
root: str,
tables: List[str],
dataset_name: Optional[str] = None,
code_mapping: Optional[Dict[str, Union[str, Tuple[str, Dict]]]] = None,
dev: bool = False,
refresh_cache: bool = False,
):
"""Loads tables into a dict of patients and saves it to cache."""

Expand All @@ -93,10 +93,10 @@ def __init__(

# hash filename for cache
args_to_hash = (
[self.dataset_name, root]
+ sorted(tables)
+ sorted(code_mapping.items())
+ ["dev" if dev else "prod"]
[self.dataset_name, root]
+ sorted(tables)
+ sorted(code_mapping.items())
+ ["dev" if dev else "prod"]
)
filename = hash_str("+".join([str(arg) for arg in args_to_hash])) + ".pkl"
self.filepath = os.path.join(MODULE_CACHE_PATH, filename)
Expand Down Expand Up @@ -174,8 +174,8 @@ def parse_tables(self) -> Dict[str, Patient]:

@staticmethod
def _add_event_to_patient_dict(
patient_dict: Dict[str, Patient],
event: Event,
patient_dict: Dict[str, Patient],
event: Event,
) -> Dict[str, Patient]:
"""Helper function which adds an event to the patient dict.
Expand All @@ -199,8 +199,8 @@ def _add_event_to_patient_dict(
return patient_dict

def _convert_code_in_patient_dict(
self,
patients: Dict[str, Patient],
self,
patients: Dict[str, Patient],
) -> Dict[str, Patient]:
"""Helper function which converts the codes for all patients.
Expand Down Expand Up @@ -322,9 +322,9 @@ def info():
print(INFO_MSG)

def set_task(
self,
task_fn: Callable,
task_name: Optional[str] = None,
self,
task_fn: Callable,
task_name: Optional[str] = None,
) -> SampleDataset:
"""Processes the base dataset to generate the task-specific sample dataset.
Expand Down Expand Up @@ -354,10 +354,12 @@ def set_task(
task_name = task_fn.__name__
samples = []
for patient_id, patient in tqdm(
self.patients.items(), desc=f"Generating samples for {task_name}"
self.patients.items(), desc=f"Generating samples for {task_name}"
):
samples.extend(task_fn(patient))
sample_dataset = SampleDataset(samples,
dataset_name=self.dataset_name,
task_name=task_name, )
sample_dataset = SampleDataset(
samples,
dataset_name=self.dataset_name,
task_name=task_name,
)
return sample_dataset
5 changes: 3 additions & 2 deletions pyhealth/datasets/sample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ def _validate(self) -> Dict:
int, or str.
"""
types = set([type(v) for v in flattened_values])
assert types == set([str]) or len(types.difference(set([int, float]))) == 0, \
f"Key {key} has mixed or unsupported types ({types}) across samples"
assert (
types == set([str]) or len(types.difference(set([int, float]))) == 0
), f"Key {key} has mixed or unsupported types ({types}) across samples"
type_ = types.pop()
"""
4.3. Combined level and type check.
Expand Down
1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .gamenet import GAMENet, GAMENetLayer
from .safedrug import SafeDrug, SafeDrugLayer
from .mlp import MLP
from .deepr import Deepr, DeeprLayer
11 changes: 5 additions & 6 deletions pyhealth/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def padding3d(batch):
[max([len(x) for x in visits]) for visits in batch]
)

# the most inner vector length
vec_len = len(batch[0][0][0])

# get mask
mask = torch.zeros(
len(batch),
Expand All @@ -154,16 +157,12 @@ def padding3d(batch):

# level-2 padding
batch = [
x + [[[0.0] * len(x[0])]] * (batch_max_length_level2 - len(x))
for x in batch
x + [[[0.0] * vec_len]] * (batch_max_length_level2 - len(x)) for x in batch
]

# level-3 padding
batch = [
[
x + [[0.0] * len(x[0])] * (batch_max_length_level3 - len(x))
for x in visits
]
[x + [[0.0] * vec_len] * (batch_max_length_level3 - len(x)) for x in visits]
for visits in batch
]

Expand Down
Loading

0 comments on commit 940db1b

Please sign in to comment.