Summary
Compute inverse-frequency class weights from the training data label distribution and pass them to the existing MaskedSparseCategoricalCrossEntropy loss function via config.class_weights. The infrastructure already exists — this is a configuration and weight-computation change with no loss function modifications needed.
Problem
The training data is heavily imbalanced. The chance field in label_utils.py sets FIRSTNAME and SURNAME generation probability to 0.5 while all other entity types are at 0.1. This means the model sees ~5× more name entities than SSNs, IBANs, security tokens, or driver's license numbers during training.
Without class weighting, the model optimizes disproportionately for frequent labels. The loss from getting a FIRSTNAME token wrong contributes much more to the total gradient (by sheer volume) than the loss from missing an IBAN token. The result: high recall on names, poor recall on rare but high-sensitivity PII types.
Existing infrastructure
The class weight mechanism is fully implemented but unused:
config.py: class_weights: dict[int, float] = field(default_factory=dict) — defaults to empty
model.py: MaskedSparseCategoricalCrossEntropy._build_weight_tensor() creates a tensor of shape (num_classes,) initialized to 1.0, then fills in per-class weights from the config dict
model.py: In the forward pass, per-token loss is multiplied by weight_tensor[y_true_safe], applying the class-specific weight
trainer.py: config.class_weights is passed to the PII loss function during initialization
Passing a non-empty class_weights dict is all that's needed to activate this.
How to calculate class weights
Step 1: Count label frequencies from the tokenized training set
After tokenization, count how many tokens carry each BIO label in the training split. The O label and padding (-100) dominate, so they should be handled separately.
from collections import Counter
import numpy as np
def compute_class_weights(dataset, id_to_label, scheme="inverse_sqrt"):
"""Compute class weights from a tokenized training dataset.
Args:
dataset: HuggingFace Dataset with 'pii_labels' column
id_to_label: mapping from label ID to label string
scheme: weighting scheme — "inverse", "inverse_sqrt", or "effective"
"""
counts = Counter()
for sample in dataset:
for label_id in sample["pii_labels"]:
if label_id == -100:
continue # skip padding
counts[label_id] += 1
total = sum(counts.values())
num_classes = len(counts)
weights = {}
for label_id, count in counts.items():
if id_to_label.get(label_id) == "O":
# O is the majority class — keep weight at 1.0 or lower
weights[label_id] = 1.0
continue
freq = count / total
if scheme == "inverse":
# Classic inverse frequency: w = total / (num_classes * count)
weights[label_id] = total / (num_classes * count)
elif scheme == "inverse_sqrt":
# Smoothed: less aggressive than pure inverse
weights[label_id] = np.sqrt(total / (num_classes * count))
elif scheme == "effective":
# Effective number of samples (Cui et al., 2019)
beta = 0.999
effective_num = (1 - beta**count) / (1 - beta)
weights[label_id] = 1.0 / effective_num
# Normalize so the mean weight across entity classes is 1.0
entity_weights = [w for lid, w in weights.items() if id_to_label.get(lid) != "O"]
mean_w = np.mean(entity_weights) if entity_weights else 1.0
for lid in weights:
if id_to_label.get(lid) != "O":
weights[lid] = weights[lid] / mean_w
return weights
Step 2: Choose a weighting scheme
Three common approaches, from most to least aggressive:
| Scheme |
Formula |
When to use |
| Inverse frequency |
w_c = N / (K × n_c) |
Severely imbalanced data; risk of overfitting rare classes |
| Inverse square root (recommended) |
w_c = √(N / (K × n_c)) |
Moderate imbalance; good default for NER |
| Effective number (Cui et al.) |
w_c = (1 - β) / (1 - β^{n_c}) |
Theoretically grounded; handles long-tail well |
Where N = total tokens, K = number of classes, n_c = count of class c, β ≈ 0.999.
Recommended: inverse square root. Pure inverse frequency can over-correct, making the model hallucinate rare entities. Square root dampens the effect while still meaningfully upweighting underrepresented types.
Step 3: Example output
Given a hypothetical label distribution:
O: 800,000 tokens → weight: 1.0 (kept at baseline)
B-FIRSTNAME: 25,000 tokens → weight: 0.8
I-FIRSTNAME: 12,000 tokens → weight: 1.1
B-SURNAME: 24,000 tokens → weight: 0.8
I-SURNAME: 10,000 tokens → weight: 1.2
B-EMAIL: 5,000 tokens → weight: 1.7
I-EMAIL: 15,000 tokens → weight: 1.0
B-SSN: 1,200 tokens → weight: 3.5
I-SSN: 3,600 tokens → weight: 2.0
B-IBAN: 800 tokens → weight: 4.3
I-IBAN: 4,000 tokens → weight: 1.9
B-SECURITYTOKEN: 500 tokens → weight: 5.4
I-SECURITYTOKEN: 2,000 tokens → weight: 2.7
...
Rare entity types like IBAN and SECURITYTOKEN get 3–5× the weight of common types like FIRSTNAME, meaning a missed IBAN token contributes proportionally more to the loss.
Implementation
Option A: Precompute and set in config (simplest)
Run the weight computation once on the training data, then hardcode the result in training_config.toml:
[class_weights]
0 = 1.0 # O
1 = 0.8 # B-FIRSTNAME
2 = 1.1 # I-FIRSTNAME
3 = 0.8 # B-SURNAME
5 = 1.7 # B-EMAIL
17 = 3.5 # B-SSN
# ...
Option B: Compute automatically during training (recommended)
Add a step in trainer.py or preprocessing.py that computes weights from the training split before instantiating the loss function, and injects them into config.class_weights. This ensures weights stay in sync with the data as the dataset evolves.
Notes
- The
O label should always have weight ≤ 1.0 — upweighting O would drown out entity signals
- Padding tokens (
-100) are already masked out in the loss function and should be excluded from frequency counting
- Class weights apply only to the PII loss, not the coref loss (this is already how the trainer is wired)
- Monitor per-class precision after enabling weights — aggressive upweighting of rare classes can increase false positives for those types
Summary
Compute inverse-frequency class weights from the training data label distribution and pass them to the existing
MaskedSparseCategoricalCrossEntropyloss function viaconfig.class_weights. The infrastructure already exists — this is a configuration and weight-computation change with no loss function modifications needed.Problem
The training data is heavily imbalanced. The
chancefield inlabel_utils.pysets FIRSTNAME and SURNAME generation probability to 0.5 while all other entity types are at 0.1. This means the model sees ~5× more name entities than SSNs, IBANs, security tokens, or driver's license numbers during training.Without class weighting, the model optimizes disproportionately for frequent labels. The loss from getting a FIRSTNAME token wrong contributes much more to the total gradient (by sheer volume) than the loss from missing an IBAN token. The result: high recall on names, poor recall on rare but high-sensitivity PII types.
Existing infrastructure
The class weight mechanism is fully implemented but unused:
config.py:class_weights: dict[int, float] = field(default_factory=dict)— defaults to emptymodel.py:MaskedSparseCategoricalCrossEntropy._build_weight_tensor()creates a tensor of shape(num_classes,)initialized to 1.0, then fills in per-class weights from the config dictmodel.py: In the forward pass, per-token loss is multiplied byweight_tensor[y_true_safe], applying the class-specific weighttrainer.py:config.class_weightsis passed to the PII loss function during initializationPassing a non-empty
class_weightsdict is all that's needed to activate this.How to calculate class weights
Step 1: Count label frequencies from the tokenized training set
After tokenization, count how many tokens carry each BIO label in the training split. The
Olabel and padding (-100) dominate, so they should be handled separately.Step 2: Choose a weighting scheme
Three common approaches, from most to least aggressive:
w_c = N / (K × n_c)w_c = √(N / (K × n_c))w_c = (1 - β) / (1 - β^{n_c})Where
N= total tokens,K= number of classes,n_c= count of classc,β≈ 0.999.Recommended: inverse square root. Pure inverse frequency can over-correct, making the model hallucinate rare entities. Square root dampens the effect while still meaningfully upweighting underrepresented types.
Step 3: Example output
Given a hypothetical label distribution:
Rare entity types like IBAN and SECURITYTOKEN get 3–5× the weight of common types like FIRSTNAME, meaning a missed IBAN token contributes proportionally more to the loss.
Implementation
Option A: Precompute and set in config (simplest)
Run the weight computation once on the training data, then hardcode the result in
training_config.toml:Option B: Compute automatically during training (recommended)
Add a step in
trainer.pyorpreprocessing.pythat computes weights from the training split before instantiating the loss function, and injects them intoconfig.class_weights. This ensures weights stay in sync with the data as the dataset evolves.Notes
Olabel should always have weight ≤ 1.0 — upweightingOwould drown out entity signals-100) are already masked out in the loss function and should be excluded from frequency counting