Skip to content

feat(model): add CRF layer on top of token classifier for valid BIO sequence decoding #256

@hanneshapke

Description

@hanneshapke

Summary

Add a Conditional Random Field (CRF) layer on top of the PII token classifier to learn transition constraints between labels and enforce valid BIO sequences during inference.

Problem

The current architecture in model.py independently classifies each token: DistilBERT produces a hidden state per token, and a linear layer maps each to scores over the 33+ BIO labels. The argmax of each token's scores becomes the prediction. Nothing prevents illegal sequences such as:

  • I-EMAIL following B-PHONENUMBER — an "inside" tag can only follow its matching "begin" or "inside" tag
  • I-SSN appearing as the first token of an entity — must start with B-SSN
  • B-FIRSTNAMEI-SURNAME — mixing entity types mid-span

With 24 entity types and synthetic training data, there will always be edge cases where independent token classification produces invalid sequences. For a privacy tool, a single missed token in a phone number means the number leaks through.

Proposed solution

Add a learnable transition matrix of shape (num_labels, num_labels) via pytorch-crf. The entry T[i][j] represents the score for transitioning from label i to label j. During training, the CRF computes the negative log-likelihood of the correct entire label sequence. During inference, the Viterbi algorithm replaces per-token argmax to find the globally optimal label sequence.

Time complexity is O(n × k²) where n = sequence length, k = number of labels. For ~33 labels and 512 max length this adds negligible overhead.

Integration points

model.py

  • Add from torchcrf import CRF (new dep: pytorch-crf)
  • Add self.crf = CRF(num_pii_labels, batch_first=True) to the model
  • Training forward pass: use self.crf(emissions, pii_labels, mask=attention_mask.bool()) for loss
  • Add a decode() method using self.crf.decode(emissions, mask) for Viterbi inference

trainer.py

  • compute_loss: use the CRF's built-in negative log-likelihood instead of per-token cross-entropy for the PII task. Coref task stays as-is.
  • compute_metrics: replace np.argmax(pii_predictions, axis=2) with model.decode(emissions, mask) since Viterbi decoding replaces argmax.

quantize.py

  • Main complication: the CRF's Viterbi decode loop is hard to export to ONNX cleanly.
  • Practical workaround: export only the emission scores to ONNX and implement Viterbi decoding in Go (in the proxy's inference code), using the learned transition matrix exported as a separate JSON/numpy file.
  • The Viterbi algorithm is ~50 lines of code in any language.

Expected impact

  • Entity-level F1 improvement of 1–3 points (typical for CRF on BERT-based NER)
  • Larger gains expected given Kiji's 24 entity types
  • Most pronounced improvement on multi-token entities (phone numbers, addresses, IBANs) where the entire span must be predicted as a unit
  • Virtually eliminates invalid BIO sequences at decode time

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions