Summary
Add a Conditional Random Field (CRF) layer on top of the PII token classifier to learn transition constraints between labels and enforce valid BIO sequences during inference.
Problem
The current architecture in model.py independently classifies each token: DistilBERT produces a hidden state per token, and a linear layer maps each to scores over the 33+ BIO labels. The argmax of each token's scores becomes the prediction. Nothing prevents illegal sequences such as:
I-EMAIL following B-PHONENUMBER — an "inside" tag can only follow its matching "begin" or "inside" tag
I-SSN appearing as the first token of an entity — must start with B-SSN
B-FIRSTNAME → I-SURNAME — mixing entity types mid-span
With 24 entity types and synthetic training data, there will always be edge cases where independent token classification produces invalid sequences. For a privacy tool, a single missed token in a phone number means the number leaks through.
Proposed solution
Add a learnable transition matrix of shape (num_labels, num_labels) via pytorch-crf. The entry T[i][j] represents the score for transitioning from label i to label j. During training, the CRF computes the negative log-likelihood of the correct entire label sequence. During inference, the Viterbi algorithm replaces per-token argmax to find the globally optimal label sequence.
Time complexity is O(n × k²) where n = sequence length, k = number of labels. For ~33 labels and 512 max length this adds negligible overhead.
Integration points
model.py
- Add
from torchcrf import CRF (new dep: pytorch-crf)
- Add
self.crf = CRF(num_pii_labels, batch_first=True) to the model
- Training forward pass: use
self.crf(emissions, pii_labels, mask=attention_mask.bool()) for loss
- Add a
decode() method using self.crf.decode(emissions, mask) for Viterbi inference
trainer.py
compute_loss: use the CRF's built-in negative log-likelihood instead of per-token cross-entropy for the PII task. Coref task stays as-is.
compute_metrics: replace np.argmax(pii_predictions, axis=2) with model.decode(emissions, mask) since Viterbi decoding replaces argmax.
quantize.py
- Main complication: the CRF's Viterbi decode loop is hard to export to ONNX cleanly.
- Practical workaround: export only the emission scores to ONNX and implement Viterbi decoding in Go (in the proxy's inference code), using the learned transition matrix exported as a separate JSON/numpy file.
- The Viterbi algorithm is ~50 lines of code in any language.
Expected impact
- Entity-level F1 improvement of 1–3 points (typical for CRF on BERT-based NER)
- Larger gains expected given Kiji's 24 entity types
- Most pronounced improvement on multi-token entities (phone numbers, addresses, IBANs) where the entire span must be predicted as a unit
- Virtually eliminates invalid BIO sequences at decode time
Summary
Add a Conditional Random Field (CRF) layer on top of the PII token classifier to learn transition constraints between labels and enforce valid BIO sequences during inference.
Problem
The current architecture in
model.pyindependently classifies each token: DistilBERT produces a hidden state per token, and a linear layer maps each to scores over the 33+ BIO labels. The argmax of each token's scores becomes the prediction. Nothing prevents illegal sequences such as:I-EMAILfollowingB-PHONENUMBER— an "inside" tag can only follow its matching "begin" or "inside" tagI-SSNappearing as the first token of an entity — must start withB-SSNB-FIRSTNAME→I-SURNAME— mixing entity types mid-spanWith 24 entity types and synthetic training data, there will always be edge cases where independent token classification produces invalid sequences. For a privacy tool, a single missed token in a phone number means the number leaks through.
Proposed solution
Add a learnable transition matrix of shape
(num_labels, num_labels)viapytorch-crf. The entryT[i][j]represents the score for transitioning from labelito labelj. During training, the CRF computes the negative log-likelihood of the correct entire label sequence. During inference, the Viterbi algorithm replaces per-token argmax to find the globally optimal label sequence.Time complexity is O(n × k²) where n = sequence length, k = number of labels. For ~33 labels and 512 max length this adds negligible overhead.
Integration points
model.pyfrom torchcrf import CRF(new dep:pytorch-crf)self.crf = CRF(num_pii_labels, batch_first=True)to the modelself.crf(emissions, pii_labels, mask=attention_mask.bool())for lossdecode()method usingself.crf.decode(emissions, mask)for Viterbi inferencetrainer.pycompute_loss: use the CRF's built-in negative log-likelihood instead of per-token cross-entropy for the PII task. Coref task stays as-is.compute_metrics: replacenp.argmax(pii_predictions, axis=2)withmodel.decode(emissions, mask)since Viterbi decoding replaces argmax.quantize.pyExpected impact