Target audience: Engineers who need to integrate a new VLA model into FlashRT (e.g. Pi0.6, a fresh open-source VLA).
Time estimate: A single
(framework, hardware)combination runs around 800-1200 lines of code, or 1-2 weeks of work, assuming the model's structure is close to Pi0.5 / Pi0. All four combinations (torch/jax × thor/rtx) take roughly 3-4 weeks.Read in this order (don't skip ahead — each doc assumes the previous):
- This doc §0–§1 (you are here) — the repository contract and which files you'll touch
flash_rt/frontends/torch/_template/README.md— the starter package. Open in a separate window before reading further; the rest of this doc references it. The template has 4 stub files (frontend / pipeline / weights_spec / attention) you copy and fill in.docs/stable_api.md— public API surface and naming conventions you must respectdocs/calibration.md— how FP8 calibration works (read §2 + §10 before writing your_calibratetwin)docs/kernel_fusion.md— kernel naming + decision tree for whichfvk.*to call where (skim §1 + §2; reference the rest as you writepipeline.py)docs/plugin_model_template.md— only if you're shipping a closed-source model as an external plugin instead of a PR into this repo. Also contains the first-light cosine routing table (Q&A section) — the most useful debugging reference once your model first runs.
Suggested first-week schedule for an ML-infra engineer with the source model already running in PyTorch:
- Day 1 — read items 1–3, copy the template, list every weight tensor in your checkpoint
- Day 2–3 — fill in
weights_spec.py(declarative, mostly mechanical) andattention.py(~60 lines) - Day 4–5 — write
pipeline.py(the bulk of the work;encoder_forwardfirst, leavedecoder_forwardfor day 6) - Day 6–7 — wire up
frontend.py, run first inference, debug cosine using the routing table
Mandatory rules:
1. Every (model, hardware) compute path lives in its own file:
models/<model>/pipeline_<hw>.py
— The suffix is required (_thor / _rtx). There is no default
pipeline.py entry point.
— No runtime hardware forks such as `if self._has_sm100` or
`if arch == 'thor'`.
— Even if two hardware implementations are 90% identical, they must
still be split. Shared logic is reused through function calls or
imports, not through in-file branching.
2. Every (model, framework, hardware) IO path = its own frontend:
frontends/<framework>/<model>_<hw>.py
class name: <Model><Fw>Frontend<Hw>
— Example: frontends/torch/pi05_rtx.py contains Pi05TorchFrontendRtx
— Same rule: split thor and rtx frontends even when they duplicate
most of their code.
3. hardware/<hw>/shared_primitives.py is a closed set:
— Only model-agnostic helpers belong here:
_gpu_* helpers, _measure_scale_gpu,
siglip_forward (usable by any model with a SigLIP tower),
encoder_forward (usable by any Paligemma-encoder model),
encoder_forward_calibrate
— Model-specific forward / decoder functions are not allowed in this
file. They go into models/<m>/pipeline_<hw>.py.
4. _PIPELINE_MAP is strictly one-to-one:
("model", "framework", "hw") -> ("flash_rt.frontends.<fw>.<m>_<hw>",
"ClassName")
Each tuple points to exactly one file and one class. Multiple tuples
sharing a class (i.e. runtime forking) is forbidden.
Known historical exception (do NOT copy this pattern):
pi0fast: frontends/torch/pi0fast.py is a single file with 14 runtime
`_has_sm100` branches. Deprecated layout — pending split into
per-hardware files. New models must follow rules 1-4 above.
Walking through a hypothetical pi06 model (Paligemma backbone) that needs to support both Thor and RTX, under both torch and jax:
flash_rt/
├── hardware/__init__.py # 4 new lines in _PIPELINE_MAP
├── hardware/thor/attn_backend.py # add make_pi06_attention_spec (if shapes change)
├── hardware/rtx/attn_backend.py # (RTX) same
├── models/pi06/
│ ├── pipeline_thor.py # NEW — Thor forward functions
│ └── pipeline_rtx.py # NEW — RTX Pi06Pipeline class
├── frontends/torch/
│ ├── _pi06_thor_spec.py # NEW — Thor torch WEIGHT_SPEC
│ ├── _pi06_rtx_spec.py # NEW — RTX torch WEIGHT_SPEC (if dims / iface differ)
│ ├── pi06_thor.py # NEW — Thor torch frontend
│ └── pi06_rtx.py # NEW — RTX torch frontend
├── frontends/jax/
│ ├── _pi06_thor_spec.py # NEW
│ ├── _pi06_rtx_spec.py # NEW
│ ├── pi06_thor.py # NEW
│ └── pi06_rtx.py # NEW
├── configs/pi06.yaml # metadata
└── tests/test_all_models_precision.py # add one segment
All four combinations together: ~3500-4500 lines.
A single (framework, hardware) combination: ~800-1200 lines (of which ~120 lines are declarative spec).
Reference implementations:
- pi05 — all four combinations complete:
models/pi05/{pipeline_thor.py, pipeline_rtx.py}plus four frontends - pi0 — Thor is done, RTX is being refactored in stage 8
- groot — Thor and RTX are done (jax only on Thor)
Before reading §2, copy the starter template:
# For a new model called "mymodel" on Thor:
cp -r flash_rt/frontends/torch/_template /tmp/mymodel_thor_work
cd /tmp/mymodel_thor_work
$EDITOR README.md # 5-min read; explains the file splitThen work file-by-file in this order (each file's docstring tells you exactly what to translate from your source model):
-
weights_spec.py→flash_rt/frontends/torch/_<mymodel>_thor_spec.pyThe declarative weight loader. Map eachstate_dict[...]key from your reference checkpoint to aWEIGHT_SPECrow. Pure mechanical work; ~60-120 lines for a Pi0.5-shape model. -
attention.py→ appendmake_<mymodel>_attention_spec()toflash_rt/hardware/thor/attn_backend.py~60 lines. Declares oneadd_site()call per distinct attention shape in your model (vision, encoder, decoder, etc.). -
pipeline.py→flash_rt/models/<mymodel>/pipeline_thor.pyThe hard part. Translate your reference model'sforward()into a sequence offvk.*kernel calls. The template's# WHAT YOU TRANSLATEblock at the top shows the line-by-line mapping pattern. You'll write two functions per stage: a production forward (FP8, captured into CUDA Graph) and a calibration twin (BF16 + measures activation amax). 200-400 lines per hardware target. -
frontend.py→flash_rt/frontends/torch/<mymodel>_thor.pyWires it all together. Owns weight upload, buffer allocation, calibration cache, and CUDA Graph capture. Class name must be<Model>TorchFrontendThorper §0 rule 2. ~400-700 lines.
After all four files compile and your first infer() produces non-NaN output, run cosine vs your PyTorch FP32 reference. Use the first-light cosine routing table in plugin_model_template.md to narrow down where to look — that table maps the cos number you see directly to the most likely root cause.
For RTX, repeat with pipeline_rtx.py / <mymodel>_rtx.py. For JAX, the template covers torch only — copy from frontends/jax/pi05_thor.py for the JAX patterns (Orbax loading, weight cache, etc.).
§2 below provides the detailed reference for each step the template guides you through. Use it as a lookup, not a tutorial — you should already have copied the template before reading further.
File: flash_rt/hardware/thor/attn_backend.py (Thor) or flash_rt/hardware/rtx/attn_backend.py (RTX).
Copy make_pi05_attention_spec and adjust:
def make_pi06_attention_spec(num_views: int, *,
enc_total_keys: int, dec_total_keys: int) -> AttentionSpec:
"""Pi0.6: 24L encoder / 24L decoder / H_dim=256 / GQA 8:1."""
S_sig = num_views * 256
NH_sig, HD_sig = 16, 72 # SigLIP 1152/16
NH_enc, HD_enc = 8, 256 # Paligemma 2048/8
NH_dec, HD_dec = 8, 256 # Action expert
return AttentionSpec(
sites=[
SiteSpec(
name="siglip", layer_count=27, q_seq_len=S_sig, kv_seq_len=S_sig,
num_heads=NH_sig, head_dim=HD_sig,
extra={"kernel": None}, # fmha_strided_full dispatcher
),
SiteSpec(
name="encoder", layer_count=24, q_seq_len=..., # Se filled at runtime
kv_seq_len=enc_total_keys,
num_heads=NH_enc, head_dim=HD_enc, num_kv_heads=1,
extra={"kernel": "standard"},
),
SiteSpec(
name="decoder", layer_count=24, q_seq_len=10,
kv_seq_len=dec_total_keys,
num_heads=NH_dec, head_dim=HD_dec, num_kv_heads=1,
extra={"kernel": "standard"},
),
],
)Allowed values for extra["kernel"] (see backend.py:AttentionBackend for the full table):
| kernel value | underlying fvk primitive | used for |
|---|---|---|
None (siglip only) |
fmha_strided_full |
SigLIP 2D vision attention |
"standard" |
attention_qkv_fp16 |
GQA encoder/decoder, no state padding |
"state_masked" |
attention_qkv_fp16_state_masked |
Pi0 decoder (the first token is state) |
"mha" |
attention_mha_fp16 |
GROOT Qwen3 full-MHA plus DiT self/cross |
Do not invent new kernel values. If you need a new variant, extend the dispatch branches in hardware/thor/attn_backend.py:ThorFlashAttnBackend.run.
Files:
flash_rt/models/pi06/pipeline_thor.py(Thor path)flash_rt/models/pi06/pipeline_rtx.py(RTX path)
Each hardware gets its own file, even if the two paths end up looking similar. Genuinely shared code lives in hardware/<hw>/shared_primitives.py (reserved for truly model-agnostic helpers) or is imported between sibling functions.
Recent references to copy from:
- Thor:
models/pi0/pipeline_thor.py— Pi0 decoder forward - Thor:
models/pi05/pipeline_thor.py— Pi0.5postln_project/decoder_forward/decoder_forward_calibrate - RTX:
models/pi05/pipeline_rtx.py— thePi05Pipelineclass (framework-agnostic, takes AttnBackend via injection) - RTX:
models/groot/pipeline_rtx.py— GROOT's three-graph flow
Every forward function must obey the pointer-interface contract:
def decoder_forward_pi06(
gemm: fvk.GemmRunner,
fvk_module, # flash_rt_kernels
bufs: dict, # {'x': ptr, 'xn': ptr, ...}
weights: dict, # {'qw': ptr, 'ow': ptr, 'gu': ptr, 'd': ptr, ...}
dims: dict, # {'S': 10, 'Da': 1024, 'Ha': 4096, 'La': 24, ...}
scales_dev: int, # device pointer to fp32 scale array
*,
attn=None, # AttentionBackend; None = legacy fallback
stream: int = 0,
):
"""Every argument is a raw pointer or a Python primitive that ctypes can pass
through — this is what makes the function safe to call repeatedly during
CUDA Graph capture."""
...Why this interface: CUDA Graph capture requires the same Python function, calling the same sequence of kernels, with the same pointers, on every replay. Tensor objects can be garbage-collected or reallocated between replays; raw .data_ptr() values cannot.
Catalog of kernels available for building forwards: docs/kernel_fusion.md lists all 93 public fvk functions and the legal fusion patterns.
Key rules:
- All intermediate buffers must be pre-allocated in
frontend._load_weights. The forward only reads pointers — no dynamic allocation. - Never call
.cpu(),.numpy(),torch.empty(), orsync()inside a forward. - Attention goes through
attn.run(site=..., layer_idx=i, ...). Do not callfvk.attention_qkv_fp16(...)directly. - Full rule set:
docs/kernel_fusion.md§5 known-failed optimizations
Files:
flash_rt/frontends/torch/_pi06_thor_spec.pyflash_rt/frontends/torch/_pi06_rtx_spec.py(only if dims or weight interface differ)
When the two hardwares share the exact same weight interface (common — both sides read the same safetensors checkpoint), a single spec file can back both frontends via from ._pi06_thor_spec import build_spec. The default is still one spec file per hardware, to avoid a future dim change on one side accidentally regressing the other.
If the backbone is in the Paligemma / Gemma family (very likely):
from flash_rt.executors.weight_loader import Item, LayerBlock, ModelWeightSpec
from flash_rt.executors.torch_weights import (
FlatCat, FusedGateUp, FusedQKV, Quant, TensorList, ToFp16, tT,
)
from flash_rt.frontends.torch._thor_spec_common import (
paligemma_encoder_block, paligemma_siglip_block,
)
def _decoder_block():
dp = "paligemma_with_expert.gemma_expert.model.layers.{i}"
return LayerBlock(
prefix_fmt="", num_layers=24, name="decoder",
items=[
Item("qkv_w",
FusedQKV(q=f"{dp}.self_attn.q_proj.weight",
k=f"{dp}.self_attn.k_proj.weight",
v=f"{dp}.self_attn.v_proj.weight",
norm_fuse=f"{dp}.input_layernorm.weight",
interleave_q_heads=8,
interleave_k_heads=1),
[tT(), Quant()],
FlatCat("_dec_qkv_flat"), scale_into="_ae_w_scales"),
# ... follow the pattern in _pi0_thor_spec.py
],
)
def build_spec() -> ModelWeightSpec:
return ModelWeightSpec(
framework="torch",
blocks=[
paligemma_siglip_block(),
paligemma_encoder_block(num_layers=24),
_decoder_block(),
],
)If the backbone is a novel architecture (Qwen3, etc.): look at frontends/torch/groot_thor.py::_load_qwen3_weights, which is still a hand-written loop rather than a declarative spec. You will likely need to either:
- add a new shared block builder to
_thor_spec_common.py, or - define a new composite source (something like
FusedQKV) — seeflash_rt/executors/torch_weights.py.
Op order must be byte-identical: compare your spec, op by op, against the legacy hand-written loader — .T.contiguous() vs .t().contiguous(), ToFp16 before or after Quant, exactly when norm_fuse is applied. A single character wrong causes precision regressions.
Files:
flash_rt/frontends/torch/pi06_thor.py(class:Pi06TorchFrontendThor)flash_rt/frontends/torch/pi06_rtx.py(class:Pi06TorchFrontendRtx)flash_rt/frontends/jax/pi06_thor.py(class:Pi06JaxFrontendThor)flash_rt/frontends/jax/pi06_rtx.py(class:Pi06JaxFrontendRtx)
Class-name rule: <Model><Framework>Frontend<HW> in CamelCase — e.g. Pi05TorchFrontendThor, Pi05TorchFrontendRtx, GrootJaxFrontendThor.
Skeleton: copy the nearest sibling (same framework, same hardware) and edit:
| What to change | Where | Lines |
|---|---|---|
__init__ |
num_views, checkpoint path |
a few |
_load_norm_stats |
new norm_stats path (if it moved) | 20 |
_load_weights |
call _pi06_<hw>_spec.build_spec(); update dim constants; update _sig_weights keys |
120 |
set_prompt |
tokenizer; time_mlp precompute; call _calibrate and _capture_*_graph |
100 |
_calibrate |
call encoder_forward_calibrate / decoder_forward_calibrate |
150 |
_capture_*_graph |
update dims; call models/pi06/pipeline_<hw>.py::{encoder,decoder}_forward_pi06 |
135 |
_autotune_enc_ae |
copy unchanged | 50 |
infer |
input preprocessing / noise / action decode | 80 |
get_latency_stats |
copy unchanged | 15 |
Things you must never do:
- Allocate new tensors inside
infer(violates the CUDA Graph contract). - Change graph topology inside
_calibrate(triggers Myelin tactic drift). - Skip
.contiguous()(column-major vs row-major layout bugs). - Detect hardware at runtime inside a frontend (
hasattr(fvk, ...)) and branch on it — this is the pi0fast anti-pattern. New models must ship two independent thor/rtx frontends.
Where the JAX side diverges from torch:
- The source is
OrbaxDictSource(engine_w), whereengine_wis thedict[str, ndarray]exported by openpi. The key names are not HF safetensors paths; they follow openpi's internal schema (e.g.vision.layer.{i}.qkv.weight). See_thor_spec_common.py. engine_wtypically has QKV and gate_up already fused (openpi does this during export). So the spec does not needFusedQKV/FusedGateUp— plainJaxQuant()is enough.- The sink is
CudaBufferFlat/CudaBufferAttrplus an explicitcache=...argument (weight caching relies on it). - The frontend must set
self._cache_blobs = {}before callingWeightLoader.run(...).
Weight cache behavior: the default is weight_cache=True. The first load takes ~30-40s; results are cached to ~/.flash_rt/weights/<hash>_nv<N>.bin, and subsequent loads take ~5s. When modifying a spec you must preserve the cache key schema (sig_wt_fp8.{0..11}, etc.); otherwise the cache format changes and users have to wipe it manually.
File: flash_rt/hardware/__init__.py
_PIPELINE_MAP: dict[...] = {
... # existing entries
# ── Pi0.6 ──
("pi06", "torch", "thor"):
("flash_rt.frontends.torch.pi06_thor", "Pi06TorchFrontendThor"),
("pi06", "torch", "rtx_sm120"):
("flash_rt.frontends.torch.pi06_rtx", "Pi06TorchFrontendRtx"),
("pi06", "jax", "thor"):
("flash_rt.frontends.jax.pi06_thor", "Pi06JaxFrontendThor"),
("pi06", "jax", "rtx_sm120"):
("flash_rt.frontends.jax.pi06_rtx", "Pi06JaxFrontendRtx"),
}One entry, one class. Two entries pointing at the same class is the pi0fast anti-pattern.
Then confirm that config="pi06" is accepted in api.py::load_model — the function validates configs near the top.
File: flash_rt/configs/pi06.yaml
Copy pi05.yaml as a starting point. Fields: num_layers, hidden_dim, num_heads, head_dim, ffn_hidden_dim, num_views, action_horizon, vocab_size, and so on.
This YAML is used only for logging and metadata. Runtime dimensions still come from the constants hard-coded inside the frontend.
File: tests/test_all_models_precision.py
- Near the top add
PI06_SCRIPT = """...""": load the pipe, set a prompt, run 5 warmup iterations, patch the RNG, record 20 latency samples, compute cosine similarity against the pytorch reference. - Add
'pi06': ('Pi0.6', PI06_SCRIPT)to the_configsdict. - Save the pytorch reference to
/tmp/pi06_pytorch_ref.npy.
On the 5090 side, additionally add a pi06 segment to your local smoke / cosine / latency benchmark scripts (you'll likely have your own; the public test suite covers smoke and unit tests, latency benchmarks are typically per-team).
# CPU unit tests (seconds)
python tests/test_weight_loader.py # 16/16
python tests/test_thor_attn_backend.py # 12/12
python tests/test_thor_groot_attn_backend.py # 11/11
# 5090 GPU validation (if you added an RTX path)
python examples/quickstart.py --checkpoint <ckpt> --config pi06 \
--benchmark 200 # smoke + latency
# Cosine: load the model, run predict() with matched_noise, compare
# against your PyTorch FP32 reference run on the same observation.
# Thor GPU precision sweep (~3-5 minutes)
free -h | head -2 # always check free memory before heavy Thor runs
python tests/test_all_models_precision.py --model pi06Thresholds:
- First-time bring-up of a new model: cos ≥ 0.995 (vs pytorch ref), and P50 inside the target latency budget.
- Models structurally derived from Pi0.5 / Pi0: the "bit-identical" band (~0.9986) indicates the FP8 bytes match exactly.
If cosine falls out of the window:
- Don't guess. First check the spec's op order byte-for-byte against the legacy loader.
- Use an A/B comparison to rule out Myelin tactic jitter — run 2-3 times back-to-back.
- If it really is a regression, revert the commit immediately. Don't try to patch it in a follow-up.
Recompiling the same MLIR → Myelin picks a different tactic → ±2ms P50 drift and ~0.001 cos jitter. This is specific to Thor.
Don't:
- Draw conclusions from a single measurement (always A/B).
- "Fix" a ±0.001 jitter in a new commit (it's almost certainly tactic drift, not code).
- Compare latency numbers taken at different times.
Do:
- Use a timing cache to pin the tactic (though you cannot choose the optimal one directly).
- Keep a reference timing cache around (see
deployment/scripts/l2v2_timing_cache.bin).
Thor has 122Gi of unified memory. Loading two models concurrently will OOM. Tests must run serially.
flash_rt/flash_rt_kernels.cpython-312-aarch64-linux-gnu.so (3.6MB) is a production binary. Adding a new model should not trigger a kernel rebuild — every fvk function you need is already in this .so. If you genuinely need a new kernel, that's a separate CUDA development flow, with explicit version backups.
Assuming the new model is structurally similar to Pi0.5 / Pi0 (Paligemma backbone, flow-matching decoder), for a single (framework, hardware) combination:
| Phase | Estimate |
|---|---|
| (1)(6)(7) Skeleton and registration | half a day |
| (2) Pipeline forward — forked from Pi0 with dim-constant edits | 1-2 days |
| (3) WEIGHT_SPEC authoring + byte-diff validation | 1 day |
| (4) Frontend — fork Pi0, edit dims / calibration / graph capture | 3-4 days |
| (8) Tests and debugging | 2-3 days |
| Total per combination | ~1-2 weeks |
All four combinations (torch/jax × thor/rtx): roughly 3-4 weeks — subsequent frontends reuse a lot of code.
If the backbone is a new architecture (Qwen3-like), add 1-2 more weeks for shared-block extensions, kernel compatibility evaluation, and possibly a new attention variant.
- (1) New AttentionSpec added to the correct hardware's
attn_backend.py; unit tests pass. - (2) Pipeline forward functions use the pointer-only interface, do no dynamic allocation, and each hardware has its own
pipeline_<hw>.pyfile. - (3)
_<model>_<hw>_spec.pysmoke-builds viabuild_spec(). - (4) Frontend is fully implemented, each
(framework, hardware)has its own<m>_<hw>.pyfile, and all buffers are pre-allocated in_load_weights. - No file uses
if self._has_sm100orhasattr(fvk, '...')to branch on hardware. -
shared_primitives.pyhas not gained any model-specific functions. - (6) The four
_PIPELINE_MAPentries are one-to-one, with no two rows pointing at the same class. - (7) YAML dims match the constants in the code.
- (8)
test_all_models_precision.pypasses three consecutive A/B runs. - Weight-cache keys remain compatible with legacy (if the JAX spec changed).
- Commit format:
feat(<model>-<framework>-<hw>): ...
Q: Why are runtime hardware forks like if hasattr(fvk, 'cutlass_fp8_sq') disallowed?
A: Because of the lesson learned from pi0fast. A single file with many branches grows maintenance cost explosively: adding a new hardware means touching 14 spots; adding a new optimization means redoing it on every branch; stack traces no longer tell you which hardware path you were on; and CUDA Graphs capture different kernel sequences per hardware anyway, so if branching can't actually unify them. Splitting per hardware lets each file focus on exactly one execution path.
Q: The thor and rtx frontends are 90% identical — wouldn't merging them save a lot of code? A: Short-term, yes. But "shared between two ends" means adding a third hardware requires splitting again, every change risks breaking the other side, and the test matrix becomes N×M. With per-hardware files, adding a new hardware is just adding a new file while the existing files stay stable. The total line count is slightly higher, but maintenance entropy is dramatically lower.
Q: KeyError: ... at load time?
A: Some key in your WEIGHT_SPEC doesn't exist in the checkpoint. Inspect the actual safetensors keys:
python -c "from safetensors import safe_open; sf=safe_open('/path/to/model.safetensors', 'pt'); [print(k) for k in list(sf.keys())[:50]]"Q: After loading, cosine is around 0.5?
A: Likely causes: wrong QKV interleave (bad GQA head count), mixing .T.contiguous() with .t().contiguous(), or applying norm_fuse at the wrong point. Start with docs/calibration.md §4 precision journey.
Q: CUDA Graph capture fails?
A: Your forward contains a dynamic allocation or a conditional branch that causes capture to take a different kernel path. Details in kernel_fusion.md §6.
Q: JAX loading takes ~40s — too slow?
A: That's the expected first-load cost. Keep weight_cache=True (the default); subsequent loads are ~5s. If you changed the WEIGHT_SPEC's cache key, you need rm -rf ~/.flash_rt/weights/ so the cache can be rebuilt.
Q: New model OOMs on Thor?
A: Thor has 122Gi of unified memory. Check: (1) free -h shows free memory greater than model size × 1.5; (2) no other pipeline is running concurrently; (3) the previous weight_cache version has been cleaned up.