Skip to content

feat(training): enable class weights computed from label frequency to handle class imbalance #262

@hanneshapke

Description

@hanneshapke

Summary

Compute inverse-frequency class weights from the training data label distribution and pass them to the existing MaskedSparseCategoricalCrossEntropy loss function via config.class_weights. The infrastructure already exists — this is a configuration and weight-computation change with no loss function modifications needed.

Problem

The training data is heavily imbalanced. The chance field in label_utils.py sets FIRSTNAME and SURNAME generation probability to 0.5 while all other entity types are at 0.1. This means the model sees ~5× more name entities than SSNs, IBANs, security tokens, or driver's license numbers during training.

Without class weighting, the model optimizes disproportionately for frequent labels. The loss from getting a FIRSTNAME token wrong contributes much more to the total gradient (by sheer volume) than the loss from missing an IBAN token. The result: high recall on names, poor recall on rare but high-sensitivity PII types.

Existing infrastructure

The class weight mechanism is fully implemented but unused:

  • config.py: class_weights: dict[int, float] = field(default_factory=dict) — defaults to empty
  • model.py: MaskedSparseCategoricalCrossEntropy._build_weight_tensor() creates a tensor of shape (num_classes,) initialized to 1.0, then fills in per-class weights from the config dict
  • model.py: In the forward pass, per-token loss is multiplied by weight_tensor[y_true_safe], applying the class-specific weight
  • trainer.py: config.class_weights is passed to the PII loss function during initialization

Passing a non-empty class_weights dict is all that's needed to activate this.

How to calculate class weights

Step 1: Count label frequencies from the tokenized training set

After tokenization, count how many tokens carry each BIO label in the training split. The O label and padding (-100) dominate, so they should be handled separately.

from collections import Counter
import numpy as np

def compute_class_weights(dataset, id_to_label, scheme="inverse_sqrt"):
    """Compute class weights from a tokenized training dataset.

    Args:
        dataset: HuggingFace Dataset with 'pii_labels' column
        id_to_label: mapping from label ID to label string
        scheme: weighting scheme — "inverse", "inverse_sqrt", or "effective"
    """
    counts = Counter()
    for sample in dataset:
        for label_id in sample["pii_labels"]:
            if label_id == -100:
                continue  # skip padding
            counts[label_id] += 1

    total = sum(counts.values())
    num_classes = len(counts)

    weights = {}
    for label_id, count in counts.items():
        if id_to_label.get(label_id) == "O":
            # O is the majority class — keep weight at 1.0 or lower
            weights[label_id] = 1.0
            continue

        freq = count / total

        if scheme == "inverse":
            # Classic inverse frequency: w = total / (num_classes * count)
            weights[label_id] = total / (num_classes * count)
        elif scheme == "inverse_sqrt":
            # Smoothed: less aggressive than pure inverse
            weights[label_id] = np.sqrt(total / (num_classes * count))
        elif scheme == "effective":
            # Effective number of samples (Cui et al., 2019)
            beta = 0.999
            effective_num = (1 - beta**count) / (1 - beta)
            weights[label_id] = 1.0 / effective_num

    # Normalize so the mean weight across entity classes is 1.0
    entity_weights = [w for lid, w in weights.items() if id_to_label.get(lid) != "O"]
    mean_w = np.mean(entity_weights) if entity_weights else 1.0
    for lid in weights:
        if id_to_label.get(lid) != "O":
            weights[lid] = weights[lid] / mean_w

    return weights

Step 2: Choose a weighting scheme

Three common approaches, from most to least aggressive:

Scheme Formula When to use
Inverse frequency w_c = N / (K × n_c) Severely imbalanced data; risk of overfitting rare classes
Inverse square root (recommended) w_c = √(N / (K × n_c)) Moderate imbalance; good default for NER
Effective number (Cui et al.) w_c = (1 - β) / (1 - β^{n_c}) Theoretically grounded; handles long-tail well

Where N = total tokens, K = number of classes, n_c = count of class c, β ≈ 0.999.

Recommended: inverse square root. Pure inverse frequency can over-correct, making the model hallucinate rare entities. Square root dampens the effect while still meaningfully upweighting underrepresented types.

Step 3: Example output

Given a hypothetical label distribution:

O:              800,000 tokens  →  weight: 1.0  (kept at baseline)
B-FIRSTNAME:     25,000 tokens  →  weight: 0.8
I-FIRSTNAME:     12,000 tokens  →  weight: 1.1
B-SURNAME:       24,000 tokens  →  weight: 0.8
I-SURNAME:       10,000 tokens  →  weight: 1.2
B-EMAIL:          5,000 tokens  →  weight: 1.7
I-EMAIL:         15,000 tokens  →  weight: 1.0
B-SSN:            1,200 tokens  →  weight: 3.5
I-SSN:            3,600 tokens  →  weight: 2.0
B-IBAN:             800 tokens  →  weight: 4.3
I-IBAN:           4,000 tokens  →  weight: 1.9
B-SECURITYTOKEN:    500 tokens  →  weight: 5.4
I-SECURITYTOKEN:  2,000 tokens  →  weight: 2.7
...

Rare entity types like IBAN and SECURITYTOKEN get 3–5× the weight of common types like FIRSTNAME, meaning a missed IBAN token contributes proportionally more to the loss.

Implementation

Option A: Precompute and set in config (simplest)

Run the weight computation once on the training data, then hardcode the result in training_config.toml:

[class_weights]
0 = 1.0    # O
1 = 0.8    # B-FIRSTNAME
2 = 1.1    # I-FIRSTNAME
3 = 0.8    # B-SURNAME
5 = 1.7    # B-EMAIL
17 = 3.5   # B-SSN
# ...

Option B: Compute automatically during training (recommended)

Add a step in trainer.py or preprocessing.py that computes weights from the training split before instantiating the loss function, and injects them into config.class_weights. This ensures weights stay in sync with the data as the dataset evolves.

Notes

  • The O label should always have weight ≤ 1.0 — upweighting O would drown out entity signals
  • Padding tokens (-100) are already masked out in the loss function and should be excluded from frequency counting
  • Class weights apply only to the PII loss, not the coref loss (this is already how the trainer is wired)
  • Monitor per-class precision after enabling weights — aggressive upweighting of rare classes can increase false positives for those types

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions