Skip to content

Add per-class token and sentence attributions with interactive visualizer#250

Merged
tmills merged 4 commits into
Machine-Learning-for-Medical-Language:mainfrom
ianbulovic:rest_interp
Apr 17, 2026
Merged

Add per-class token and sentence attributions with interactive visualizer#250
tmills merged 4 commits into
Machine-Learning-for-Medical-Language:mainfrom
ianbulovic:rest_interp

Conversation

@ianbulovic
Copy link
Copy Markdown
Contributor

Summary

Lots of help from Claude on this one, but carefully steered and scrutinized by me.

  • Adds two new interpretation methods to the REST API: gradient-based token attributions and leave-one-out sentence attributions. This is adapted from @tmills rest_interp branch, slightly modified to return signed per-class scores rather than a single unsigned saliency score
  • Adds a GET / endpoint that serves an interactive HTML visualizer for exploring model predictions and attributions
  • Adds tests for both attribution methods against the existing negation model

Token attributions (return_attributions=true)

Uses the input × gradient method: one forward pass through the encoder, then one backward pass per class label (reusing the computation graph via retain_graph=True). The signed dot product of gradient and embedding over the hidden dimension gives a per-token, per-class score in [-1, 1]:

  • Positive: token pushes the model toward that class
  • Negative: token pushes the model away from that class

Each token in the response includes token_id, character-level start/end offsets into the original input string, and a scores dict keyed by label. Special tokens ([CLS], [SEP]) are included with start == end == 0. Only classification tasks are supported; tagging and relations tasks log a warning and are skipped.

Sentence attributions (return_sentence_attributions=true)

Uses leave-one-out ablation: runs one batched forward pass over the full text and each ablated variant (one sentence removed), then for each class computes p(class | full) − p(class | ablated). Scores are per-class and signed with the same polarity convention. Single-sentence inputs return a score of 0.0 for all classes.

Visualizer (GET /)

A self-contained HTML page (no external dependencies) served at the root of each model's route prefix. Features:

  • Form controls for all /process query parameters
  • Classification results with prediction, per-class probability badges
  • Token attribution display: original text rendered with colored highlights (green/red by class score intensity), class selector dropdown, hover tooltips showing all class scores
  • Sentence attribution display with signed scores and class selector
  • Basic span/relation lists for tagging and relation tasks

The HTML file is bundled as package data via importlib.resources and declared in pyproject.toml so it is included correctly in installed wheels.

image

@tmills tmills merged commit 091b5e7 into Machine-Learning-for-Medical-Language:main Apr 17, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants