Skip to content

ku-milab/Protocare-

ย 
ย 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

4 Commits
ย 
ย 

Repository files navigation

ProtoCare+: Knowledge Graph Guided Representation Learning for Diagnosis Prediction

License Framework

This repository contains the official PyTorch implementation of the paper: "ProtoCare+: Knowledge Graph Guided Representation Learning for Diagnosis Prediction", Under review, 2026.

๐Ÿ“ Abstract

Deep learning models utilizing Electronic Health Records (EHR) and Medical Knowledge Graphs (KGs) often face challenges regarding KG incompleteness and noise (task-irrelevant information).

ProtoCare+ is a novel framework designed to address these limitations by integrating:

  1. Mixed Attention Module (MAM): Captures both explicit KG relations and implicit data-driven dependencies.
  2. Prototype Learning: Extracts shared latent attributes across patient groups to improve robustness, especially for low-frequency diseases.
  3. Multi-view Graph Representation Learning (GRL): Filters noise through graph contrastive learning from both local (patient) and global (prototype) perspectives.

Extensive experiments on the MIMIC-III dataset demonstrate that ProtoCare+ consistently outperforms state-of-the-art baselines.

๐Ÿ—๏ธ Model Architecture

The overall framework of ProtoCare+ consists of four main components:

Framework Overview Figure 1: The overall framework of ProtoCare+ (Source: Figure 2 in the paper).

  1. Patient Feature Extraction: Encodes visit sequences using GRUs and clinical embeddings (Diagnosis, Procedure, Drug) enhanced by the Mixed Attention Module (MAM).
  2. Prototype Learning: Softly assigns patient embeddings to representative prototypes under diversity and clustering constraints.
  3. Multi-view GRL: Extracts task-relevant information via global (prototype) and local (patient) graph masking modules.
  4. Diagnosis Prediction: Fuses multi-view representations to predict future diagnoses.

Mixed Attention Module (MAM)

MAM Architecture Figure 2: Detailed architecture of the MAM (Source: Figure 3 in the paper). MAM integrates structural signals from the medical KG (via Graph Attention Networks) and contextual co-occurrence patterns (via Transformer-based self-attention).

๐Ÿ› ๏ธ Requirements

The code was developed and tested in the following environment:

  • Python >= 3.8
  • PyTorch == 2.3.0
  • CUDA (Tested on NVIDIA RTX A6000)
  • Libraries: numpy, pandas, scikit-learn, dgl (or torch_geometric depending on implementation)

To install dependencies:

pip install -r requirements.txt

About

ProtoCare+: Knowledge Graph Guided Representation Learning for Diagnosis Prediction

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors