Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,45 @@ Where a reward signal is available we compute gradients using a weighted advanta
bergson build <output_path> --model <model_name> --dataset <dataset_name> --reward_column <reward_column_name>
```

## Numerical Stability

Some models produce inconsistent per-example gradients when sequences of different lengths are batched together. This is caused by optimized SDPA attention backends (flash, memory-efficient) computing slightly different results depending on the padding length.

Use the built-in diagnostic to check your model:

```bash
bergson test_numerical_stability --model <model_name>
```

This automatically tests escalating configurations and reports exactly which flags (if any) you need. If your model fails the default test, add the recommended flags to your `build`/`score`/`trackstar` commands:

```bash
bergson build <output_path> --model <model_name> --force_math_sdp
# or if needed:
bergson build <output_path> --model <model_name> --force_math_sdp --precision fp32
```

### Performance impact

The overhead of `--force_math_sdp` and `--precision fp32` varies by model. Benchmarked on A100-80GB with 500 documents from pile-10k:

| Model | Settings | Build time | vs bf16 baseline |
|-------|----------|------------|------------------|
| Pythia-160M | bf16 | 30.2s | — |
| Pythia-160M | bf16 + `--force_math_sdp` | 30.4s | +0.8% |
| Pythia-160M | fp32 | 35.6s | +17.9% |
| Pythia-160M | fp32 + `--force_math_sdp` | 39.6s | +31.1% |
| OLMo-2-1B | bf16 | 43.1s | — |
| OLMo-2-1B | bf16 + `--force_math_sdp` | 53.6s | +24.5% |
| OLMo-2-1B | fp32 | 132.8s | +208.1% |
| OLMo-2-1B | fp32 + `--force_math_sdp` | 141.8s | +229.0% |
| OLMo-2-7B | bf16 | 105.5s | — |
| OLMo-2-7B | bf16 + `--force_math_sdp` | 151.1s | +43.2% |
| OLMo-2-7B | fp32 | 569.2s | +439.5% |
| OLMo-2-7B | fp32 + `--force_math_sdp` | 603.6s | +472.1% |

Not all models are affected — run `bergson test_numerical_stability` before enabling these flags to avoid unnecessary overhead.

# Benchmarks

![CLI Benchmark](docs/benchmarks/cli_benchmark_NVIDIA_GH200_120GB.png)
Expand Down
26 changes: 25 additions & 1 deletion bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ScoreConfig,
TrackstarConfig,
)
from .diagnose import DiagnoseConfig, diagnose
from .hessians.hessian_approximations import approximate_hessians
from .magic import MagicConfig, run_magic
from .query.query_index import query
Expand Down Expand Up @@ -149,12 +150,35 @@ def execute(self):
run_magic(self.run_cfg, self.dist_cfg)


@dataclass
class Test_Numerical_Stability:
"""Test gradient consistency across padding and batch composition.

Tests whether a model produces consistent gradients regardless of how
documents are batched together. If inconsistencies are found, recommends
using --force_math_sdp on build/score/trackstar commands."""

diagnose_cfg: DiagnoseConfig

def execute(self):
"""Run the diagnostic."""
diagnose(self.diagnose_cfg)


@dataclass
class Main:
"""Routes to the subcommands."""

command: Union[
Build, Query, Preconditioners, Reduce, Score, Hessian, Trackstar, Magic
Build,
Query,
Preconditioners,
Reduce,
Score,
Hessian,
Trackstar,
Magic,
Test_Numerical_Stability,
]

def execute(self):
Expand Down
6 changes: 6 additions & 0 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ class IndexConfig:
overwrite: bool = False
"""Whether to overwrite any existing index in the run path."""

force_math_sdp: bool = False
"""Disable flash and memory-efficient SDPA backends, forcing the
math-only kernel. Some models produce inconsistent gradients across
different padding lengths when using optimized attention backends.
Run `bergson diagnose` to check whether your model needs this."""

distributed: DistributedConfig = field(default_factory=DistributedConfig)
"""Configuration for multi-node distributed preconditioner computation."""

Expand Down
Loading
Loading