Skip to content

Commit 940db1b

Browse files
authored
Merge pull request #68 from sunlabuiuc/develop
update pyhealth live 05, add deepr model, start unittest (#67)
2 parents e452ad4 + 81db578 commit 940db1b

File tree

18 files changed

+1199
-135
lines changed

18 files changed

+1199
-135
lines changed

docs/api/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ We implement the following models for supporting multiple healthcare predictive
1515
models/pyhealth.models.GAMENet
1616
models/pyhealth.models.MICRON
1717
models/pyhealth.models.SafeDrug
18+
models/pyhealth.models.Deepr
1819

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
pyhealth.models.Deepr
2+
===================================
3+
4+
The separate callable DeeprLayer and the complete Deepr model.
5+
6+
.. autoclass:: pyhealth.models.DeeprLayer
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. autoclass:: pyhealth.models.Deepr
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:

docs/live.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ PyHealth live
1313

1414
**YouTube**: `Recorded Live Sessions <https://www.youtube.com/playlist?list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV>`_
1515

16+
**User/Developer Slack**: `Click to join <https://join.slack.com/t/pyhealthworkspace/shared_invite/zt-1np4yxs77-aqTKxhlfLOjaPbqTzr6sTA>`_
17+
1618
Schedules
1719
^^^^^^^^^^^^^^
1820
**(Dec 21, Wed)** Live 01 - What is PyHealth and How to Get Started? `[Recap] <https://www.youtube.com/watch?v=1Ir6hzU4Nro&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=1>`_
@@ -23,10 +25,10 @@ Schedules
2325

2426
**(Jan 11, Wed)** Live 04 - Tokenizer & Medcode: master the medical code lookup and mapping `[Recap I] <https://www.youtube.com/watch?v=MmmfU6_xkYg&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=9>`_ `[II] <https://www.youtube.com/watch?v=CeXJtf0lfs0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=10>`_
2527

26-
**(Jan 18, Wed)** Live 05 - PyHealth can support a complete healthcare ML pipeline
28+
**(Jan 18, Wed)** Live 05 - PyHealth can support a complete healthcare ML pipeline `[Recap I] <https://www.youtube.com/watch?v=GVLzc6E4og0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=11>`_ `[II] <https://www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`_
2729

28-
**(Jan 25, Wed)** Live 06 - Adopt your customized model and quickly try it on our data
30+
**(Jan 25, Wed)** Live 06 - Fit your own dataset into pipeline and use our model
2931

30-
**(Feb 1, Wed)** Live 07 - Fit your own dataset into pipeline and use our model
32+
**(Feb 1, Wed)** Live 07 - Adopt your customized model and quickly try it on our data
3133

3234
**(Feb 8, Wed)** Live 08 - Define your own healthcare task on MIMIC data

docs/log.rst

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,36 @@ Development logs
22
======================
33
We track the new development here:
44

5+
**Jan 24, 2023**
6+
7+
.. code-block:: bash
8+
9+
1. Fix the code typo in pyhealth/tasks/drug_recommendation.py for issue #71.
10+
2. update the pyhealth live schedule
11+
12+
**Jan 22, 2023**
13+
14+
.. code-block:: bash
15+
16+
1. Fix the list of list of vector problem in RNN, Transformer, RETAIN, and CNN
17+
2. Add initialization examples for RNN, Transformer, RETAIN, CNN, and Deepr
18+
3. (minor) change the parameters from "Type" and "level" to "type_" and "dim_"
19+
4. BPDanek adds the __repr__ function to medcode for better print understanding
20+
5. add unittest for pyhealth.data
21+
22+
**Jan 21, 2023**
23+
24+
.. code-block:: bash
25+
26+
1. Added a new model, Deepr (models.Deepr)
27+
28+
**Jan 20, 2023**
29+
30+
.. code-block:: bash
31+
32+
1. add the pyhealth live 05
33+
2. add slack channel invitation in pyhealth live page
34+
535
**Jan 13, 2023**
636

737
.. code-block:: bash

docs/tutorials.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ Tutorials
1515

1616
`Tutorial 4: Introduction to pyhealth.trainer <https://colab.research.google.com/drive/1L1Nz76cRNB7wTp5Pz_4Vp4N2eRZ9R6xl?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=5Hyw3of5pO4&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=7>`_
1717

18-
`Tutorial 5: Introduction to pyhealth.metrics <https://colab.research.google.com/drive/1Mrs77EJ92HwMgDaElJ_CBXbi4iABZBeo?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=d-Kx_xCwre4&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=8>`_
18+
`Tutorial 5: Introduction to pyhealth.metrics <https://colab.research.google.com/drive/1Mrs77EJ92HwMgDaElJ_CBXbi4iABZBeo?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=d-Kx_xCwre4&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=8>`_
1919

20-
`Tutorial 6: Introduction to pyhealth.tokenizer <https://colab.research.google.com/drive/1bDOb0A5g0umBjtz8NIp4wqye7taJ03D0?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=CeXJtf0lfs0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=10>`_
20+
`Tutorial 6: Introduction to pyhealth.tokenizer <https://colab.research.google.com/drive/1bDOb0A5g0umBjtz8NIp4wqye7taJ03D0?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=CeXJtf0lfs0&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=10>`_
2121

2222
`Tutorial 7: Introduction to pyhealth.medcode <https://colab.research.google.com/drive/1xrp_ACM2_Hg5Wxzj0SKKKgZfMY0WwEj3?usp=sharing>`_ `[Video] <https://www.youtube.com/watch?v=MmmfU6_xkYg&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=9>`_
2323

2424

25-
The following tutorials will help users build their own task pipelines. `[Video] <https://drive.google.com/file/d/1roWcfvjRrrtDWTWLjjhgZ1laD6p851Yi/view?usp=share_link>`_
25+
The following tutorials will help users build their own task pipelines. `[Video] <https://www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`_
2626

2727
`Pipeline 1: Drug Recommendation <https://colab.research.google.com/drive/10CSb4F4llYJvv42yTUiRmvSZdoEsbmFF?usp=sharing>`_
2828

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from pyhealth.datasets import eICUDataset
2+
from pyhealth.datasets import split_by_patient, get_dataloader
3+
from pyhealth.models import Transformer
4+
from pyhealth.tasks import drug_recommendation_eicu_fn
5+
from pyhealth.trainer import Trainer
6+
7+
# STEP 1: load data
8+
base_dataset = eICUDataset(
9+
root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
10+
tables=["diagnosis", "medication", "physicalExam"],
11+
dev=True,
12+
)
13+
base_dataset.stat()
14+
15+
# STEP 2: set task
16+
17+
from pyhealth.data import Visit, Patient
18+
19+
20+
def drug_recommendation_eicu_fn(patient: Patient):
21+
"""Processes a single patient for the drug recommendation task.
22+
23+
Drug recommendation aims at recommending a set of drugs given the patient health
24+
history (e.g., conditions and procedures).
25+
26+
Args:
27+
patient: a Patient object
28+
29+
Returns:
30+
samples: a list of samples, each sample is a dict with patient_id, visit_id,
31+
and other task-specific attributes as key
32+
33+
Examples:
34+
>>> from pyhealth.datasets import eICUDataset
35+
>>> eicu_base = eICUDataset(
36+
... root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
37+
... tables=["diagnosis", "medication"],
38+
... code_mapping={},
39+
... dev=True
40+
... )
41+
>>> from pyhealth.tasks import drug_recommendation_eicu_fn
42+
>>> eicu_sample = eicu_base.set_task(drug_recommendation_eicu_fn)
43+
>>> eicu_sample.samples[0]
44+
[{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': [['2', '3', '4']]}]
45+
"""
46+
samples = []
47+
for i in range(len(patient)):
48+
visit: Visit = patient[i]
49+
conditions = visit.get_code_list(table="diagnosis")
50+
procedures = visit.get_code_list(table="physicalExam")
51+
drugs = visit.get_code_list(table="medication")
52+
# exclude: visits without condition, procedure, or drug code
53+
if len(conditions) * len(procedures) * len(drugs) == 0:
54+
continue
55+
# TODO: should also exclude visit with age < 18
56+
samples.append(
57+
{
58+
"visit_id": visit.visit_id,
59+
"patient_id": patient.patient_id,
60+
"conditions": conditions,
61+
"procedures": procedures,
62+
"drugs": drugs,
63+
"drugs_all": drugs,
64+
}
65+
)
66+
# exclude: patients with less than 2 visit
67+
if len(samples) < 2:
68+
return []
69+
# add history
70+
samples[0]["conditions"] = [samples[0]["conditions"]]
71+
samples[0]["procedures"] = [samples[0]["procedures"]]
72+
samples[0]["drugs_all"] = [samples[0]["drugs_all"]]
73+
74+
for i in range(1, len(samples)):
75+
samples[i]["conditions"] = samples[i - 1]["conditions"] + [
76+
samples[i]["conditions"]
77+
]
78+
samples[i]["procedures"] = samples[i - 1]["procedures"] + [
79+
samples[i]["procedures"]
80+
]
81+
samples[i]["drugs_all"] = samples[i - 1]["drugs_all"] + [
82+
samples[i]["drugs_all"]
83+
]
84+
85+
return samples
86+
87+
88+
sample_dataset = base_dataset.set_task(drug_recommendation_eicu_fn)
89+
sample_dataset.stat()
90+
91+
train_dataset, val_dataset, test_dataset = split_by_patient(
92+
sample_dataset, [0.8, 0.1, 0.1]
93+
)
94+
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
95+
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
96+
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
97+
98+
# STEP 3: define model
99+
model = Transformer(
100+
dataset=sample_dataset,
101+
feature_keys=["conditions", "procedures"],
102+
label_key="drugs",
103+
mode="multilabel",
104+
)
105+
106+
# STEP 4: define trainer
107+
trainer = Trainer(model=model)
108+
trainer.train(
109+
train_dataloader=train_dataloader,
110+
val_dataloader=val_dataloader,
111+
epochs=50,
112+
monitor="pr_auc_samples",
113+
)
114+
115+
# STEP 5: evaluate
116+
trainer.evaluate(test_dataloader)

pyhealth/datasets/base_dataset.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ class BaseDataset(ABC):
6666
"""
6767

6868
def __init__(
69-
self,
70-
root: str,
71-
tables: List[str],
72-
dataset_name: Optional[str] = None,
73-
code_mapping: Optional[Dict[str, Union[str, Tuple[str, Dict]]]] = None,
74-
dev: bool = False,
75-
refresh_cache: bool = False,
69+
self,
70+
root: str,
71+
tables: List[str],
72+
dataset_name: Optional[str] = None,
73+
code_mapping: Optional[Dict[str, Union[str, Tuple[str, Dict]]]] = None,
74+
dev: bool = False,
75+
refresh_cache: bool = False,
7676
):
7777
"""Loads tables into a dict of patients and saves it to cache."""
7878

@@ -93,10 +93,10 @@ def __init__(
9393

9494
# hash filename for cache
9595
args_to_hash = (
96-
[self.dataset_name, root]
97-
+ sorted(tables)
98-
+ sorted(code_mapping.items())
99-
+ ["dev" if dev else "prod"]
96+
[self.dataset_name, root]
97+
+ sorted(tables)
98+
+ sorted(code_mapping.items())
99+
+ ["dev" if dev else "prod"]
100100
)
101101
filename = hash_str("+".join([str(arg) for arg in args_to_hash])) + ".pkl"
102102
self.filepath = os.path.join(MODULE_CACHE_PATH, filename)
@@ -174,8 +174,8 @@ def parse_tables(self) -> Dict[str, Patient]:
174174

175175
@staticmethod
176176
def _add_event_to_patient_dict(
177-
patient_dict: Dict[str, Patient],
178-
event: Event,
177+
patient_dict: Dict[str, Patient],
178+
event: Event,
179179
) -> Dict[str, Patient]:
180180
"""Helper function which adds an event to the patient dict.
181181
@@ -199,8 +199,8 @@ def _add_event_to_patient_dict(
199199
return patient_dict
200200

201201
def _convert_code_in_patient_dict(
202-
self,
203-
patients: Dict[str, Patient],
202+
self,
203+
patients: Dict[str, Patient],
204204
) -> Dict[str, Patient]:
205205
"""Helper function which converts the codes for all patients.
206206
@@ -322,9 +322,9 @@ def info():
322322
print(INFO_MSG)
323323

324324
def set_task(
325-
self,
326-
task_fn: Callable,
327-
task_name: Optional[str] = None,
325+
self,
326+
task_fn: Callable,
327+
task_name: Optional[str] = None,
328328
) -> SampleDataset:
329329
"""Processes the base dataset to generate the task-specific sample dataset.
330330
@@ -354,10 +354,12 @@ def set_task(
354354
task_name = task_fn.__name__
355355
samples = []
356356
for patient_id, patient in tqdm(
357-
self.patients.items(), desc=f"Generating samples for {task_name}"
357+
self.patients.items(), desc=f"Generating samples for {task_name}"
358358
):
359359
samples.extend(task_fn(patient))
360-
sample_dataset = SampleDataset(samples,
361-
dataset_name=self.dataset_name,
362-
task_name=task_name, )
360+
sample_dataset = SampleDataset(
361+
samples,
362+
dataset_name=self.dataset_name,
363+
task_name=task_name,
364+
)
363365
return sample_dataset

pyhealth/datasets/sample_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ def _validate(self) -> Dict:
172172
int, or str.
173173
"""
174174
types = set([type(v) for v in flattened_values])
175-
assert types == set([str]) or len(types.difference(set([int, float]))) == 0, \
176-
f"Key {key} has mixed or unsupported types ({types}) across samples"
175+
assert (
176+
types == set([str]) or len(types.difference(set([int, float]))) == 0
177+
), f"Key {key} has mixed or unsupported types ({types}) across samples"
177178
type_ = types.pop()
178179
"""
179180
4.3. Combined level and type check.

pyhealth/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .gamenet import GAMENet, GAMENetLayer
1010
from .safedrug import SafeDrug, SafeDrugLayer
1111
from .mlp import MLP
12+
from .deepr import Deepr, DeeprLayer

pyhealth/models/base_model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def padding3d(batch):
141141
[max([len(x) for x in visits]) for visits in batch]
142142
)
143143

144+
# the most inner vector length
145+
vec_len = len(batch[0][0][0])
146+
144147
# get mask
145148
mask = torch.zeros(
146149
len(batch),
@@ -154,16 +157,12 @@ def padding3d(batch):
154157

155158
# level-2 padding
156159
batch = [
157-
x + [[[0.0] * len(x[0])]] * (batch_max_length_level2 - len(x))
158-
for x in batch
160+
x + [[[0.0] * vec_len]] * (batch_max_length_level2 - len(x)) for x in batch
159161
]
160162

161163
# level-3 padding
162164
batch = [
163-
[
164-
x + [[0.0] * len(x[0])] * (batch_max_length_level3 - len(x))
165-
for x in visits
166-
]
165+
[x + [[0.0] * vec_len] * (batch_max_length_level3 - len(x)) for x in visits]
167166
for visits in batch
168167
]
169168

0 commit comments

Comments
 (0)