From 4807b50bd6a5e1d4d0a6f92e97349afad6ec1e98 Mon Sep 17 00:00:00 2001 From: EricSchrock Date: Sun, 8 Feb 2026 00:46:51 +0000 Subject: [PATCH 1/2] Remove old chestxray14 example notebooks --- .../chestxray14_binary_classification.ipynb | 1426 ----------------- ...hestxray14_multilabel_classification.ipynb | 1266 --------------- 2 files changed, 2692 deletions(-) delete mode 100644 examples/cxr/chestxray14_binary_classification.ipynb delete mode 100644 examples/cxr/chestxray14_multilabel_classification.ipynb diff --git a/examples/cxr/chestxray14_binary_classification.ipynb b/examples/cxr/chestxray14_binary_classification.ipynb deleted file mode 100644 index 270fa3af7..000000000 --- a/examples/cxr/chestxray14_binary_classification.ipynb +++ /dev/null @@ -1,1426 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "gpuType": "T4" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "2750023fb2bc420c875b3fde2cef2843": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_d942639eecdf4f3c955a5ceabb2dd012", - "IPY_MODEL_21aed97b90ea496dafbbfde642b8da3d", - "IPY_MODEL_c38852bb4f82461dbc2f2ad28949e04f" - ], - "layout": "IPY_MODEL_36b15a47acca4d19ad19aa7a75d6adc4" - } - }, - "d942639eecdf4f3c955a5ceabb2dd012": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_489e34fc92e041f39771ff701ddd6969", - "placeholder": "​", - "style": "IPY_MODEL_00a24949f2324b8f855ec5bfdc92d434", - "value": "Epoch 0 / 1: 100%" - } - }, - "21aed97b90ea496dafbbfde642b8da3d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ff1caf7de77b4b279a3b6d35669b79dd", - "max": 219, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_4c653f14317240edbea90f9b198c10cf", - "value": 219 - } - }, - "c38852bb4f82461dbc2f2ad28949e04f": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a5e6278ca52d47f1b638e9a94be80cbd", - "placeholder": "​", - "style": "IPY_MODEL_85cac2d243444b52a23f40b87d1b4023", - "value": " 219/219 [00:44<00:00,  5.12it/s]" - } - }, - "36b15a47acca4d19ad19aa7a75d6adc4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "489e34fc92e041f39771ff701ddd6969": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "00a24949f2324b8f855ec5bfdc92d434": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "ff1caf7de77b4b279a3b6d35669b79dd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4c653f14317240edbea90f9b198c10cf": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "a5e6278ca52d47f1b638e9a94be80cbd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "85cac2d243444b52a23f40b87d1b4023": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - } - } - } - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Binary Classification Using the ChestX-ray14 Dataset" - ], - "metadata": { - "id": "HaDNCcQJ3tD7" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Step 0: Install PyHealth" - ], - "metadata": { - "id": "j9Zj-n54qEwL" - } - }, - { - "cell_type": "code", - "source": [ - "!rm -rf PyHealth\n", - "!git clone https://github.com/EricSchrock/PyHealth.git\n", - "%cd PyHealth\n", - "!git checkout ChestX-ray14\n", - "!pip install -e ." - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "collapsed": true, - "id": "TWEAeB85p0C7", - "outputId": "32a89b86-4c11-49ca-9c46-867dbfaf2fa7" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Cloning into 'PyHealth'...\n", - "remote: Enumerating objects: 8101, done.\u001b[K\n", - "remote: Counting objects: 100% (1761/1761), done.\u001b[K\n", - "remote: Compressing objects: 100% (512/512), done.\u001b[K\n", - "remote: Total 8101 (delta 1555), reused 1251 (delta 1249), pack-reused 6340 (from 2)\u001b[K\n", - "Receiving objects: 100% (8101/8101), 113.88 MiB | 26.69 MiB/s, done.\n", - "Resolving deltas: 100% (5242/5242), done.\n", - "/content/PyHealth\n", - "Branch 'ChestX-ray14' set up to track remote branch 'ChestX-ray14' from 'origin'.\n", - "Switched to a new branch 'ChestX-ray14'\n", - "Obtaining file:///content/PyHealth\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n", - " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.11.0)\n", - "Collecting mne~=1.10.0 (from pyhealth==2.0a8)\n", - " Downloading mne-1.10.2-py3-none-any.whl.metadata (21 kB)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (3.5)\n", - "Collecting numpy~=1.26.4 (from pyhealth==2.0a8)\n", - " Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ogb>=1.3.5 (from pyhealth==2.0a8)\n", - " Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)\n", - "Collecting pandarallel~=1.6.5 (from pyhealth==2.0a8)\n", - " Downloading pandarallel-1.6.5.tar.gz (14 kB)\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Collecting pandas~=2.3.1 (from pyhealth==2.0a8)\n", - " Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.17.1)\n", - "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n", - "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n", - "Collecting rdkit (from pyhealth==2.0a8)\n", - " Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)\n", - "Collecting scikit-learn~=1.7.0 (from pyhealth==2.0a8)\n", - " Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.23.0+cu126)\n", - "Collecting torch~=2.7.1 (from pyhealth==2.0a8)\n", - " Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n", - "Collecting transformers~=4.53.2 (from pyhealth==2.0a8)\n", - " Downloading transformers-4.53.3-py3-none-any.whl.metadata (40 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n", - "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n", - "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n", - "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n", - "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n", - "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n", - "Collecting outdated>=0.2.0 (from ogb>=1.3.5->pyhealth==2.0a8)\n", - " Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)\n", - "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n", - "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n", - "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n", - "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n", - "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n", - "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n", - "Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n", - "Collecting nvidia-cusparselt-cu12==0.6.3 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)\n", - "Collecting nvidia-nccl-cu12==2.26.2 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n", - "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n", - "Collecting triton==3.3.1 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)\n", - "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n", - "Collecting tokenizers<0.22,>=0.21 (from transformers~=4.53.2->pyhealth==2.0a8)\n", - " Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", - "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.6.2)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n", - "INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.\n", - "Collecting torchvision (from pyhealth==2.0a8)\n", - " Downloading torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", - " Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", - " Downloading torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", - " Downloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", - "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n", - "Collecting littleutils (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8)\n", - " Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)\n", - "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n", - "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.10.5)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n", - "Downloading mne-1.10.2-py3-none-any.whl (7.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m100.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.0/18.0 MB\u001b[0m \u001b[31m125.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading ogb-1.3.6-py3-none-any.whl (78 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.8/78.8 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m145.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.5/9.5 MB\u001b[0m \u001b[31m145.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl (821.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m821.0/821.0 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl (571.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m571.0/571.0 MB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl (156.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m156.8/156.8 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (201.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m201.3/201.3 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.7/155.7 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading transformers-4.53.3-py3-none-any.whl (10.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m133.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.2/36.2 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl (7.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m84.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)\n", - "Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m63.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)\n", - "Building wheels for collected packages: pyhealth, pandarallel\n", - " Building editable for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=10674 sha256=958c7e0bd8938910e22eda0840e62272710f8cae2e42ad8531f1012a34cd222f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-c1tiyeqt/wheels/1c/98/da/d6e74a692d0be5faeba6025d7302fd470b1ee8167b77261ad6\n", - " Building wheel for pandarallel (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pandarallel: filename=pandarallel-1.6.5-py3-none-any.whl size=16674 sha256=d2ad066c2563268e9811ae2c6adb46d872232aa5ad8891689caf3aea26b89d42\n", - " Stored in directory: /root/.cache/pip/wheels/46/f9/0d/40c9cd74a7cb8dc8fe57e8d6c3c19e2c730449c0d3f2bf66b5\n", - "Successfully built pyhealth pandarallel\n", - "Installing collected packages: nvidia-cusparselt-cu12, triton, nvidia-nccl-cu12, nvidia-cudnn-cu12, numpy, littleutils, rdkit, pandas, outdated, torch, tokenizers, scikit-learn, pandarallel, transformers, torchvision, ogb, mne, pyhealth\n", - " Attempting uninstall: nvidia-cusparselt-cu12\n", - " Found existing installation: nvidia-cusparselt-cu12 0.7.1\n", - " Uninstalling nvidia-cusparselt-cu12-0.7.1:\n", - " Successfully uninstalled nvidia-cusparselt-cu12-0.7.1\n", - " Attempting uninstall: triton\n", - " Found existing installation: triton 3.4.0\n", - " Uninstalling triton-3.4.0:\n", - " Successfully uninstalled triton-3.4.0\n", - " Attempting uninstall: nvidia-nccl-cu12\n", - " Found existing installation: nvidia-nccl-cu12 2.27.3\n", - " Uninstalling nvidia-nccl-cu12-2.27.3:\n", - " Successfully uninstalled nvidia-nccl-cu12-2.27.3\n", - " Attempting uninstall: nvidia-cudnn-cu12\n", - " Found existing installation: nvidia-cudnn-cu12 9.10.2.21\n", - " Uninstalling nvidia-cudnn-cu12-9.10.2.21:\n", - " Successfully uninstalled nvidia-cudnn-cu12-9.10.2.21\n", - " Attempting uninstall: numpy\n", - " Found existing installation: numpy 2.0.2\n", - " Uninstalling numpy-2.0.2:\n", - " Successfully uninstalled numpy-2.0.2\n", - " Attempting uninstall: pandas\n", - " Found existing installation: pandas 2.2.2\n", - " Uninstalling pandas-2.2.2:\n", - " Successfully uninstalled pandas-2.2.2\n", - " Attempting uninstall: torch\n", - " Found existing installation: torch 2.8.0+cu126\n", - " Uninstalling torch-2.8.0+cu126:\n", - " Successfully uninstalled torch-2.8.0+cu126\n", - " Attempting uninstall: tokenizers\n", - " Found existing installation: tokenizers 0.22.1\n", - " Uninstalling tokenizers-0.22.1:\n", - " Successfully uninstalled tokenizers-0.22.1\n", - " Attempting uninstall: scikit-learn\n", - " Found existing installation: scikit-learn 1.6.1\n", - " Uninstalling scikit-learn-1.6.1:\n", - " Successfully uninstalled scikit-learn-1.6.1\n", - " Attempting uninstall: transformers\n", - " Found existing installation: transformers 4.57.1\n", - " Uninstalling transformers-4.57.1:\n", - " Successfully uninstalled transformers-4.57.1\n", - " Attempting uninstall: torchvision\n", - " Found existing installation: torchvision 0.23.0+cu126\n", - " Uninstalling torchvision-0.23.0+cu126:\n", - " Successfully uninstalled torchvision-0.23.0+cu126\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.\n", - "pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", - "opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", - "torchaudio 2.8.0+cu126 requires torch==2.8.0, but you have torch 2.7.1 which is incompatible.\n", - "opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", - "opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed littleutils-0.2.4 mne-1.10.2 numpy-1.26.4 nvidia-cudnn-cu12-9.5.1.17 nvidia-cusparselt-cu12-0.6.3 nvidia-nccl-cu12-2.26.2 ogb-1.3.6 outdated-0.2.2 pandarallel-1.6.5 pandas-2.3.3 pyhealth-2.0a8 rdkit-2025.9.1 scikit-learn-1.7.2 tokenizers-0.21.4 torch-2.7.1 torchvision-0.22.1 transformers-4.53.3 triton-3.3.1\n" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "numpy" - ] - }, - "id": "3737617eb2cf402699bacea64f559c14" - } - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 1: Load Dataset" - ], - "metadata": { - "id": "rMjzPqNbscDV" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import ChestXray14Dataset\n", - "\n", - "dataset = ChestXray14Dataset(download=True, partial=True)\n", - "dataset.stats()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q_fTVUTrsryn", - "outputId": "0660d909-31c6-48df-bb98-a015e48dd88d" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Dataset: ChestX-ray14\n", - "Dev mode: False\n", - "Number of patients: 1335\n", - "Number of events: 4999\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 2: Define Task" - ], - "metadata": { - "id": "ecF9IgCb22N5" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.tasks import ChestXray14BinaryClassification\n", - "\n", - "task = ChestXray14BinaryClassification(disease=\"infiltration\")\n", - "samples = dataset.set_task(task)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uj9ALkQGtVqF", - "outputId": "bfb6c953-9411-40be-9ca0-060077cbad96" - }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating samples with 1 worker(s)...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n", - "Generating samples for ChestXray14BinaryClassification with 1 worker: 100%|██████████| 1335/1335 [00:00<00:00, 1770.85it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Label label vocab: {0: 0, 1: 1}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.processors.label_processor:Label label vocab: {0: 0, 1: 1}\n", - "Processing samples: 100%|██████████| 4999/4999 [01:22<00:00, 60.94it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generated 4999 samples for task ChestXray14BinaryClassification\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.datasets.base_dataset:Generated 4999 samples for task ChestXray14BinaryClassification\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import get_dataloader, split_by_sample\n", - "\n", - "train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])\n", - "\n", - "train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)\n", - "test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)" - ], - "metadata": { - "id": "8qS3hfKX5GNo" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 3: Define Model" - ], - "metadata": { - "id": "SjonWePy1r6N" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.models import CNN\n", - "\n", - "model = CNN(dataset=samples)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VydSOr8u0XWG", - "outputId": "9f3bf251-0b5d-4457-9dfe-14f5d1beaeb5" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/content/PyHealth/pyhealth/metrics/calibration.py:122: SyntaxWarning: invalid escape sequence '\\c'\n", - " accuracy of 1. Thus, the ECE is :math:`\\\\frac{1}{3} \\cdot 0.49 + \\\\frac{2}{3}\\cdot 0.3=0.3633`.\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Warning: No embedding created for field due to lack of compatible processor: image\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 4: Train Model" - ], - "metadata": { - "id": "0jqDpKxgAu3-" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.trainer import Trainer\n", - "\n", - "trainer = Trainer(model=model)\n", - "trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "2750023fb2bc420c875b3fde2cef2843", - "d942639eecdf4f3c955a5ceabb2dd012", - "21aed97b90ea496dafbbfde642b8da3d", - "c38852bb4f82461dbc2f2ad28949e04f", - "36b15a47acca4d19ad19aa7a75d6adc4", - "489e34fc92e041f39771ff701ddd6969", - "00a24949f2324b8f855ec5bfdc92d434", - "ff1caf7de77b4b279a3b6d35669b79dd", - "4c653f14317240edbea90f9b198c10cf", - "a5e6278ca52d47f1b638e9a94be80cbd", - "85cac2d243444b52a23f40b87d1b4023" - ] - }, - "id": "-our6gpdAyGD", - "outputId": "ac84d73f-9940-4333-8f7e-f7b9f2614da9" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=1, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=1, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Metrics: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Metrics: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Epoch 0 / 1: 0%| | 0/219 [00:00=1.3.5 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.3.6)\n", - "Requirement already satisfied: pandarallel~=1.6.5 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.6.5)\n", - "Requirement already satisfied: pandas~=2.3.1 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.3.3)\n", - "Requirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.17.1)\n", - "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n", - "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n", - "Requirement already satisfied: rdkit in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2025.9.1)\n", - "Requirement already satisfied: scikit-learn~=1.7.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.7.2)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.22.1)\n", - "Requirement already satisfied: torch~=2.7.1 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.7.1)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n", - "Requirement already satisfied: transformers~=4.53.2 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.53.3)\n", - "Requirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n", - "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n", - "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n", - "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n", - "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n", - "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n", - "Requirement already satisfied: outdated>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (0.2.2)\n", - "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n", - "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n", - "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n", - "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n", - "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n", - "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (9.5.1.17)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n", - "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (0.6.3)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2.26.2)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n", - "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n", - "Requirement already satisfied: triton==3.3.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.3.1)\n", - "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n", - "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.21.4)\n", - "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.6.2)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n", - "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n", - "Requirement already satisfied: littleutils in /usr/local/lib/python3.12/dist-packages (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8) (0.2.4)\n", - "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n", - "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.10.5)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n", - "Building wheels for collected packages: pyhealth\n", - " Building editable for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=10674 sha256=958c7e0bd8938910e22eda0840e62272710f8cae2e42ad8531f1012a34cd222f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-netvrq88/wheels/1c/98/da/d6e74a692d0be5faeba6025d7302fd470b1ee8167b77261ad6\n", - "Successfully built pyhealth\n", - "Installing collected packages: pyhealth\n", - " Attempting uninstall: pyhealth\n", - " Found existing installation: pyhealth 2.0a8\n", - " Uninstalling pyhealth-2.0a8:\n", - " Successfully uninstalled pyhealth-2.0a8\n", - "Successfully installed pyhealth-2.0a8\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 1: Load Dataset" - ], - "metadata": { - "id": "rMjzPqNbscDV" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import ChestXray14Dataset\n", - "\n", - "dataset = ChestXray14Dataset(download=True, partial=True)\n", - "dataset.stats()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q_fTVUTrsryn", - "outputId": "942b186a-dc4d-4b05-eedd-c0d285aae951" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Dataset: ChestX-ray14\n", - "Dev mode: False\n", - "Number of patients: 1335\n", - "Number of events: 4999\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 2: Define Task" - ], - "metadata": { - "id": "ecF9IgCb22N5" - } - }, - { - "cell_type": "code", - "source": [ - "samples = dataset.set_task()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uj9ALkQGtVqF", - "outputId": "076cbb31-879f-4414-963e-7e3631f0ed31" - }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Setting task ChestXray14MultilabelClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Setting task ChestXray14MultilabelClassification for ChestX-ray14 base dataset...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generating samples with 1 worker(s)...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n", - "Generating samples for ChestXray14MultilabelClassification with 1 worker: 100%|██████████| 1335/1335 [00:00<00:00, 1475.55it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Label labels vocab: {'atelectasis': 0, 'cardiomegaly': 1, 'consolidation': 2, 'edema': 3, 'effusion': 4, 'emphysema': 5, 'fibrosis': 6, 'hernia': 7, 'infiltration': 8, 'mass': 9, 'nodule': 10, 'pleural_thickening': 11, 'pneumonia': 12, 'pneumothorax': 13}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.processors.label_processor:Label labels vocab: {'atelectasis': 0, 'cardiomegaly': 1, 'consolidation': 2, 'edema': 3, 'effusion': 4, 'emphysema': 5, 'fibrosis': 6, 'hernia': 7, 'infiltration': 8, 'mass': 9, 'nodule': 10, 'pleural_thickening': 11, 'pneumonia': 12, 'pneumothorax': 13}\n", - "Processing samples: 100%|██████████| 4999/4999 [01:18<00:00, 63.31it/s]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Generated 4999 samples for task ChestXray14MultilabelClassification\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n", - "INFO:pyhealth.datasets.base_dataset:Generated 4999 samples for task ChestXray14MultilabelClassification\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import get_dataloader, split_by_sample\n", - "\n", - "train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])\n", - "\n", - "train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)\n", - "test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)" - ], - "metadata": { - "id": "8qS3hfKX5GNo" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 3: Define Model" - ], - "metadata": { - "id": "SjonWePy1r6N" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.models import CNN\n", - "\n", - "model = CNN(dataset=samples)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VydSOr8u0XWG", - "outputId": "52e4df8b-00fd-47b7-e3ce-5836df69ffa9" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/content/PyHealth/pyhealth/metrics/calibration.py:122: SyntaxWarning: invalid escape sequence '\\c'\n", - " accuracy of 1. Thus, the ECE is :math:`\\\\frac{1}{3} \\cdot 0.49 + \\\\frac{2}{3}\\cdot 0.3=0.3633`.\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Warning: No embedding created for field due to lack of compatible processor: image\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Step 4: Train Model" - ], - "metadata": { - "id": "0jqDpKxgAu3-" - } - }, - { - "cell_type": "code", - "source": [ - "from pyhealth.trainer import Trainer\n", - "\n", - "# Only measure accurancy because with the \"partial\" dataset it is likely that\n", - "# there are not positive samples of every label present in the validation and test sets\n", - "trainer = Trainer(model=model, metrics=[\"accuracy\"])\n", - "trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "d5764f3cccdf4c52a25d0b8b2071e3b3", - "775da3f0d3e643f793ba6ad6aefdefca", - "87c9b17add0b434a897d46ec46826b4c", - "cba94b14345e40f4806e26a41edb72bc", - "207a6f173e57485b9abb66eb5f259c74", - "b32f9870995c4af3890deb5af41a77e0", - "517d39e922b543b4a111c1c72dc2abbd", - "3c26b80083274382b36f31add64ed5ed", - "8a9b003734834976aafca099ab6a37a5", - "059bda91051a46c18e0b02aa11eb73fa", - "ab3c660b9c2449619cca4a9d31392391" - ] - }, - "id": "-our6gpdAyGD", - "outputId": "d7360434-f396-4ffc-d348-04bfc6c3a524" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=14, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=14, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Metrics: ['accuracy']\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Metrics: ['accuracy']\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Epoch 0 / 1: 0%| | 0/219 [00:00 Date: Sun, 8 Feb 2026 00:51:35 +0000 Subject: [PATCH 2/2] Add back PyHealth 2.0 chestxray14 example notebooks --- .../chestxray14_binary_classification.ipynb | 423 ++++++++++++++++++ ...hestxray14_multilabel_classification.ipynb | 411 +++++++++++++++++ 2 files changed, 834 insertions(+) create mode 100644 examples/cxr/chestxray14_binary_classification.ipynb create mode 100644 examples/cxr/chestxray14_multilabel_classification.ipynb diff --git a/examples/cxr/chestxray14_binary_classification.ipynb b/examples/cxr/chestxray14_binary_classification.ipynb new file mode 100644 index 000000000..dde84f642 --- /dev/null +++ b/examples/cxr/chestxray14_binary_classification.ipynb @@ -0,0 +1,423 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c5b1f2fc", + "metadata": {}, + "source": [ + "# Binary Classification Using the ChestX-ray14 Dataset" + ] + }, + { + "cell_type": "markdown", + "id": "ef73c37a", + "metadata": {}, + "source": [ + "## Step 0: Install PyHealth" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89df5010", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install pyhealth ipywidgets" + ] + }, + { + "cell_type": "markdown", + "id": "1ad732a0", + "metadata": {}, + "source": [ + "## Step 1: Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1f8f4f2a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ./images_01.tar.gz...\n", + "Checking MD5 checksum for ./images_01.tar.gz...\n", + "Extracting ./images_01.tar.gz...\n", + "Deleting ./images_01.tar.gz...\n", + "Download complete\n", + "Initializing ChestX-ray14 dataset from . (dev mode: False)\n", + "No cache_dir provided. Using default cache dir: /root/.cache/pyhealth/fb6e8a46-32a1-580b-bb6c-4015d54b1bc1\n", + "Scanning table: chestxray14 from /root/chestxray14-metadata-pyhealth.csv\n", + "Caching event dataframe to /root/.cache/pyhealth/fb6e8a46-32a1-580b-bb6c-4015d54b1bc1/global_event_df.parquet...\n", + "Dataset: ChestX-ray14\n", + "Dev mode: False\n", + "Number of patients: 1335\n", + "Number of events: 4999\n" + ] + } + ], + "source": [ + "from pyhealth.datasets import ChestXray14Dataset\n", + "\n", + "dataset = ChestXray14Dataset(download=True, partial=True)\n", + "dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "id": "501af8c4", + "metadata": {}, + "source": [ + "## Step 2: Define Task" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6cf188af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 1335 patients. (Polars threads: 22)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1335 [00:00\n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: None\n", + "Monitor criterion: max\n", + "Epochs: 1\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5e539aa75e4045fdbcc80b6ea50be128", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 1: 0%| | 0/219 [00:00\n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: None\n", + "Monitor criterion: max\n", + "Epochs: 1\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "031d701ec469490a9cdde45fcb16ac87", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 1: 0%| | 0/219 [00:00