Commit 224e6bb
v0.2.20: Fused NN Kernels + Flash Attention 3 SM120 + FP8 Block-Scale MMA (#193)
* feat(ops): add native Conv1d CUDA kernel (#180)
Add GPU-accelerated 1D convolution to replace CPU fallback in Whisper ASR encoder.
Changes:
- Add native/ops/conv/conv1d_kernels.cuh: F32/BF16/F16 kernels
- Add native/ops/conv/conv1d.cu: Dispatcher with dtype validation
- Add native/bindings/nn/conv.cpp: pybind11 bindings
- Add src/pygpukit/ops/conv.py: Python API with CPU fallback
- Update Whisper encoder to use native conv1d
Performance: Eliminates GPU->CPU->GPU roundtrip per audio frame.
Correctness: Max diff vs NumPy reference < 5e-7.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(examples): add Llama Guard 3 content safety classifier
Add security example with Meta's Llama Guard 3 model for content moderation:
- MLCommons hazard taxonomy (S1-S14 categories)
- User input and agent response classification
- Interactive and batch classification modes
- Greedy decoding for deterministic safety classification
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(llm): add LLaMA 4 native CUDA kernels
- Add LLaMA 4 model implementation with native CUDA kernels
- Update CMakeLists.txt and bindings for LLaMA 4 ops
Note: LLaMA 4 kernels are monolithic and need refactoring
to follow modular nn/ structure (see issue)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* chore: add security benchmark scripts to gitignore
* chore: update cutlass submodule (alignment fix)
* feat(attention): add Flash Attention 3 for SM120 (Blackwell)
Implement FA3 with WMMA tensor core acceleration:
- WMMA-based score computation (Q @ K^T)
- WMMA-based output computation (P @ V)
- Vectorized memory loads (float4)
- Warp-level softmax with shuffle reductions
Benchmark results (RTX 5090, 32 heads, head_dim=128):
- seq_len=128: FA3 1.02x vs SDPA
- seq_len=512: FA3 1.03x vs SDPA
- seq_len=1024: FA3 0.99x vs SDPA
- seq_len=2048: FA3 1.01x vs SDPA
All correctness tests pass (mean relative error < 2%).
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(ops): add TMA utilities for SM90+ kernels
Add reusable TMA (Tensor Memory Accelerator) utilities:
- tma_utils.cuh: CUtensorMap descriptor creation, async copy ops
- barrier_init/arrive/wait for mbarrier synchronization
- tma_load_2d/3d for async global->shared transfers
- Support for BF16, FP16, FP32 data types
- 128B swizzle for bank-conflict-free access
- warp_scheduler.cuh: Producer/consumer warp specialization
- WarpRole enum and detection helpers
- Warpgroup utilities for WGMMA
- Named barriers for SM90+
- FA3Config/GemmConfig presets
- pipeline.cuh: Multi-stage async pipeline management
- Pipeline<N> template for N-stage buffering
- DualBufferPipeline optimized 2-stage
- PipelineBuffer shared memory manager
These utilities enable TMA-based optimization for:
- Flash Attention 3
- Persistent GEMM
- Any kernel needing async global->shared transfers
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* wip(fa3): add TMA-enabled Flash Attention 3 kernel
Add flash_attention_3_tma.cuh with:
- TmaSharedMemory: Multi-stage K/V buffers with mbarrier
- TmaFA3Config: Warp-specialized configuration (4 producer, 8 consumer)
- Producer functions: TMA async bulk tensor loads
- Consumer functions: WMMA-based score and output computation
- 4-stage pipeline for K/V prefetching
Architecture:
- Producer warps (0-3): Issue TMA loads for K/V tiles
- Consumer warps (4-11): Compute attention scores and output
- mbarrier synchronization between stages
NOTE: Requires Python bindings to create CUtensorMap descriptors.
This is a WIP - kernel compiles but not yet callable from Python.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(attention): integrate TMA FA3 into SDPA dispatch
- Add TMA FA3 environment control (PYGPUKIT_FA3_TMA)
- Create TMA descriptor launcher function for Q/K/V tensors
- Integrate TMA path into sdpa_causal_dispatch before regular FA3
- Fix TMA kernel to use 3D loads for 3D tensor descriptors
- Add benchmark script for TMA vs baseline comparison
Benchmark results (RTX 5090, SM 120a):
- [32, 512, 128]: Baseline 2090us, TMA 2170us (0.96x)
- [32, 1024, 128]: Baseline 7175us, TMA 7187us (1.00x)
- [32, 2048, 128]: Baseline 27165us, TMA 27125us (1.00x)
- [32, 4096, 128]: Baseline 93848us, TMA 93444us (1.00x)
Correctness: PASS (results match baseline)
Note: TMA kernel is functional but not yet optimized for speedup.
Future work: warp specialization tuning, swizzle patterns.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* fix(fa3): resolve __syncthreads divergence causing kernel hang at scale
Bug: TMA FA3 kernel hung at 256+ blocks due to __syncthreads()
inside consumer-only code path. Producer warps never reached sync.
Fix:
- Split consumer_compute_output() into two functions:
- convert_scores_to_probs(): ALL threads participate (has syncs)
- consumer_compute_output_matmul(): consumers only (no syncs)
- Reduce TILE_Q 64->32 and NUM_STAGES 4->2 for 99KB smem limit
- Use union for smem_scores/smem_probs to save 8KB
Benchmark (RTX 5090, 32 heads):
- seq_len=512: 6.6ms, 0.65 TFLOPS
- seq_len=1024: 25.8ms, 0.66 TFLOPS
- seq_len=2048: 99.2ms, 0.69 TFLOPS
- seq_len=4096: 387.5ms, 0.71 TFLOPS
Correctness: PASS (matches FA3 baseline)
Next: Parallelize softmax across query positions for 8-32x speedup
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* refactor(fa3): parallelize softmax and fix consumer warp indexing
Changes:
1. Warp-parallel softmax: Each consumer warp handles different q rows
- 8 warps process 8 rows simultaneously (was: all warps on same row)
- Purely warp-synchronous with no __syncthreads() inside
2. Fix consumer warp indexing bug in matmul functions:
- consumer_compute_scores: use consumer_warp_idx (0-7) not global warp_id (4-11)
- consumer_compute_output_matmul: same fix
- Ensures all tiles are computed (was missing tiles 0-3)
3. Direct BF16 softmax output:
- Softmax writes BF16 directly to smem_probs
- Eliminates convert_scores_to_probs function call
- Saves 2 __syncthreads() per iteration
Sync point analysis (after optimization):
- 5 syncs per iteration (was 7):
1. After barrier_wait (TMA data visible)
2. After Q@K (scores ready for causal mask)
3. After causal mask (scores ready for softmax)
4. After softmax (probs ready for P@V)
5. End of iteration (next TMA)
Benchmark (RTX 5090, 32 heads):
- Performance: ~0.65-0.71 TFLOPS (similar to baseline)
- Correctness: PASS
Note: Performance unchanged suggests bottleneck is elsewhere
(WMMA efficiency, memory bandwidth, or 1 block/SM occupancy).
Next optimization: wgmma instructions for SM120a.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* fix(fa3): resolve non-determinism in TMA FA3 attention kernel
Root cause: Union between smem_scores (float) and smem_probs (bf16)
caused a race condition when multiple warps processed different Q rows
in parallel. Warp B writing to smem_probs[row_B] could corrupt
smem_scores[row_A] that Warp A was still reading.
Fix: Two-phase softmax approach
- Phase 1: ALL warps read scores, compute probs, store to REGISTERS
- Phase 2: After __syncthreads(), ALL warps write probs to smem_probs
Also includes:
- TMA descriptor cache for reduced host-side overhead (99.4% hit rate)
- cudaEvent-based kernel timing for accurate benchmarks
- Proper handling of fully-masked rows (causal attention edge case)
Benchmark results (RTX 5090, SM120a):
- seq_len=1024: 51.21 TFLOPS (kernel-only)
- seq_len=2048: 59.86 TFLOPS (kernel-only)
- Correctness: PASS (max_diff=0.0)
- Determinism: PASS (all runs identical)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* wip(fa4): add Flash Attention 4 SM120 Phase 1 BF16 baseline
Phase 1 implementation identical to FA3 TMA structure.
This establishes the baseline for NVFP4 integration in Phase 2/3.
Benchmark results (RTX 5090, seq_len=1024, 32 heads, 128 head_dim):
- Kernel-only: 335.6 us (51.19 TFLOPS)
- E2E cached: 368.1 us (46.67 TFLOPS)
- Correctness: PASS (max diff = 0 vs FA3)
Files added:
- flash_attention_4_sm120.cuh: FA4 kernel with config structs for all phases
- benchmark_fa4_sm120.py: Benchmark script with correctness verification
- fa4_sm120_research.md: SM100 vs SM120 architecture research
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* wip(fa4): Phase 2 NVFP4 Q@K^T external validation
Phase 2 validates the NVFP4 GEMM path for attention scores.
Benchmark results (RTX 5090, seq_len=1024, single head):
- NVFP4 Q@K^T: 394.0 us (0.68 TFLOPS)
- Correctness: 21% rel_diff vs NumPy (acceptable for 4-bit)
Key finding: NVFP4 GEMM optimized for large K (LLM weights),
not attention's small K=128 (head_dim). CUTLASS uses K=256 tiles.
For comparison:
- Full FA3 TMA (32 heads): 330.9 us (51.92 TFLOPS)
NVFP4 benefit in attention comes from memory bandwidth (4x smaller
loads), not compute throughput. Full integration requires PTX
inline assembly for mma.sync.aligned.block_scale.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* bench(fa4): add Phase 3 full NVFP4 pipeline validation
Phase 3 Results (RTX 5090, seq_len=1024):
- P@V (K=seq_len=1024): 94.7 us (2.84 TFLOPS)
- Q@K^T (K=head_dim=128): 353.3 us (0.76 TFLOPS)
- Larger K speedup: 3.73x (better tile utilization)
Key Findings:
1. NVFP4 CUTLASS GEMM uses K=256 tile size, suboptimal for head_dim=128
2. P (softmax output) CANNOT use NVFP4 directly:
- Softmax values ~1/seq_len = 0.001
- NVFP4 smallest positive = 0.25
- All P values quantize to 0 (100% error)
Recommended FA4 Architecture:
- Q, K, V: pre-quantize to NVFP4 (static weights OK)
- P: keep in BF16 (dynamic, small values)
- Q@K^T: use mma.sync.block_scale (NVFP4)
- P@V: use mma.sync (BF16) or mixed precision
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* docs(fa4): add SM120 implementation report
Complete analysis of FA4 (Flash Attention 4) feasibility for RTX 5090.
Key Findings:
1. SM120 uses mma.sync.block_scale, NOT tcgen05.mma (datacenter)
2. NVFP4 GEMM optimized for K=256 tiles, suboptimal for head_dim=128
3. P (softmax output) CANNOT use NVFP4:
- Softmax values ~0.001 << NVFP4 minimum 0.25
- All P values quantize to 0 (100% error)
Recommendation: Do NOT proceed with FA4 NVFP4 for SM120.
FA3 TMA (51.97 TFLOPS) is already optimal for GeForce Blackwell.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(fa3): add SM120 tuning configs with version selection
Add 5 SM120 config versions for FA3 TMA attention tuning:
- V0: Baseline (TILE_Q=32, TILE_KV=64, 4+8 warps) - 63.61 TFLOPS
- V1: Smaller tiles (TILE_KV=32) - 53.11 TFLOPS
- V2: 3-stage pipeline (TILE_KV=32) - 52.86 TFLOPS
- V3: More compute warps (2+10) - 64.01 TFLOPS
- V4: Most compute warps (4+12) - 66.62 TFLOPS (+4.7%)
Environment variable PYGPUKIT_FA3_SM120_VERSION (0-4) selects config.
Version 4 achieves best performance with 16 total warps.
Benchmark results (RTX 5090, seq_len=4096, heads=32, head_dim=128):
- V0 (baseline): 63.61 TFLOPS
- V4 (4+12 warps): 66.62 TFLOPS
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* docs(fa3): document sync requirements in SM120 kernel
Add detailed documentation explaining why all 6 __syncthreads() per KV
tile are required and cannot be reduced:
1. After barrier_wait - mbarrier is per-thread, need block sync
2. After compute_scores - scores must complete before mask
3. After mask - mask must complete before softmax reads
4. After softmax phase1 - union race condition prevention
5. After softmax phase2 - probs must complete before P@V
6. End of loop - prevents cross-iteration TMA/read race
Attempted sync reduction failed due to:
- Removing sync after barrier_wait causes thread divergence races
- Removing end-of-loop sync causes prefetch/read stage conflicts
Current performance: 64.6 TFLOPS (SM120, seq_len=4096)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(fp8): add native PTX inline assembly for FP8 block-scale MMA
Implement FP8 E4M3 block-scale MMA using native PTX inline assembly for SM120.
Fragment layouts derived from CUTLASS CuTe mma_traits_sm80.hpp analysis.
Key implementation details:
- PTX instruction: mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e4m3.f32.ue8m0
- A fragment: 4 registers (16 FP8 E4M3 elements each)
- B fragment: 2 registers (8 FP8 E4M3 elements each)
- C/D fragment: 4 FP32 registers (16x8 output tile)
- Scale factors: UE8M0 format (8-bit unsigned exponent)
CuTe Layout Analysis:
- ALayout: (T32,V16) -> (M16,K32), t0=lane/8, t1=lane%8
- BLayout: (T32,V8) -> (K32,N8), non-contiguous byte access
- CLayout: (T32,V4) -> (M16,N8), d[v] = C[4*t0+v, t1]
Test result: PASS on RTX 5090 (SM 120)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(fa3): add FP8 block-scale MMA Flash Attention 3 for SM120
Implements FA3 with FP8 E4M3 Q@K^T using SM120's block-scale MMA
instruction for ~50% memory bandwidth reduction vs BF16.
Key implementation details:
- FP8 E4M3 quantization with per-head global UE8M0 scaling
- mma.sync.aligned.kind::mxf8f6f4.block_scale.m16n8k32.f32.e4m3.e4m3
- B fragment loading: n_idx=lane_id/4, k_base=(lane_id%4)*8
- SM80_16x8_Row C fragment layout for correct output mapping
- BF16 P@V with WMMA for precision (FP8 V gave ~18% error)
Validation results (vs BF16 FA3 reference):
- Prefill (128 tokens): 1.97% error, 0.9999 correlation - PASS
- Prefill (512 tokens): 1.58% error, 0.9999 correlation - PASS
- Decode (single token): 0% error, perfect correlation - PASS
New files:
- native/ops/nn/attention/flash_attention_3_fp8_sm120.cuh
- native/ops/matmul/gemm/fp8_block_scale/test_mma_direct.cuh
Python API: sdpa_causal_fp8(), fa3_fp8_available(), test_fp8_mma_direct()
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat(nn): add fused kernels for RMSNorm+Residual, SwiGLU, GeGLU
Add high-performance fused kernels to reduce memory bandwidth and
kernel launch overhead in LLM inference pipelines.
New kernels:
- rmsnorm_residual: y = rmsnorm(x + residual) * gamma
- swiglu: y = silu(gate) * up (used in Qwen, LLaMA3, Mistral FFN)
- geglu: y = gelu(gate) * up
Benchmark results (RTX 5090):
- SwiGLU: 2.38-14.25x speedup vs separate ops
- RMSNorm+Residual: 2.03-12.37x speedup
- GeGLU: 2.40-13.10x speedup
Larger batch sizes show greater speedups due to memory bandwidth
savings from eliminating intermediate buffers.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* fix(lint): resolve ruff lint errors in fused kernel files
- Fix import sorting (I001)
- Fix unused loop variable (B007) by renaming to _features
- Fix loop variable binding (B023) by using default args
- Remove unused mode argument in open() (UP015)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* fix(mypy): add type annotation for scores_max in llama4.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
---------
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>1 parent 3adce30 commit 224e6bb
62 files changed
Lines changed: 16131 additions & 49 deletions
File tree
- .serena/memories
- docs
- examples
- security
- native
- bindings
- gemm
- nn
- ops
- common
- conv
- matmul
- gemm/fp8_block_scale
- nn
- attention
- arch
- fused
- llama4
- src/pygpukit
- asr/whisper
- llm/models
- ops
- nn
- third_party
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
150 | 150 | | |
151 | 151 | | |
152 | 152 | | |
| 153 | + | |
0 commit comments