From b7196c4d62dd0ae03887860eeb13d9f3268eabeb Mon Sep 17 00:00:00 2001 From: Felipe Bonchristiano Date: Sat, 6 Jun 2026 00:07:00 -0300 Subject: [PATCH 1/7] Attention Rollout skeleton --- .../interpret/methods/attention_rollout.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 pyhealth/interpret/methods/attention_rollout.py diff --git a/pyhealth/interpret/methods/attention_rollout.py b/pyhealth/interpret/methods/attention_rollout.py new file mode 100644 index 000000000..151443a27 --- /dev/null +++ b/pyhealth/interpret/methods/attention_rollout.py @@ -0,0 +1,46 @@ +from typing import Dict, Optional + +import torch + +from pyhealth.models.base_model import BaseModel +from .base_interpreter import BaseInterpreter + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Main interpreter +# --------------------------------------------------------------------------- + +class AttentionRollout(BaseInterpreter): + def __init__(self, model: BaseModel, head_fusion="mean"): + super().__init__() + + required_methods = [ + "set_attention_hooks", + "get_attention_layers", + "get_relevance_tensor", + ] + missing_methods = [ + method for method in required_methods if not hasattr(model, method) + ] + + if missing_methods: + raise TypeError( + "AttentionRollout requires a model that exposes the attention " + "interpretability methods: " + f"{', '.join(required_methods)}. " + f"Missing: {', '.join(missing_methods)}." + ) + + self.model = model + self.head_fusion = head_fusion + + def attribute( + self, + target_class_idx: Optional[int] = None, + **data, + ) -> Dict[str, torch.Tensor]: + pass \ No newline at end of file From a327010fb337b20ce4cf48f09115231086c2756d Mon Sep 17 00:00:00 2001 From: Felipe Bonchristiano Date: Sun, 7 Jun 2026 19:20:51 -0300 Subject: [PATCH 2/7] attention_rollout.py done --- pyhealth/interpret/methods/__init__.py | 2 + .../interpret/methods/attention_rollout.py | 249 ++++++++++++++++-- 2 files changed, 236 insertions(+), 15 deletions(-) diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 6c92cb6e4..3f012f1b0 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -1,6 +1,7 @@ from pyhealth.interpret.methods.base_interpreter import BaseInterpreter from pyhealth.interpret.methods.baseline import RandomBaseline from pyhealth.interpret.methods.chefer import CheferRelevance +from pyhealth.interpret.methods.attention_rollout import AttentionRollout from pyhealth.interpret.methods.basic_gradient import BasicGradientSaliencyMaps from pyhealth.interpret.methods.deeplift import DeepLift from pyhealth.interpret.methods.gim import GIM @@ -15,6 +16,7 @@ __all__ = [ "BaseInterpreter", "CheferRelevance", + "AttentionRollout", "DeepLift", "GIM", "IntegratedGradientGIM", diff --git a/pyhealth/interpret/methods/attention_rollout.py b/pyhealth/interpret/methods/attention_rollout.py index 151443a27..dba9d336f 100644 --- a/pyhealth/interpret/methods/attention_rollout.py +++ b/pyhealth/interpret/methods/attention_rollout.py @@ -1,3 +1,7 @@ +# Author: Felipe Amaral Bonchristiano +# NetID: felipea5 +# Description: Attention rollout interpretability method implementation for PyHealth 2.0 + from typing import Dict, Optional import torch @@ -5,27 +9,106 @@ from pyhealth.models.base_model import BaseModel from .base_interpreter import BaseInterpreter -# --------------------------------------------------------------------------- -# Helper functions -# --------------------------------------------------------------------------- +class AttentionRollout(BaseInterpreter): + """Attention rollout for transformer interpretability. + + Implements the canonical attention rollout method of Abnar & Zuidema, + "Quantifying Attention Flow in Transformers" (2020), + https://arxiv.org/abs/2005.00928. -# --------------------------------------------------------------------------- -# Main interpreter -# --------------------------------------------------------------------------- + Unlike :class:`~pyhealth.interpret.methods.CheferRelevance`, which is + gradient-weighted and class-specific, rollout is **forward-pass only**, + **gradient-free**, and **class-agnostic**: it quantifies how attention + propagates information across layers, independent of any target class. + It serves as the standard baseline that gradient-based attention methods + are compared against. -class AttentionRollout(BaseInterpreter): - def __init__(self, model: BaseModel, head_fusion="mean"): - super().__init__() + This interpreter works with any model that exposes the attention-readout + methods ``set_attention_hooks``, ``get_attention_layers``, and + ``get_relevance_tensor`` (currently :class:`~pyhealth.models.Transformer` + and :class:`~pyhealth.models.StageAttentionNet`). Compatibility is checked + by duck-typing in ``__init__`` rather than by requiring a named interface, + since these methods are general attention readout and not specific to any + one method. + + The algorithm, per feature key: + + 1. Enable attention hooks via ``model.set_attention_hooks(True)`` and run a + single forward pass (no backward pass). + 2. Retrieve per-layer attention maps via ``model.get_attention_layers()``, + discarding the gradient element of each ``(attn_map, attn_grad)`` pair. + 3. Fuse heads (mean) to get one ``[batch, seq, seq]`` matrix per layer. + 4. Account for residual connections: ``A_hat = 0.5 * (A + I)``. + 5. Compose layers by matrix product: ``rollout = A_hat_L @ ... @ A_hat_1``. + 6. Reduce to per-token scores via ``model.get_relevance_tensor()``, then + expand to raw input value shapes. + + Because each ``A_hat`` is row-stochastic, so is their product; the + per-token relevance therefore forms a distribution over tokens (sums to 1 + before the input-shape expansion). + + Args: + model (BaseModel): A trained PyHealth model exposing the attention- + readout methods listed above. + head_fusion (str): How to combine attention heads into a single matrix + per layer. Currently only ``"mean"`` is supported (the canonical + choice from the paper). Defaults to ``"mean"``. + + Example: + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> from pyhealth.models import Transformer + >>> from pyhealth.interpret.methods import AttentionRollout + >>> + >>> samples = [ + ... { + ... "patient_id": "p0", + ... "visit_id": "v0", + ... "conditions": ["A05B", "A05C", "A06A"], + ... "procedures": ["P01", "P02"], + ... "label": 1, + ... }, + ... { + ... "patient_id": "p0", + ... "visit_id": "v1", + ... "conditions": ["A05B"], + ... "procedures": ["P01"], + ... "label": 0, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"conditions": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="ehr_example", + ... ) + >>> model = Transformer(dataset=dataset) + >>> # ... train the model ... + >>> + >>> interpreter = AttentionRollout(model) + >>> batch = next(iter(get_dataloader(dataset, batch_size=2))) + >>> + >>> attributions = interpreter.attribute(**batch) + >>> # Returns dict: {"conditions": tensor, "procedures": tensor} + >>> print(attributions["conditions"].shape) # [batch, num_tokens] + >>> + >>> # target_class_idx is accepted but ignored (rollout is class-agnostic) + >>> same = interpreter.attribute(target_class_idx=1, **batch) + """ + + def __init__(self, model: BaseModel, head_fusion: str = "mean"): + if head_fusion != "mean": + raise ValueError( + f"Unsupported head_fusion='{self.head_fusion}'. " + "Currently supported values: mean." + ) required_methods = [ "set_attention_hooks", "get_attention_layers", "get_relevance_tensor", ] - missing_methods = [ - method for method in required_methods if not hasattr(model, method) - ] + missing_methods = [m for m in required_methods if not hasattr(model, m)] if missing_methods: raise TypeError( @@ -34,13 +117,149 @@ def __init__(self, model: BaseModel, head_fusion="mean"): f"{', '.join(required_methods)}. " f"Missing: {', '.join(missing_methods)}." ) - - self.model = model + + super().__init__(model) self.head_fusion = head_fusion + def attribute( self, target_class_idx: Optional[int] = None, **data, ) -> Dict[str, torch.Tensor]: - pass \ No newline at end of file + """Compute class-agnostic attention rollout attributions. + + Args: + target_class_idx: Accepted for API compatibility with class-specific + interpreters. Attention rollout is class-agnostic, so this argument + is ignored. + **data: Batch input passed directly to the model. + + Returns: + Dict[str, torch.Tensor]: A dict keyed by the model's feature keys. + Each value holds the rollout relevance for that feature — the + CLS-token row of the composed attention-rollout matrix, reduced + to one score per token by ``model.get_relevance_tensor()`` and + then expanded to the raw input value shape by + ``_map_to_input_shapes``. For flat sequence features this is + ``[batch, num_tokens]``; for nested sequences the per-visit + score is replicated across the codes within each visit. + Scores are non-negative and, before the input-shape expansion, + sum to 1 across tokens (a consequence of composing + row-stochastic matrices). + """ + + self.model.set_attention_hooks(True) + try: + self.model(**data) + finally: + self.model.set_attention_hooks(False) + + attention_layers = self.model.get_attention_layers() + R = {} + + for feature_key, layers in attention_layers.items(): + rollout = None + + for attn_map, _ in layers: + if attn_map is None: + raise RuntimeError( + "AttentionRollout expected attention maps to be captured " + f"for feature '{feature_key}', but found None." + ) + + attn = self._fuse_heads(attn_map) + attn = self._add_residual(attn) + + if rollout is None: + batch_size, seq_len, _ = attn.shape + rollout = torch.eye( + seq_len, + device=attn.device, + dtype=attn.dtype, + ) + rollout = rollout.unsqueeze(0).expand( + batch_size, + seq_len, + seq_len, + ) + + rollout = torch.bmm(attn, rollout) + + if rollout is None: + raise RuntimeError( + "AttentionRollout expected at least one attention layer " + f"for feature '{feature_key}', but found none." + ) + + R[feature_key] = rollout + + attributions = self.model.get_relevance_tensor(R, **data) + return self._map_to_input_shapes(attributions, data) + + def _fuse_heads(self, attn_map: torch.Tensor) -> torch.Tensor: + """Fuse attention heads from [batch, heads, seq, seq] to [batch, seq, seq].""" + + if self.head_fusion == "mean": + return attn_map.mean(dim=1) + + raise ValueError( + f"Unsupported head_fusion='{self.head_fusion}'. " + "Currently supported values: mean." + ) + + def _map_to_input_shapes( + self, + attributions: Dict[str, torch.Tensor], + data: dict, + ) -> Dict[str, torch.Tensor]: + """Expand attributions to match raw input value shapes. + + For nested sequences the attention operates on a pooled + (visit-level) sequence, but downstream consumers (e.g. ablation + metrics) expect attributions to match the raw input value shape. + Per-visit relevance scores are replicated across all codes + within each visit. + + Args: + attributions: Per-feature attribution tensors returned by + ``model.get_relevance_tensor()``. + data: Original ``**data`` kwargs from the dataloader batch. + + Returns: + Attributions expanded to raw input value shapes where needed. + """ + result: Dict[str, torch.Tensor] = {} + for key, attr in attributions.items(): + feature = data.get(key) + if feature is not None: + if isinstance(feature, torch.Tensor): + val = feature + else: + schema = self.model.dataset.input_processors[key].schema() + val = ( + feature[schema.index("value")] + if "value" in schema + else None + ) + if val is not None and val.dim() > attr.dim(): + for _ in range(val.dim() - attr.dim()): + attr = attr.unsqueeze(-1) + attr = attr.expand_as(val) + result[key] = attr + return result + + @staticmethod + def _add_residual(attn: torch.Tensor) -> torch.Tensor: + """ + Add canonical rollout residual connection: 0.5 * (A + I). + 0.5 * (A + I) stays row-stochastic only because A is (soft-max ouput). + """ + + batch, seq_len, _ = attn.shape + identity = torch.eye( + seq_len, + device=attn.device, + dtype=attn.dtype, + ).unsqueeze(0) + return 0.5 * (attn + identity) \ No newline at end of file From 67475dfd5d8f6f95a4e1c912d86a68ea01003c17 Mon Sep 17 00:00:00 2001 From: Felipe Bonchristiano Date: Sun, 7 Jun 2026 19:52:47 -0300 Subject: [PATCH 3/7] tests/core/test_attention_rollout.py done --- pyhealth/interpret/methods/attention_rollout.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/interpret/methods/attention_rollout.py b/pyhealth/interpret/methods/attention_rollout.py index dba9d336f..b7150c70e 100644 --- a/pyhealth/interpret/methods/attention_rollout.py +++ b/pyhealth/interpret/methods/attention_rollout.py @@ -97,9 +97,9 @@ class AttentionRollout(BaseInterpreter): """ def __init__(self, model: BaseModel, head_fusion: str = "mean"): - if head_fusion != "mean": + if head_fusion != "mean": raise ValueError( - f"Unsupported head_fusion='{self.head_fusion}'. " + f"Unsupported head_fusion='{head_fusion}'. " "Currently supported values: mean." ) From 7089c67e00601378c9e4c75c9dcca341cea00c19 Mon Sep 17 00:00:00 2001 From: Felipe Bonchristiano Date: Sun, 7 Jun 2026 20:12:15 -0300 Subject: [PATCH 4/7] attention rollout integrated into example scripts --- examples/interpretability/dka_stageattn_mimic4_interpret.py | 1 + examples/interpretability/dka_transformer_mimic4_interpret.py | 1 + examples/interpretability/los_stageattn_mimic4_interpret.py | 1 + examples/interpretability/los_transformer_mimic4_interpret.py | 1 + examples/interpretability/mp_stageattn_mimic4_interpret.py | 1 + examples/interpretability/mp_transformer_mimic4_interpret.py | 1 + 6 files changed, 6 insertions(+) diff --git a/examples/interpretability/dka_stageattn_mimic4_interpret.py b/examples/interpretability/dka_stageattn_mimic4_interpret.py index 3b405bd52..dcfeecac9 100644 --- a/examples/interpretability/dka_stageattn_mimic4_interpret.py +++ b/examples/interpretability/dka_stageattn_mimic4_interpret.py @@ -135,6 +135,7 @@ def count_labels(ds): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), + "attention_rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/dka_transformer_mimic4_interpret.py b/examples/interpretability/dka_transformer_mimic4_interpret.py index d2617d652..0aa22f10a 100644 --- a/examples/interpretability/dka_transformer_mimic4_interpret.py +++ b/examples/interpretability/dka_transformer_mimic4_interpret.py @@ -135,6 +135,7 @@ def count_labels(ds): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), + "attention_rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_stageattn_mimic4_interpret.py b/examples/interpretability/los_stageattn_mimic4_interpret.py index 51d253f25..462050c15 100644 --- a/examples/interpretability/los_stageattn_mimic4_interpret.py +++ b/examples/interpretability/los_stageattn_mimic4_interpret.py @@ -121,6 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), + "attention_rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_transformer_mimic4_interpret.py b/examples/interpretability/los_transformer_mimic4_interpret.py index ccb06c707..6cab9572e 100644 --- a/examples/interpretability/los_transformer_mimic4_interpret.py +++ b/examples/interpretability/los_transformer_mimic4_interpret.py @@ -121,6 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), + "attention_rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_stageattn_mimic4_interpret.py b/examples/interpretability/mp_stageattn_mimic4_interpret.py index e42b9aca6..5fedc21ff 100644 --- a/examples/interpretability/mp_stageattn_mimic4_interpret.py +++ b/examples/interpretability/mp_stageattn_mimic4_interpret.py @@ -121,6 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), + "attention_rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_transformer_mimic4_interpret.py b/examples/interpretability/mp_transformer_mimic4_interpret.py index dcdb55215..ad6413789 100644 --- a/examples/interpretability/mp_transformer_mimic4_interpret.py +++ b/examples/interpretability/mp_transformer_mimic4_interpret.py @@ -121,6 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), + "attention_rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } From 76c79827e21127d99c023f4bbe5c077c6845b63a Mon Sep 17 00:00:00 2001 From: Felipe Bonchristiano Date: Sun, 7 Jun 2026 20:24:17 -0300 Subject: [PATCH 5/7] attention rollout docs --- docs/api/interpret.rst | 2 ++ pyhealth/interpret/methods/attention_rollout.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/docs/api/interpret.rst b/docs/api/interpret.rst index 747f05b56..d121d6ddb 100644 --- a/docs/api/interpret.rst +++ b/docs/api/interpret.rst @@ -57,6 +57,7 @@ New to interpretability in PyHealth? Check out these complete examples: - Train a ViT model on COVID-19 chest X-ray classification - Use CheferRelevance for gradient-weighted attention attribution - Visualize which image patches contribute to predictions + **LIME Example:** - ``examples/lime_stagenet_mimic4.py`` - Demonstrates LIME (Local Interpretable Model-agnostic Explanations) for StageNet mortality prediction. Shows how to: @@ -78,6 +79,7 @@ Attribution Methods interpret/pyhealth.interpret.methods.gim interpret/pyhealth.interpret.methods.basic_gradient interpret/pyhealth.interpret.methods.chefer + interpret/pyhealth.interpret.methods.attention_rollout interpret/pyhealth.interpret.methods.deeplift interpret/pyhealth.interpret.methods.integrated_gradients interpret/pyhealth.interpret.methods.shap diff --git a/pyhealth/interpret/methods/attention_rollout.py b/pyhealth/interpret/methods/attention_rollout.py index b7150c70e..902cc2dc0 100644 --- a/pyhealth/interpret/methods/attention_rollout.py +++ b/pyhealth/interpret/methods/attention_rollout.py @@ -24,6 +24,15 @@ class AttentionRollout(BaseInterpreter): It serves as the standard baseline that gradient-based attention methods are compared against. + .. note:: + "Gradient-free" refers to the attribution **math**: no backward pass + is run and no gradients enter the rollout computation. It does **not** + mean the call is safe inside ``torch.no_grad()``. The shared + attention-readout plumbing registers a gradient hook on the attention + tensors during the forward pass, so running ``attribute(**batch)`` + under ``torch.no_grad()`` raises a ``RuntimeError``. Call it under the + default (grad-enabled) context. + This interpreter works with any model that exposes the attention-readout methods ``set_attention_hooks``, ``get_attention_layers``, and ``get_relevance_tensor`` (currently :class:`~pyhealth.models.Transformer` @@ -147,6 +156,12 @@ def attribute( Scores are non-negative and, before the input-shape expansion, sum to 1 across tokens (a consequence of composing row-stochastic matrices). + + Note: + Do not call this method inside a ``torch.no_grad()`` context. Even + though rollout uses no gradients, enabling attention hooks registers + a gradient hook during the forward pass, which requires grad-enabled + tensors and otherwise raises a ``RuntimeError``. """ self.model.set_attention_hooks(True) From 68134f75eb32339e3ef5921b4bb9b57ace865634 Mon Sep 17 00:00:00 2001 From: Felipe Bonchristiano Date: Sun, 7 Jun 2026 20:25:14 -0300 Subject: [PATCH 6/7] attention rollout docs/interpret/pyhealth.interpret.methods.attention_rollout.rst added --- ...th.interpret.methods.attention_rollout.rst | 84 +++++ tests/core/test_attention_rollout.py | 311 ++++++++++++++++++ 2 files changed, 395 insertions(+) create mode 100644 docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst create mode 100644 tests/core/test_attention_rollout.py diff --git a/docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst b/docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst new file mode 100644 index 000000000..35875f3d5 --- /dev/null +++ b/docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst @@ -0,0 +1,84 @@ +pyhealth.interpret.methods.attention_rollout +============================================= + +Overview +-------- + +Attention Rollout provides token-level relevance scores for Transformer models +in PyHealth. It quantifies how attention propagates information across layers by +composing the per-layer attention matrices (with a residual-connection +correction), yielding a single importance score per input token (e.g. diagnosis +codes, procedure codes, medications) for a given patient sample. + +Unlike :class:`~pyhealth.interpret.methods.CheferRelevance`, which is +gradient-weighted and **class-specific**, attention rollout is **forward-pass +only**, **gradient-free**, and **class-agnostic**: it explains how information +flows through the attention mechanism independent of any target class. It serves +as the standard baseline that gradient-based attention methods are compared +against, and complements Chefer rather than replacing it. + +This method is particularly useful for: + +- **Clinical decision support**: Understanding which medical codes drove a particular prediction +- **Model debugging**: Identifying whether the model attends to clinically meaningful features +- **Feature importance**: Ranking tokens by how much attention flows to them +- **Trust and transparency**: Providing interpretable, class-agnostic explanations for model predictions + +The implementation follows the paper by Abnar & Zuidema (2020): "Quantifying +Attention Flow in Transformers" (https://arxiv.org/abs/2005.00928). + +Key Features +------------ + +- **Multi-modal support**: Works with multiple feature types (conditions, procedures, drugs, labs, etc.) +- **Gradient-free**: Computed from a single forward pass; no backward pass is used in the attribution math +- **Class-agnostic**: Independent of the predicted/target class (``target_class_idx`` is accepted but ignored) +- **Layer-wise composition**: Composes per-layer attention as ``rollout = Â_L @ ... @ Â_1`` with the residual correction ``Â = 0.5 * (A + I)`` +- **Distribution over tokens**: Because each ``Â`` is row-stochastic, so is their product; per-token relevance sums to 1 (before the input-shape expansion) +- **Model-agnostic by duck-typing**: Works with any model exposing the attention-readout methods ``set_attention_hooks``, ``get_attention_layers`` and ``get_relevance_tensor`` (currently :class:`~pyhealth.models.Transformer` and :class:`~pyhealth.models.StageAttentionNet`), not just one named model + +Usage Notes +----------- + +1. **Batch size**: For interpretability, use ``batch_size=1`` to get per-sample explanations. +2. **Do not wrap in** ``torch.no_grad()``: Although rollout is gradient-free in its math, the shared attention-readout plumbing registers a gradient hook on the attention tensors during the forward pass, so calling ``attribute(**batch)`` inside ``torch.no_grad()`` raises a ``RuntimeError``. Call it under the default (grad-enabled) context; no backward pass is performed. +3. **Model compatibility**: Works with any model that exposes ``set_attention_hooks``, ``get_attention_layers`` and ``get_relevance_tensor`` — not restricted to the Transformer. Incompatible models raise ``TypeError`` at construction. +4. **Class specification**: ``target_class_idx`` is accepted for API compatibility but ignored, since rollout is class-agnostic. + +Quick Start +----------- + +.. code-block:: python + + from pyhealth.models import Transformer + from pyhealth.interpret.methods import AttentionRollout + from pyhealth.datasets import get_dataloader + + # Assume you have a trained transformer model and dataset + model = Transformer(dataset=sample_dataset, ...) + # ... train the model ... + + # Create interpretability object + rollout = AttentionRollout(model) + + # Get a test sample (batch_size=1) + test_loader = get_dataloader(test_dataset, batch_size=1, shuffle=False) + batch = next(iter(test_loader)) + + # Compute attributions (target_class_idx is accepted but ignored) + scores = rollout.attribute(**batch) + + # Analyze results + for feature_key, attribution in scores.items(): + print(f"{feature_key}: {attribution.shape}") + top_tokens = attribution[0].topk(5).indices + print(f" Top 5 most relevant tokens: {top_tokens}") + +API Reference +------------- + +.. autoclass:: pyhealth.interpret.methods.AttentionRollout + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/tests/core/test_attention_rollout.py b/tests/core/test_attention_rollout.py new file mode 100644 index 000000000..45c745163 --- /dev/null +++ b/tests/core/test_attention_rollout.py @@ -0,0 +1,311 @@ +# Author: Felipe Amaral Bonchristiano +# NetID: felipea5 +# Description: Unit tests for the AttentionRollout interpretability method +# (Abnar & Zuidema, 2020, https://arxiv.org/abs/2005.00928). + +import unittest + +import torch +import torch.nn as nn + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.interpret.methods import AttentionRollout +from pyhealth.models import Transformer + + +def _make_dataset(samples, input_schema): + """Build a tiny sample dataset.""" + return create_sample_dataset( + samples=samples, + input_schema=input_schema, + output_schema={"label": "binary"}, + dataset_name="rollout_test", + ) + + +class TestAttentionRollout(unittest.TestCase): + """Tests for :class:`AttentionRollout`.""" + + def setUp(self): + torch.manual_seed(42) + + self.samples = [ + { + "patient_id": "p0", + "visit_id": "v0", + "conditions": ["A05B", "A05C", "A06A"], + "procedures": ["P01", "P02"], + "label": 1, + }, + { + "patient_id": "p1", + "visit_id": "v0", + "conditions": ["A05B"], + "procedures": ["P01"], + "label": 0, + }, + ] + self.input_schema = { + "conditions": "sequence", + "procedures": "sequence", + } + self.dataset = _make_dataset(self.samples, self.input_schema) + self.model = Transformer( + dataset=self.dataset, + embedding_dim=8, + heads=2, + num_layers=2, + ) + self.interpreter = AttentionRollout(self.model) + self.loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + self.batch = next(iter(self.loader)) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _capture_relevance(self, model): + """Wrap ``get_relevance_tensor`` to capture pre/post-expansion tensors. + + Returns ``(captured_R, captured_pre)`` dicts that are populated as a + side effect of the next ``attribute`` call: + + * ``captured_R`` — the composed rollout matrices ``[batch, seq, seq]``. + * ``captured_pre`` — the per-token relevance *before* the + input-shape expansion done by ``_map_to_input_shapes``. + """ + original = model.get_relevance_tensor + captured_R = {} + captured_pre = {} + + def spy(R, **data): + for key, value in R.items(): + captured_R[key] = value.detach().clone() + out = original(R, **data) + for key, value in out.items(): + captured_pre[key] = value.detach().clone() + return out + + model.get_relevance_tensor = spy + return captured_R, captured_pre + + # ------------------------------------------------------------------ + # Tests + # ------------------------------------------------------------------ + + def test_returns_dict_keyed_by_feature_keys(self): + """attribute() returns a dict keyed by exactly the model feature keys.""" + + attributions = self.interpreter.attribute(**self.batch) + + self.assertIsInstance(attributions, dict) + self.assertEqual( + set(attributions.keys()), + set(self.model.feature_keys), + ) + + def test_output_shape_matches_input_seq_length(self): + """Each attribution matches its input feature's shape (seq length).""" + + attributions = self.interpreter.attribute(**self.batch) + + for key in self.model.feature_keys: + self.assertIsInstance(attributions[key], torch.Tensor) + # Flat sequence inputs are [batch, seq_len] - the attribution + # must line up token-for-token with the raw input. + self.assertEqual(attributions[key].shape, self.batch[key].shape) + self.assertEqual(attributions[key].shape[0], 2) # batch size + + def test_multi_feature_key_model(self): + """A model with several feature streams yields one entry per stream.""" + + attributions = self.interpreter.attribute(**self.batch) + + self.assertIn("conditions", attributions) + self.assertIn("procedures", attributions) + self.assertEqual(len(attributions), 2) + + def test_row_stochastic_invariant(self): + """Pre-expansion relevance is a distribution over tokens (sums to 1).""" + + _, captured_pre = self._capture_relevance(self.model) + self.interpreter.attribute(**self.batch) + + self.assertTrue(captured_pre) # something was captured + for key, relevance in captured_pre.items(): + token_sums = relevance.sum(dim=-1) + self.assertTrue( + torch.allclose(token_sums, torch.ones_like(token_sums), atol=1e-5), + msg=f"feature '{key}' relevance does not sum to 1: {token_sums}", + ) + # Rollout produces non-negative relevance. + self.assertTrue(torch.all(relevance >= 0)) + + def test_rollout_matrices_are_row_stochastic(self): + """Every composed rollout matrix has rows summing to 1.""" + + captured_R, _ = self._capture_relevance(self.model) + self.interpreter.attribute(**self.batch) + + for key, rollout in captured_R.items(): + row_sums = rollout.sum(dim=-1) + self.assertTrue( + torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5), + msg=f"feature '{key}' rollout not row-stochastic: {row_sums}", + ) + + def test_identity_attention_gives_identity_rollout(self): + """If attention is the identity at every layer, rollout is identity.""" + + original_layers = self.model.get_attention_layers + + def identity_layers(): + # Reuse real shapes captured during the forward pass, but + # overwrite each attention map with an identity per head. + real = original_layers() + patched = {} + for key, layers in real.items(): + new_layers = [] + for attn_map, grad in layers: + batch, heads, seq, _ = attn_map.shape + eye = ( + torch.eye(seq, dtype=attn_map.dtype, device=attn_map.device) + .reshape(1, 1, seq, seq) + .expand(batch, heads, seq, seq) + .contiguous() + ) + new_layers.append((eye, grad)) + patched[key] = new_layers + return patched + + self.model.get_attention_layers = identity_layers + captured_R, _ = self._capture_relevance(self.model) + self.interpreter.attribute(**self.batch) + + self.assertTrue(captured_R) + for key, rollout in captured_R.items(): + batch, seq, _ = rollout.shape + expected = ( + torch.eye(seq, dtype=rollout.dtype) + .unsqueeze(0) + .expand(batch, seq, seq) + ) + self.assertTrue( + torch.allclose(rollout, expected, atol=1e-6), + msg=f"feature '{key}' rollout is not identity", + ) + + def test_single_layer(self): + """num_layers=1 still produces valid, row-stochastic attributions.""" + + torch.manual_seed(42) + model = Transformer( + dataset=self.dataset, embedding_dim=8, heads=2, num_layers=1 + ) + interpreter = AttentionRollout(model) + _, captured_pre = self._capture_relevance(model) + + attributions = interpreter.attribute(**self.batch) + + self.assertEqual(set(attributions.keys()), set(model.feature_keys)) + for relevance in captured_pre.values(): + token_sums = relevance.sum(dim=-1) + self.assertTrue( + torch.allclose(token_sums, torch.ones_like(token_sums), atol=1e-5) + ) + + def test_single_head(self): + """heads=1 (no head fusion needed) produces valid attributions.""" + + torch.manual_seed(42) + model = Transformer( + dataset=self.dataset, embedding_dim=8, heads=1, num_layers=2 + ) + interpreter = AttentionRollout(model) + _, captured_pre = self._capture_relevance(model) + + attributions = interpreter.attribute(**self.batch) + + self.assertEqual(set(attributions.keys()), set(model.feature_keys)) + for relevance in captured_pre.values(): + token_sums = relevance.sum(dim=-1) + self.assertTrue( + torch.allclose(token_sums, torch.ones_like(token_sums), atol=1e-5) + ) + + def test_masked_padded_sequence(self): + """Padded batches (uneven sequence lengths) stay row-stochastic.""" + # Sanity check that padding actually occurred. + self.assertEqual(self.batch["conditions"].shape[1], 3) + + _, captured_pre = self._capture_relevance(self.model) + self.interpreter.attribute(**self.batch) + + for key, relevance in captured_pre.items(): + token_sums = relevance.sum(dim=-1) + self.assertTrue( + torch.allclose(token_sums, torch.ones_like(token_sums), atol=1e-5), + msg=f"padded feature '{key}' relevance does not sum to 1", + ) + + def test_target_class_idx_is_a_noop(self): + """Rollout is class-agnostic: target_class_idx must not change output.""" + + baseline = self.interpreter.attribute(**self.batch) + with_target = self.interpreter.attribute(target_class_idx=1, **self.batch) + + for key in baseline: + self.assertTrue( + torch.allclose(baseline[key], with_target[key], atol=1e-6), + msg=f"target_class_idx changed attributions for '{key}'", + ) + + def test_incompatible_model_raises_type_error(self): + """Model lacking the attention-readout methods raises TypeError.""" + + class PlainModel(nn.Module): + def forward(self, **data): + return {"logit": torch.zeros(1, 1)} + + with self.assertRaises(TypeError): + AttentionRollout(PlainModel()) + + def test_unsupported_head_fusion_raises_value_error(self): + """An unsupported head_fusion value raises ValueError (not AttributeError).""" + with self.assertRaises(ValueError): + AttentionRollout(self.model, head_fusion="max") + + def test_model_is_in_eval_mode(self): + """Constructing the interpreter puts the model in eval mode, disabling dropout + and making attributions deterministic for a given input. + """ + + self.assertFalse(self.model.training) + + def test_attribute_is_deterministic(self): + """Repeated calls on the same batch produce identical attributions.""" + + first = self.interpreter.attribute(**self.batch) + second = self.interpreter.attribute(**self.batch) + + for key in first: + self.assertTrue( + torch.allclose(first[key], second[key], atol=1e-6), + msg=f"attributions for '{key}' are not deterministic", + ) + + def test_callable_interface_matches_attribute(self): + """Calling the interpreter directly is equivalent to attribute().""" + + via_attribute = self.interpreter.attribute(**self.batch) + via_call = self.interpreter(**self.batch) + + self.assertEqual(set(via_attribute.keys()), set(via_call.keys())) + for key in via_attribute: + self.assertTrue( + torch.allclose(via_attribute[key], via_call[key], atol=1e-6) + ) + + +if __name__ == "__main__": + unittest.main() From f865c3bfa9b331b554ef4e54342e7f81a6355f44 Mon Sep 17 00:00:00 2001 From: Felipe Bonchristiano Date: Sun, 7 Jun 2026 20:39:11 -0300 Subject: [PATCH 7/7] Stip trailing whitespace and rename example keys to rollout --- .../dka_stageattn_mimic4_interpret.py | 2 +- .../dka_transformer_mimic4_interpret.py | 2 +- .../los_stageattn_mimic4_interpret.py | 2 +- .../los_transformer_mimic4_interpret.py | 2 +- .../mp_stageattn_mimic4_interpret.py | 2 +- .../mp_transformer_mimic4_interpret.py | 2 +- pyhealth/interpret/methods/attention_rollout.py | 12 ++++++------ 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/interpretability/dka_stageattn_mimic4_interpret.py b/examples/interpretability/dka_stageattn_mimic4_interpret.py index dcfeecac9..3001a8d46 100644 --- a/examples/interpretability/dka_stageattn_mimic4_interpret.py +++ b/examples/interpretability/dka_stageattn_mimic4_interpret.py @@ -135,7 +135,7 @@ def count_labels(ds): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), - "attention_rollout": AttentionRollout(model), + "rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/dka_transformer_mimic4_interpret.py b/examples/interpretability/dka_transformer_mimic4_interpret.py index 0aa22f10a..1a7f87799 100644 --- a/examples/interpretability/dka_transformer_mimic4_interpret.py +++ b/examples/interpretability/dka_transformer_mimic4_interpret.py @@ -135,7 +135,7 @@ def count_labels(ds): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), - "attention_rollout": AttentionRollout(model), + "rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_stageattn_mimic4_interpret.py b/examples/interpretability/los_stageattn_mimic4_interpret.py index 462050c15..35a7c95ec 100644 --- a/examples/interpretability/los_stageattn_mimic4_interpret.py +++ b/examples/interpretability/los_stageattn_mimic4_interpret.py @@ -121,7 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), - "attention_rollout": AttentionRollout(model), + "rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/los_transformer_mimic4_interpret.py b/examples/interpretability/los_transformer_mimic4_interpret.py index 6cab9572e..cb5911cd2 100644 --- a/examples/interpretability/los_transformer_mimic4_interpret.py +++ b/examples/interpretability/los_transformer_mimic4_interpret.py @@ -121,7 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), - "attention_rollout": AttentionRollout(model), + "rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_stageattn_mimic4_interpret.py b/examples/interpretability/mp_stageattn_mimic4_interpret.py index 5fedc21ff..e7eaa8541 100644 --- a/examples/interpretability/mp_stageattn_mimic4_interpret.py +++ b/examples/interpretability/mp_stageattn_mimic4_interpret.py @@ -121,7 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), - "attention_rollout": AttentionRollout(model), + "rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/examples/interpretability/mp_transformer_mimic4_interpret.py b/examples/interpretability/mp_transformer_mimic4_interpret.py index ad6413789..e3ffb4f85 100644 --- a/examples/interpretability/mp_transformer_mimic4_interpret.py +++ b/examples/interpretability/mp_transformer_mimic4_interpret.py @@ -121,7 +121,7 @@ def main(): "deeplift": DeepLift(model, use_embeddings=True), "gim": GIM(model), "chefer": CheferRelevance(model), - "attention_rollout": AttentionRollout(model), + "rollout": AttentionRollout(model), "shap": ShapExplainer(model, use_embeddings=True), "lime": LimeExplainer(model, use_embeddings=True, n_samples=200), } diff --git a/pyhealth/interpret/methods/attention_rollout.py b/pyhealth/interpret/methods/attention_rollout.py index 902cc2dc0..c853ee605 100644 --- a/pyhealth/interpret/methods/attention_rollout.py +++ b/pyhealth/interpret/methods/attention_rollout.py @@ -104,7 +104,7 @@ class AttentionRollout(BaseInterpreter): >>> # target_class_idx is accepted but ignored (rollout is class-agnostic) >>> same = interpreter.attribute(target_class_idx=1, **batch) """ - + def __init__(self, model: BaseModel, head_fusion: str = "mean"): if head_fusion != "mean": raise ValueError( @@ -126,7 +126,7 @@ def __init__(self, model: BaseModel, head_fusion: str = "mean"): f"{', '.join(required_methods)}. " f"Missing: {', '.join(missing_methods)}." ) - + super().__init__(model) self.head_fusion = head_fusion @@ -211,18 +211,18 @@ def attribute( attributions = self.model.get_relevance_tensor(R, **data) return self._map_to_input_shapes(attributions, data) - + def _fuse_heads(self, attn_map: torch.Tensor) -> torch.Tensor: """Fuse attention heads from [batch, heads, seq, seq] to [batch, seq, seq].""" if self.head_fusion == "mean": return attn_map.mean(dim=1) - + raise ValueError( f"Unsupported head_fusion='{self.head_fusion}'. " "Currently supported values: mean." ) - + def _map_to_input_shapes( self, attributions: Dict[str, torch.Tensor], @@ -263,7 +263,7 @@ def _map_to_input_shapes( attr = attr.expand_as(val) result[key] = attr return result - + @staticmethod def _add_residual(attn: torch.Tensor) -> torch.Tensor: """