This repository contains the official PyTorch implementation of the paper: "ProtoCare+: Knowledge Graph Guided Representation Learning for Diagnosis Prediction", Under review, 2026.
Deep learning models utilizing Electronic Health Records (EHR) and Medical Knowledge Graphs (KGs) often face challenges regarding KG incompleteness and noise (task-irrelevant information).
ProtoCare+ is a novel framework designed to address these limitations by integrating:
- Mixed Attention Module (MAM): Captures both explicit KG relations and implicit data-driven dependencies.
- Prototype Learning: Extracts shared latent attributes across patient groups to improve robustness, especially for low-frequency diseases.
- Multi-view Graph Representation Learning (GRL): Filters noise through graph contrastive learning from both local (patient) and global (prototype) perspectives.
Extensive experiments on the MIMIC-III dataset demonstrate that ProtoCare+ consistently outperforms state-of-the-art baselines.
The overall framework of ProtoCare+ consists of four main components:
Figure 1: The overall framework of ProtoCare+ (Source: Figure 2 in the paper).
- Patient Feature Extraction: Encodes visit sequences using GRUs and clinical embeddings (Diagnosis, Procedure, Drug) enhanced by the Mixed Attention Module (MAM).
- Prototype Learning: Softly assigns patient embeddings to representative prototypes under diversity and clustering constraints.
- Multi-view GRL: Extracts task-relevant information via global (prototype) and local (patient) graph masking modules.
- Diagnosis Prediction: Fuses multi-view representations to predict future diagnoses.
Figure 2: Detailed architecture of the MAM (Source: Figure 3 in the paper).
MAM integrates structural signals from the medical KG (via Graph Attention Networks) and contextual co-occurrence patterns (via Transformer-based self-attention).
The code was developed and tested in the following environment:
- Python >= 3.8
- PyTorch == 2.3.0
- CUDA (Tested on NVIDIA RTX A6000)
- Libraries:
numpy,pandas,scikit-learn,dgl(ortorch_geometricdepending on implementation)
To install dependencies:
pip install -r requirements.txt