Motivation
AITER is adding a CK-free build mode (ENABLE_CK=0) that removes Composable Kernel dependencies, reducing build time from ~35min to ~8min. ATOM needs corresponding changes to:
- Route around CK-dependent code paths when CK kernels are unavailable
- Maintain performance by preferring ASM PA (paged attention) and Triton kernels where possible
- Support clean Docker builds from pre-built wheels (zero compilation)
- Fix bugs discovered during CK-free validation
All changes are backward compatible — when ATOM_CK_FREE=0 (default), ATOM uses CK/ASM kernels as before.
Validated Locally
- MI300X (gfx942): Llama-3.1-8B, DeepSeek-R1-671B inference
- MI355X (gfx950): MXFP4 MOE with Swiglu paths
- Clean Docker (zero-compile from wheels): Llama-3.1-8B decode 80-82% of public image perf
Proposed PR Sequence
PR 1: Bug Fixes (Standalone, No CK-Free Dependency)
Files: atom/model_ops/linear.py, atom/model_ops/fused_moe_triton.py, atom/model_engine/scheduler.py
Individual fixes:
linear.py: UnboundLocalError in MergedReplicatedLinear.weight_loader() — missing else clause for per_Token and per_1x32 quant types. Compute shard_offset from self.output_sizes.
fused_moe_triton.py: Use CDNA4MXScaleLayout instead of GFX950MXScaleLayout for gfx950 arch detection (more accurate naming).
fused_moe_triton.py: Add update_opt_flags_constraints({"block_m": 128}) for MI355X — default CDNA4 block_m=256 exceeds 160KB LDS limit.
scheduler.py: Initialize num_rejected=0 to prevent UnboundLocalError in non-speculative path.
Dependency: None — these fix real bugs independent of CK-free work.
PR 2: MHA Attention Dispatch Decoupling
Files: atom/model_ops/attention_mha.py, atom/utils/envs.py
Key changes:
- Add
ATOM_CK_FREE env var (default=0) as master switch
- Decouple cache update from paged attention backend selection:
- Cache update: Always use Triton fused rope+cache (fast, no module_cache JIT dependency)
- Paged attention: Independently select ASM PA for decode when
head_dim=128 and no sliding window
AITER_FORCE_TRITON_ATTN env override for forcing Triton PA
- FP8 KV cache: fill per-token scale buffers with uniform per-tensor scale so ASM PA can dequantize correctly
- Prefill: always use
prefill_attention_triton (no CK/flash_attn_varlen_func dependency)
- Move
kv_scale tensor to CUDA at init for graph capture compatibility
- Create proper
block_tables for prefill with block_size=1
Dependency: AITER PR 1 (CK-free build gating)
PR 3: MLA CK-Free Paths
Files: atom/model_ops/attention_mla.py, atom/model_ops/attentions/aiter_mla.py
Key changes:
ATOM_USE_TRITON_MLA_DECODE or ATOM_CK_FREE → force Triton MLA decode using decode_attention_fwd_grouped_rope (AITER Triton kernel)
- MLA prefill fallback: Replace
flash_attn_varlen_func with PyTorch F.scaled_dot_product_attention (loops over sequences)
- FP8 constraint: only use fp8 scales when
max_seqlen_q == 1 (mla_decode_fwd limitation)
- Type casting: convert Q/KV to model dtype before
mla_prefill_fwd if dtype mismatch
aiter_mla.py: Build paged KV metadata (kv_indptr, kv_indices, block_tables) for MLA prefill paths
Dependency: PR 2
PR 4: MOE Cascade Routing (CK → FlyDSL → Triton)
Files: atom/model_ops/moe.py, atom/model_ops/flydsl_moe.py (NEW)
Key changes:
- Detection functions:
_has_ck_moe_sorting(), _has_flydsl_moe() with caching
- Cascade logic for
Fp8MoEMethod / CompressedTensorsFp8MoEMethod:
- Check CK MOE sorting availability
- If unavailable or
ATOM_CK_FREE=1: try FlyDSL MOE (ATOM_USE_FLYDSL_MOE=1)
- Else: fall back to Triton MOE (
_triton_fp8_moe())
- Weight shuffle skip: Triton expects standard row-major weights (not CK's shuffled layout)
flydsl_moe.py (NEW): FlyDSL MOE backend with torch-native sorting, per-token FP8 quant, 5-stage GEMM pipeline
_triton_fp8_moe(): Complete Triton MOE pipeline (sort → GEMM1+SiLU → GEMM2)
_per_token_group_quant_fp8(): Per-token-group FP8 quantization helper
Dependency: PR 2 (for ATOM_CK_FREE env var)
PR 5: Docker Infrastructure
Files: docker/Dockerfile, docker/Dockerfile.clean (NEW), docker/Dockerfile.wheels (NEW), .dockerignore (NEW)
Key changes:
Dockerfile: ARG ENABLE_CK=1 parameter, conditional git submodule, pass to AITER setup.py, install triton_kernels wheel
Dockerfile.wheels (NEW, ~160 lines): Multi-stage builder — PyTorch ROCm 7.2, Triton 3.5.x, FlyDSL, MORI, AITER (ENABLE_CK=0)
Dockerfile.clean (NEW, ~70 lines): Zero-compilation runtime from pre-built wheels via bind-mount
.dockerignore (NEW): Exclude .git/, build/, dist/ — reduces context from 67.9GB to 37.9GB
Build time comparison:
| Image |
Build Time |
Size |
| Current (full CK) |
~60 min |
Large |
| Dockerfile.wheels |
~60 min (one-time) |
Wheels only |
| Dockerfile.clean |
~10 min |
Minimal runtime |
Dependency: AITER PR 1 (ENABLE_CK support)
PR 6: Test Suite & CI (Nice-to-have)
Files: tests/test_ck_free_mode.py, tests/test_flydsl_moe.py, tests/test_attention_dispatch.py, tests/test_mla_prefill_routing.py, tests/test_aiter_mla_metadata.py, tests/test_moe_shapes.py (all NEW), pyproject.toml, .github/workflows/pre-checks.yaml, .github/workflows/atom-test.yaml
Key changes:
- 6 new test files (~970 lines) covering: env var detection, MOE routing, MHA dispatch, MLA prefill/decode, metadata construction, MOE shapes
pyproject.toml: gpu pytest marker for tagging GPU-requiring tests
pre-checks.yaml: Add unit-tests job running CPU-only pytest (-m "not gpu") — works on forks without GPU runners
atom-test.yaml: Guard to only run on ROCm/ATOM (not forks), re-enable golden output tests
Dependency: PR 2-4
PR 7: CI Nightly Sync (Nice-to-have)
Files: .github/workflows/sync-upstream.yaml (NEW), scripts/test_golden_output.sh (NEW)
- Nightly scheduled sync of fork main with upstream
- Golden output comparison script for CI regression testing
Dependency: None
Execution Path Summary
When ATOM_CK_FREE=1:
MHA: Triton fused rope+cache → ASM PA (head_dim=128) or Triton PA
MLA: PyTorch SDPA prefill → Triton decode (decode_attention_fwd_grouped_rope)
MOE: torch-native sorting → FlyDSL GEMM (if available) or Triton GEMM
When ATOM_CK_FREE=0 (default): unchanged behavior using CK/ASM kernels.
Performance Summary (MI300X, Llama-3.1-8B)
| Configuration |
Decode tok/s |
vs Public |
| Public image (full CK+ASM) |
~8,200 |
100% |
| CK-free, ASM PA, fp8 KV |
~6,830 |
~83% |
| CK-free, Triton PA, bf16 KV |
~6,255 |
~76% |
Remaining gap primarily due to: (1) no ASM GEMM for decode (M=1), (2) tuned GEMM CSV coverage only M≤256.
Open Questions
- Should FlyDSL MOE be the default fallback, or should we wait for more validation?
- Should Dockerfile.clean/Dockerfile.wheels live in ATOM or a separate build-infra repo?
- Priority of ASM GEMM re-enablement (v2 clean Docker) vs other optimizations?
Related
- AITER RFC: (will link after creation)
Motivation
AITER is adding a CK-free build mode (
ENABLE_CK=0) that removes Composable Kernel dependencies, reducing build time from ~35min to ~8min. ATOM needs corresponding changes to:All changes are backward compatible — when
ATOM_CK_FREE=0(default), ATOM uses CK/ASM kernels as before.Validated Locally
Proposed PR Sequence
PR 1: Bug Fixes (Standalone, No CK-Free Dependency)
Files:
atom/model_ops/linear.py,atom/model_ops/fused_moe_triton.py,atom/model_engine/scheduler.pyIndividual fixes:
linear.py:UnboundLocalErrorinMergedReplicatedLinear.weight_loader()— missing else clause forper_Tokenandper_1x32quant types. Computeshard_offsetfromself.output_sizes.fused_moe_triton.py: UseCDNA4MXScaleLayoutinstead ofGFX950MXScaleLayoutfor gfx950 arch detection (more accurate naming).fused_moe_triton.py: Addupdate_opt_flags_constraints({"block_m": 128})for MI355X — default CDNA4block_m=256exceeds 160KB LDS limit.scheduler.py: Initializenum_rejected=0to preventUnboundLocalErrorin non-speculative path.Dependency: None — these fix real bugs independent of CK-free work.
PR 2: MHA Attention Dispatch Decoupling
Files:
atom/model_ops/attention_mha.py,atom/utils/envs.pyKey changes:
ATOM_CK_FREEenv var (default=0) as master switchhead_dim=128and no sliding windowAITER_FORCE_TRITON_ATTNenv override for forcing Triton PAprefill_attention_triton(no CK/flash_attn_varlen_func dependency)kv_scaletensor to CUDA at init for graph capture compatibilityblock_tablesfor prefill withblock_size=1Dependency: AITER PR 1 (CK-free build gating)
PR 3: MLA CK-Free Paths
Files:
atom/model_ops/attention_mla.py,atom/model_ops/attentions/aiter_mla.pyKey changes:
ATOM_USE_TRITON_MLA_DECODEorATOM_CK_FREE→ force Triton MLA decode usingdecode_attention_fwd_grouped_rope(AITER Triton kernel)flash_attn_varlen_funcwith PyTorchF.scaled_dot_product_attention(loops over sequences)max_seqlen_q == 1(mla_decode_fwd limitation)mla_prefill_fwdif dtype mismatchaiter_mla.py: Build paged KV metadata (kv_indptr,kv_indices, block_tables) for MLA prefill pathsDependency: PR 2
PR 4: MOE Cascade Routing (CK → FlyDSL → Triton)
Files:
atom/model_ops/moe.py,atom/model_ops/flydsl_moe.py(NEW)Key changes:
_has_ck_moe_sorting(),_has_flydsl_moe()with cachingFp8MoEMethod/CompressedTensorsFp8MoEMethod:ATOM_CK_FREE=1: try FlyDSL MOE (ATOM_USE_FLYDSL_MOE=1)_triton_fp8_moe())flydsl_moe.py(NEW): FlyDSL MOE backend with torch-native sorting, per-token FP8 quant, 5-stage GEMM pipeline_triton_fp8_moe(): Complete Triton MOE pipeline (sort → GEMM1+SiLU → GEMM2)_per_token_group_quant_fp8(): Per-token-group FP8 quantization helperDependency: PR 2 (for ATOM_CK_FREE env var)
PR 5: Docker Infrastructure
Files:
docker/Dockerfile,docker/Dockerfile.clean(NEW),docker/Dockerfile.wheels(NEW),.dockerignore(NEW)Key changes:
Dockerfile:ARG ENABLE_CK=1parameter, conditionalgit submodule, pass to AITER setup.py, installtriton_kernelswheelDockerfile.wheels(NEW, ~160 lines): Multi-stage builder — PyTorch ROCm 7.2, Triton 3.5.x, FlyDSL, MORI, AITER (ENABLE_CK=0)Dockerfile.clean(NEW, ~70 lines): Zero-compilation runtime from pre-built wheels via bind-mount.dockerignore(NEW): Exclude.git/,build/,dist/— reduces context from 67.9GB to 37.9GBBuild time comparison:
Dependency: AITER PR 1 (ENABLE_CK support)
PR 6: Test Suite & CI (Nice-to-have)
Files:
tests/test_ck_free_mode.py,tests/test_flydsl_moe.py,tests/test_attention_dispatch.py,tests/test_mla_prefill_routing.py,tests/test_aiter_mla_metadata.py,tests/test_moe_shapes.py(all NEW),pyproject.toml,.github/workflows/pre-checks.yaml,.github/workflows/atom-test.yamlKey changes:
pyproject.toml:gpupytest marker for tagging GPU-requiring testspre-checks.yaml: Addunit-testsjob running CPU-only pytest (-m "not gpu") — works on forks without GPU runnersatom-test.yaml: Guard to only run on ROCm/ATOM (not forks), re-enable golden output testsDependency: PR 2-4
PR 7: CI Nightly Sync (Nice-to-have)
Files:
.github/workflows/sync-upstream.yaml(NEW),scripts/test_golden_output.sh(NEW)Dependency: None
Execution Path Summary
When
ATOM_CK_FREE=1:When
ATOM_CK_FREE=0(default): unchanged behavior using CK/ASM kernels.Performance Summary (MI300X, Llama-3.1-8B)
Remaining gap primarily due to: (1) no ASM GEMM for decode (M=1), (2) tuned GEMM CSV coverage only M≤256.
Open Questions
Related