|
| 1 | +from pyhealth.datasets import eICUDataset |
| 2 | +from pyhealth.datasets import split_by_patient, get_dataloader |
| 3 | +from pyhealth.models import Transformer |
| 4 | +from pyhealth.tasks import drug_recommendation_eicu_fn |
| 5 | +from pyhealth.trainer import Trainer |
| 6 | + |
| 7 | +# STEP 1: load data |
| 8 | +base_dataset = eICUDataset( |
| 9 | + root="/srv/local/data/physionet.org/files/eicu-crd/2.0", |
| 10 | + tables=["diagnosis", "medication", "physicalExam"], |
| 11 | + dev=True, |
| 12 | +) |
| 13 | +base_dataset.stat() |
| 14 | + |
| 15 | +# STEP 2: set task |
| 16 | + |
| 17 | +from pyhealth.data import Visit, Patient |
| 18 | + |
| 19 | + |
| 20 | +def drug_recommendation_eicu_fn(patient: Patient): |
| 21 | + """Processes a single patient for the drug recommendation task. |
| 22 | +
|
| 23 | + Drug recommendation aims at recommending a set of drugs given the patient health |
| 24 | + history (e.g., conditions and procedures). |
| 25 | +
|
| 26 | + Args: |
| 27 | + patient: a Patient object |
| 28 | +
|
| 29 | + Returns: |
| 30 | + samples: a list of samples, each sample is a dict with patient_id, visit_id, |
| 31 | + and other task-specific attributes as key |
| 32 | +
|
| 33 | + Examples: |
| 34 | + >>> from pyhealth.datasets import eICUDataset |
| 35 | + >>> eicu_base = eICUDataset( |
| 36 | + ... root="/srv/local/data/physionet.org/files/eicu-crd/2.0", |
| 37 | + ... tables=["diagnosis", "medication"], |
| 38 | + ... code_mapping={}, |
| 39 | + ... dev=True |
| 40 | + ... ) |
| 41 | + >>> from pyhealth.tasks import drug_recommendation_eicu_fn |
| 42 | + >>> eicu_sample = eicu_base.set_task(drug_recommendation_eicu_fn) |
| 43 | + >>> eicu_sample.samples[0] |
| 44 | + [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51']], 'procedures': [['1']], 'label': [['2', '3', '4']]}] |
| 45 | + """ |
| 46 | + samples = [] |
| 47 | + for i in range(len(patient)): |
| 48 | + visit: Visit = patient[i] |
| 49 | + conditions = visit.get_code_list(table="diagnosis") |
| 50 | + procedures = visit.get_code_list(table="physicalExam") |
| 51 | + drugs = visit.get_code_list(table="medication") |
| 52 | + # exclude: visits without condition, procedure, or drug code |
| 53 | + if len(conditions) * len(procedures) * len(drugs) == 0: |
| 54 | + continue |
| 55 | + # TODO: should also exclude visit with age < 18 |
| 56 | + samples.append( |
| 57 | + { |
| 58 | + "visit_id": visit.visit_id, |
| 59 | + "patient_id": patient.patient_id, |
| 60 | + "conditions": conditions, |
| 61 | + "procedures": procedures, |
| 62 | + "drugs": drugs, |
| 63 | + "drugs_all": drugs, |
| 64 | + } |
| 65 | + ) |
| 66 | + # exclude: patients with less than 2 visit |
| 67 | + if len(samples) < 2: |
| 68 | + return [] |
| 69 | + # add history |
| 70 | + samples[0]["conditions"] = [samples[0]["conditions"]] |
| 71 | + samples[0]["procedures"] = [samples[0]["procedures"]] |
| 72 | + samples[0]["drugs_all"] = [samples[0]["drugs_all"]] |
| 73 | + |
| 74 | + for i in range(1, len(samples)): |
| 75 | + samples[i]["conditions"] = samples[i - 1]["conditions"] + [ |
| 76 | + samples[i]["conditions"] |
| 77 | + ] |
| 78 | + samples[i]["procedures"] = samples[i - 1]["procedures"] + [ |
| 79 | + samples[i]["procedures"] |
| 80 | + ] |
| 81 | + samples[i]["drugs_all"] = samples[i - 1]["drugs_all"] + [ |
| 82 | + samples[i]["drugs_all"] |
| 83 | + ] |
| 84 | + |
| 85 | + return samples |
| 86 | + |
| 87 | + |
| 88 | +sample_dataset = base_dataset.set_task(drug_recommendation_eicu_fn) |
| 89 | +sample_dataset.stat() |
| 90 | + |
| 91 | +train_dataset, val_dataset, test_dataset = split_by_patient( |
| 92 | + sample_dataset, [0.8, 0.1, 0.1] |
| 93 | +) |
| 94 | +train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) |
| 95 | +val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) |
| 96 | +test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) |
| 97 | + |
| 98 | +# STEP 3: define model |
| 99 | +model = Transformer( |
| 100 | + dataset=sample_dataset, |
| 101 | + feature_keys=["conditions", "procedures"], |
| 102 | + label_key="drugs", |
| 103 | + mode="multilabel", |
| 104 | +) |
| 105 | + |
| 106 | +# STEP 4: define trainer |
| 107 | +trainer = Trainer(model=model) |
| 108 | +trainer.train( |
| 109 | + train_dataloader=train_dataloader, |
| 110 | + val_dataloader=val_dataloader, |
| 111 | + epochs=50, |
| 112 | + monitor="pr_auc_samples", |
| 113 | +) |
| 114 | + |
| 115 | +# STEP 5: evaluate |
| 116 | +trainer.evaluate(test_dataloader) |
0 commit comments