diff --git a/examples/cxr/chestxray14_binary_classification.ipynb b/examples/cxr/chestxray14_binary_classification.ipynb index 270fa3af7..dde84f642 100644 --- a/examples/cxr/chestxray14_binary_classification.ipynb +++ b/examples/cxr/chestxray14_binary_classification.ipynb @@ -1,1426 +1,423 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "gpuType": "T4" + "cells": [ + { + "cell_type": "markdown", + "id": "c5b1f2fc", + "metadata": {}, + "source": [ + "# Binary Classification Using the ChestX-ray14 Dataset" + ] + }, + { + "cell_type": "markdown", + "id": "ef73c37a", + "metadata": {}, + "source": [ + "## Step 0: Install PyHealth" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89df5010", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install pyhealth ipywidgets" + ] + }, + { + "cell_type": "markdown", + "id": "1ad732a0", + "metadata": {}, + "source": [ + "## Step 1: Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1f8f4f2a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ./images_01.tar.gz...\n", + "Checking MD5 checksum for ./images_01.tar.gz...\n", + "Extracting ./images_01.tar.gz...\n", + "Deleting ./images_01.tar.gz...\n", + "Download complete\n", + "Initializing ChestX-ray14 dataset from . (dev mode: False)\n", + "No cache_dir provided. Using default cache dir: /root/.cache/pyhealth/fb6e8a46-32a1-580b-bb6c-4015d54b1bc1\n", + "Scanning table: chestxray14 from /root/chestxray14-metadata-pyhealth.csv\n", + "Caching event dataframe to /root/.cache/pyhealth/fb6e8a46-32a1-580b-bb6c-4015d54b1bc1/global_event_df.parquet...\n", + "Dataset: ChestX-ray14\n", + "Dev mode: False\n", + "Number of patients: 1335\n", + "Number of events: 4999\n" + ] + } + ], + "source": [ + "from pyhealth.datasets import ChestXray14Dataset\n", + "\n", + "dataset = ChestXray14Dataset(download=True, partial=True)\n", + "dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "id": "501af8c4", + "metadata": {}, + "source": [ + "## Step 2: Define Task" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6cf188af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task ChestXray14BinaryClassification for ChestX-ray14 base dataset...\n", + "Applying task transformations on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 1335 patients. (Polars threads: 22)\n" + ] }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1335 [00:00=1.3.5 (from pyhealth==2.0a8)\n", - " Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)\n", - "Collecting pandarallel~=1.6.5 (from pyhealth==2.0a8)\n", - " Downloading pandarallel-1.6.5.tar.gz (14 kB)\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Collecting pandas~=2.3.1 (from pyhealth==2.0a8)\n", - " Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.17.1)\n", - "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n", - "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n", - "Collecting rdkit (from pyhealth==2.0a8)\n", - " Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)\n", - "Collecting scikit-learn~=1.7.0 (from pyhealth==2.0a8)\n", - " Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.23.0+cu126)\n", - "Collecting torch~=2.7.1 (from pyhealth==2.0a8)\n", - " Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n", - "Collecting transformers~=4.53.2 (from pyhealth==2.0a8)\n", - " Downloading transformers-4.53.3-py3-none-any.whl.metadata (40 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n", - "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n", - "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n", - "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n", - "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n", - "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n", - "Collecting outdated>=0.2.0 (from ogb>=1.3.5->pyhealth==2.0a8)\n", - " Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)\n", - "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n", - "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n", - "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n", - "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n", - "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n", - "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n", - "Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n", - "Collecting nvidia-cusparselt-cu12==0.6.3 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)\n", - "Collecting nvidia-nccl-cu12==2.26.2 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n", - "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n", - "Collecting triton==3.3.1 (from torch~=2.7.1->pyhealth==2.0a8)\n", - " Downloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)\n", - "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n", - "Collecting tokenizers<0.22,>=0.21 (from transformers~=4.53.2->pyhealth==2.0a8)\n", - " Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", - "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.6.2)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n", - "INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.\n", - "Collecting torchvision (from pyhealth==2.0a8)\n", - " Downloading torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", - " Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", - " Downloading torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", - " Downloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", - "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n", - "Collecting littleutils (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8)\n", - " Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)\n", - "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n", - "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.10.5)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n", - "Downloading mne-1.10.2-py3-none-any.whl (7.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m100.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.0/18.0 MB\u001b[0m \u001b[31m125.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading ogb-1.3.6-py3-none-any.whl (78 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.8/78.8 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m145.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.5/9.5 MB\u001b[0m \u001b[31m145.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl (821.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m821.0/821.0 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl (571.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m571.0/571.0 MB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl (156.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m156.8/156.8 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (201.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m201.3/201.3 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.7/155.7 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading transformers-4.53.3-py3-none-any.whl (10.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m133.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.2/36.2 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl (7.5 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m84.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)\n", - "Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m63.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)\n", - "Building wheels for collected packages: pyhealth, pandarallel\n", - " Building editable for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=10674 sha256=958c7e0bd8938910e22eda0840e62272710f8cae2e42ad8531f1012a34cd222f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-c1tiyeqt/wheels/1c/98/da/d6e74a692d0be5faeba6025d7302fd470b1ee8167b77261ad6\n", - " Building wheel for pandarallel (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pandarallel: filename=pandarallel-1.6.5-py3-none-any.whl size=16674 sha256=d2ad066c2563268e9811ae2c6adb46d872232aa5ad8891689caf3aea26b89d42\n", - " Stored in directory: /root/.cache/pip/wheels/46/f9/0d/40c9cd74a7cb8dc8fe57e8d6c3c19e2c730449c0d3f2bf66b5\n", - "Successfully built pyhealth pandarallel\n", - "Installing collected packages: nvidia-cusparselt-cu12, triton, nvidia-nccl-cu12, nvidia-cudnn-cu12, numpy, littleutils, rdkit, pandas, outdated, torch, tokenizers, scikit-learn, pandarallel, transformers, torchvision, ogb, mne, pyhealth\n", - " Attempting uninstall: nvidia-cusparselt-cu12\n", - " Found existing installation: nvidia-cusparselt-cu12 0.7.1\n", - " Uninstalling nvidia-cusparselt-cu12-0.7.1:\n", - " Successfully uninstalled nvidia-cusparselt-cu12-0.7.1\n", - " Attempting uninstall: triton\n", - " Found existing installation: triton 3.4.0\n", - " Uninstalling triton-3.4.0:\n", - " Successfully uninstalled triton-3.4.0\n", - " Attempting uninstall: nvidia-nccl-cu12\n", - " Found existing installation: nvidia-nccl-cu12 2.27.3\n", - " Uninstalling nvidia-nccl-cu12-2.27.3:\n", - " Successfully uninstalled nvidia-nccl-cu12-2.27.3\n", - " Attempting uninstall: nvidia-cudnn-cu12\n", - " Found existing installation: nvidia-cudnn-cu12 9.10.2.21\n", - " Uninstalling nvidia-cudnn-cu12-9.10.2.21:\n", - " Successfully uninstalled nvidia-cudnn-cu12-9.10.2.21\n", - " Attempting uninstall: numpy\n", - " Found existing installation: numpy 2.0.2\n", - " Uninstalling numpy-2.0.2:\n", - " Successfully uninstalled numpy-2.0.2\n", - " Attempting uninstall: pandas\n", - " Found existing installation: pandas 2.2.2\n", - " Uninstalling pandas-2.2.2:\n", - " Successfully uninstalled pandas-2.2.2\n", - " Attempting uninstall: torch\n", - " Found existing installation: torch 2.8.0+cu126\n", - " Uninstalling torch-2.8.0+cu126:\n", - " Successfully uninstalled torch-2.8.0+cu126\n", - " Attempting uninstall: tokenizers\n", - " Found existing installation: tokenizers 0.22.1\n", - " Uninstalling tokenizers-0.22.1:\n", - " Successfully uninstalled tokenizers-0.22.1\n", - " Attempting uninstall: scikit-learn\n", - " Found existing installation: scikit-learn 1.6.1\n", - " Uninstalling scikit-learn-1.6.1:\n", - " Successfully uninstalled scikit-learn-1.6.1\n", - " Attempting uninstall: transformers\n", - " Found existing installation: transformers 4.57.1\n", - " Uninstalling transformers-4.57.1:\n", - " Successfully uninstalled transformers-4.57.1\n", - " Attempting uninstall: torchvision\n", - " Found existing installation: torchvision 0.23.0+cu126\n", - " Uninstalling torchvision-0.23.0+cu126:\n", - " Successfully uninstalled torchvision-0.23.0+cu126\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.\n", - "pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", - "opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", - "torchaudio 2.8.0+cu126 requires torch==2.8.0, but you have torch 2.7.1 which is incompatible.\n", - "opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", - "opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", - "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed littleutils-0.2.4 mne-1.10.2 numpy-1.26.4 nvidia-cudnn-cu12-9.5.1.17 nvidia-cusparselt-cu12-0.6.3 nvidia-nccl-cu12-2.26.2 ogb-1.3.6 outdated-0.2.2 pandarallel-1.6.5 pandas-2.3.3 pyhealth-2.0a8 rdkit-2025.9.1 scikit-learn-1.7.2 tokenizers-0.21.4 torch-2.7.1 torchvision-0.22.1 transformers-4.53.3 triton-3.3.1\n" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "numpy" - ] - }, - "id": "3737617eb2cf402699bacea64f559c14" - } - }, - "metadata": {} - } - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] }, { - "cell_type": "markdown", - "source": [ - "## Step 1: Load Dataset" - ], - "metadata": { - "id": "rMjzPqNbscDV" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "Label label vocab: {0: 0, 1: 1}\n", + "Processing samples and saving to /root/.cache/pyhealth/fb6e8a46-32a1-580b-bb6c-4015d54b1bc1/tasks/ChestXray14BinaryClassification_acac6c08-a7e0-5016-99fb-ace461d83f56/samples_0a9d8d5e-42c4-534f-9f35-24cffedeb0db.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 4999 samples. (0 to 4999)\n" + ] }, { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import ChestXray14Dataset\n", - "\n", - "dataset = ChestXray14Dataset(download=True, partial=True)\n", - "dataset.stats()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q_fTVUTrsryn", - "outputId": "0660d909-31c6-48df-bb98-a015e48dd88d" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Dataset: ChestX-ray14\n", - "Dev mode: False\n", - "Number of patients: 1335\n", - "Number of events: 4999\n" - ] - } - ] + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/4999 [00:00\n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: None\n", + "Monitor criterion: max\n", + "Epochs: 1\n", + "Patience: None\n", + "\n" + ] }, { - "cell_type": "code", - "source": [ - "from pyhealth.trainer import Trainer\n", - "\n", - "trainer = Trainer(model=model)\n", - "trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "2750023fb2bc420c875b3fde2cef2843", - "d942639eecdf4f3c955a5ceabb2dd012", - "21aed97b90ea496dafbbfde642b8da3d", - "c38852bb4f82461dbc2f2ad28949e04f", - "36b15a47acca4d19ad19aa7a75d6adc4", - "489e34fc92e041f39771ff701ddd6969", - "00a24949f2324b8f855ec5bfdc92d434", - "ff1caf7de77b4b279a3b6d35669b79dd", - "4c653f14317240edbea90f9b198c10cf", - "a5e6278ca52d47f1b638e9a94be80cbd", - "85cac2d243444b52a23f40b87d1b4023" - ] - }, - "id": "-our6gpdAyGD", - "outputId": "ac84d73f-9940-4333-8f7e-f7b9f2614da9" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5e539aa75e4045fdbcc80b6ea50be128", + "version_major": 2, + "version_minor": 0 }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=1, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:CNN(\n", - " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())\n", - " (cnn): ModuleDict(\n", - " (image): CNNLayer(\n", - " (cnn): ModuleList(\n", - " (0): CNNBlock(\n", - " (conv1): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " )\n", - " (conv2): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (downsample): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (relu): ReLU()\n", - " )\n", - " )\n", - " (pooling): AdaptiveAvgPool2d(output_size=1)\n", - " )\n", - " )\n", - " (fc): Linear(in_features=128, out_features=1, bias=True)\n", - ")\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Metrics: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Metrics: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Device: cuda\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Training:\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Batch size: 16\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Epoch 0 / 1: 0%| | 0/219 [00:00=1.3.5 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.3.6)\n", - "Requirement already satisfied: pandarallel~=1.6.5 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.6.5)\n", - "Requirement already satisfied: pandas~=2.3.1 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.3.3)\n", - "Requirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.17.1)\n", - "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n", - "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n", - "Requirement already satisfied: rdkit in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2025.9.1)\n", - "Requirement already satisfied: scikit-learn~=1.7.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.7.2)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.22.1)\n", - "Requirement already satisfied: torch~=2.7.1 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.7.1)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n", - "Requirement already satisfied: transformers~=4.53.2 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.53.3)\n", - "Requirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n", - "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n", - "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n", - "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n", - "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n", - "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n", - "Requirement already satisfied: outdated>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (0.2.2)\n", - "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n", - "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", - "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n", - "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n", - "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n", - "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n", - "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (9.5.1.17)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n", - "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (0.6.3)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2.26.2)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n", - "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n", - "Requirement already satisfied: triton==3.3.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.3.1)\n", - "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n", - "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.21.4)\n", - "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.6.2)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n", - "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n", - "Requirement already satisfied: littleutils in /usr/local/lib/python3.12/dist-packages (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8) (0.2.4)\n", - "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n", - "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.10.5)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n", - "Building wheels for collected packages: pyhealth\n", - " Building editable for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=10674 sha256=958c7e0bd8938910e22eda0840e62272710f8cae2e42ad8531f1012a34cd222f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-netvrq88/wheels/1c/98/da/d6e74a692d0be5faeba6025d7302fd470b1ee8167b77261ad6\n", - "Successfully built pyhealth\n", - "Installing collected packages: pyhealth\n", - " Attempting uninstall: pyhealth\n", - " Found existing installation: pyhealth 2.0a8\n", - " Uninstalling pyhealth-2.0a8:\n", - " Successfully uninstalled pyhealth-2.0a8\n", - "Successfully installed pyhealth-2.0a8\n" - ] - } - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] }, { - "cell_type": "markdown", - "source": [ - "## Step 1: Load Dataset" - ], - "metadata": { - "id": "rMjzPqNbscDV" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting processors on the dataset...\n", + "Label labels vocab: {'atelectasis': 0, 'cardiomegaly': 1, 'consolidation': 2, 'edema': 3, 'effusion': 4, 'emphysema': 5, 'fibrosis': 6, 'hernia': 7, 'infiltration': 8, 'mass': 9, 'nodule': 10, 'pleural_thickening': 11, 'pneumonia': 12, 'pneumothorax': 13}\n", + "Processing samples and saving to /root/.cache/pyhealth/fb6e8a46-32a1-580b-bb6c-4015d54b1bc1/tasks/ChestXray14MultilabelClassification_f8cedbe4-72a8-53c3-922d-4cc8730f4c2d/samples_e4cb1532-b4bc-5434-aac5-9269556ad11e.ld...\n", + "Applying processors on data with 1 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 4999 samples. (0 to 4999)\n" + ] }, { - "cell_type": "code", - "source": [ - "from pyhealth.datasets import ChestXray14Dataset\n", - "\n", - "dataset = ChestXray14Dataset(download=True, partial=True)\n", - "dataset.stats()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q_fTVUTrsryn", - "outputId": "942b186a-dc4d-4b05-eedd-c0d285aae951" - }, - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.chestxray14:Download complete\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Dataset: ChestX-ray14\n", - "Dev mode: False\n", - "Number of patients: 1335\n", - "Number of events: 4999\n" - ] - } - ] + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/4999 [00:00\n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: None\n", + "Monitor criterion: max\n", + "Epochs: 1\n", + "Patience: None\n", + "\n" + ] }, { - "cell_type": "code", - "source": [ - "from pyhealth.models import CNN\n", - "\n", - "model = CNN(dataset=samples)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VydSOr8u0XWG", - "outputId": "52e4df8b-00fd-47b7-e3ce-5836df69ffa9" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "031d701ec469490a9cdde45fcb16ac87", + "version_major": 2, + "version_minor": 0 }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/content/PyHealth/pyhealth/metrics/calibration.py:122: SyntaxWarning: invalid escape sequence '\\c'\n", - " accuracy of 1. Thus, the ECE is :math:`\\\\frac{1}{3} \\cdot 0.49 + \\\\frac{2}{3}\\cdot 0.3=0.3633`.\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Warning: No embedding created for field due to lack of compatible processor: image\n" - ] - } + "text/plain": [ + "Epoch 0 / 1: 0%| | 0/219 [00:00\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Weight decay: 0.0\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Max grad norm: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Val dataloader: \n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Monitor criterion: max\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Epochs: 1\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:Patience: None\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "INFO:pyhealth.trainer:\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Epoch 0 / 1: 0%| | 0/219 [00:00