Skip to content

Commit 224e6bb

Browse files
m96-chanclaude
andauthored
v0.2.20: Fused NN Kernels + Flash Attention 3 SM120 + FP8 Block-Scale MMA (#193)
* feat(ops): add native Conv1d CUDA kernel (#180) Add GPU-accelerated 1D convolution to replace CPU fallback in Whisper ASR encoder. Changes: - Add native/ops/conv/conv1d_kernels.cuh: F32/BF16/F16 kernels - Add native/ops/conv/conv1d.cu: Dispatcher with dtype validation - Add native/bindings/nn/conv.cpp: pybind11 bindings - Add src/pygpukit/ops/conv.py: Python API with CPU fallback - Update Whisper encoder to use native conv1d Performance: Eliminates GPU->CPU->GPU roundtrip per audio frame. Correctness: Max diff vs NumPy reference < 5e-7. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(examples): add Llama Guard 3 content safety classifier Add security example with Meta's Llama Guard 3 model for content moderation: - MLCommons hazard taxonomy (S1-S14 categories) - User input and agent response classification - Interactive and batch classification modes - Greedy decoding for deterministic safety classification 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(llm): add LLaMA 4 native CUDA kernels - Add LLaMA 4 model implementation with native CUDA kernels - Update CMakeLists.txt and bindings for LLaMA 4 ops Note: LLaMA 4 kernels are monolithic and need refactoring to follow modular nn/ structure (see issue) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * chore: add security benchmark scripts to gitignore * chore: update cutlass submodule (alignment fix) * feat(attention): add Flash Attention 3 for SM120 (Blackwell) Implement FA3 with WMMA tensor core acceleration: - WMMA-based score computation (Q @ K^T) - WMMA-based output computation (P @ V) - Vectorized memory loads (float4) - Warp-level softmax with shuffle reductions Benchmark results (RTX 5090, 32 heads, head_dim=128): - seq_len=128: FA3 1.02x vs SDPA - seq_len=512: FA3 1.03x vs SDPA - seq_len=1024: FA3 0.99x vs SDPA - seq_len=2048: FA3 1.01x vs SDPA All correctness tests pass (mean relative error < 2%). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(ops): add TMA utilities for SM90+ kernels Add reusable TMA (Tensor Memory Accelerator) utilities: - tma_utils.cuh: CUtensorMap descriptor creation, async copy ops - barrier_init/arrive/wait for mbarrier synchronization - tma_load_2d/3d for async global->shared transfers - Support for BF16, FP16, FP32 data types - 128B swizzle for bank-conflict-free access - warp_scheduler.cuh: Producer/consumer warp specialization - WarpRole enum and detection helpers - Warpgroup utilities for WGMMA - Named barriers for SM90+ - FA3Config/GemmConfig presets - pipeline.cuh: Multi-stage async pipeline management - Pipeline<N> template for N-stage buffering - DualBufferPipeline optimized 2-stage - PipelineBuffer shared memory manager These utilities enable TMA-based optimization for: - Flash Attention 3 - Persistent GEMM - Any kernel needing async global->shared transfers Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * wip(fa3): add TMA-enabled Flash Attention 3 kernel Add flash_attention_3_tma.cuh with: - TmaSharedMemory: Multi-stage K/V buffers with mbarrier - TmaFA3Config: Warp-specialized configuration (4 producer, 8 consumer) - Producer functions: TMA async bulk tensor loads - Consumer functions: WMMA-based score and output computation - 4-stage pipeline for K/V prefetching Architecture: - Producer warps (0-3): Issue TMA loads for K/V tiles - Consumer warps (4-11): Compute attention scores and output - mbarrier synchronization between stages NOTE: Requires Python bindings to create CUtensorMap descriptors. This is a WIP - kernel compiles but not yet callable from Python. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(attention): integrate TMA FA3 into SDPA dispatch - Add TMA FA3 environment control (PYGPUKIT_FA3_TMA) - Create TMA descriptor launcher function for Q/K/V tensors - Integrate TMA path into sdpa_causal_dispatch before regular FA3 - Fix TMA kernel to use 3D loads for 3D tensor descriptors - Add benchmark script for TMA vs baseline comparison Benchmark results (RTX 5090, SM 120a): - [32, 512, 128]: Baseline 2090us, TMA 2170us (0.96x) - [32, 1024, 128]: Baseline 7175us, TMA 7187us (1.00x) - [32, 2048, 128]: Baseline 27165us, TMA 27125us (1.00x) - [32, 4096, 128]: Baseline 93848us, TMA 93444us (1.00x) Correctness: PASS (results match baseline) Note: TMA kernel is functional but not yet optimized for speedup. Future work: warp specialization tuning, swizzle patterns. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(fa3): resolve __syncthreads divergence causing kernel hang at scale Bug: TMA FA3 kernel hung at 256+ blocks due to __syncthreads() inside consumer-only code path. Producer warps never reached sync. Fix: - Split consumer_compute_output() into two functions: - convert_scores_to_probs(): ALL threads participate (has syncs) - consumer_compute_output_matmul(): consumers only (no syncs) - Reduce TILE_Q 64->32 and NUM_STAGES 4->2 for 99KB smem limit - Use union for smem_scores/smem_probs to save 8KB Benchmark (RTX 5090, 32 heads): - seq_len=512: 6.6ms, 0.65 TFLOPS - seq_len=1024: 25.8ms, 0.66 TFLOPS - seq_len=2048: 99.2ms, 0.69 TFLOPS - seq_len=4096: 387.5ms, 0.71 TFLOPS Correctness: PASS (matches FA3 baseline) Next: Parallelize softmax across query positions for 8-32x speedup Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * refactor(fa3): parallelize softmax and fix consumer warp indexing Changes: 1. Warp-parallel softmax: Each consumer warp handles different q rows - 8 warps process 8 rows simultaneously (was: all warps on same row) - Purely warp-synchronous with no __syncthreads() inside 2. Fix consumer warp indexing bug in matmul functions: - consumer_compute_scores: use consumer_warp_idx (0-7) not global warp_id (4-11) - consumer_compute_output_matmul: same fix - Ensures all tiles are computed (was missing tiles 0-3) 3. Direct BF16 softmax output: - Softmax writes BF16 directly to smem_probs - Eliminates convert_scores_to_probs function call - Saves 2 __syncthreads() per iteration Sync point analysis (after optimization): - 5 syncs per iteration (was 7): 1. After barrier_wait (TMA data visible) 2. After Q@K (scores ready for causal mask) 3. After causal mask (scores ready for softmax) 4. After softmax (probs ready for P@V) 5. End of iteration (next TMA) Benchmark (RTX 5090, 32 heads): - Performance: ~0.65-0.71 TFLOPS (similar to baseline) - Correctness: PASS Note: Performance unchanged suggests bottleneck is elsewhere (WMMA efficiency, memory bandwidth, or 1 block/SM occupancy). Next optimization: wgmma instructions for SM120a. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(fa3): resolve non-determinism in TMA FA3 attention kernel Root cause: Union between smem_scores (float) and smem_probs (bf16) caused a race condition when multiple warps processed different Q rows in parallel. Warp B writing to smem_probs[row_B] could corrupt smem_scores[row_A] that Warp A was still reading. Fix: Two-phase softmax approach - Phase 1: ALL warps read scores, compute probs, store to REGISTERS - Phase 2: After __syncthreads(), ALL warps write probs to smem_probs Also includes: - TMA descriptor cache for reduced host-side overhead (99.4% hit rate) - cudaEvent-based kernel timing for accurate benchmarks - Proper handling of fully-masked rows (causal attention edge case) Benchmark results (RTX 5090, SM120a): - seq_len=1024: 51.21 TFLOPS (kernel-only) - seq_len=2048: 59.86 TFLOPS (kernel-only) - Correctness: PASS (max_diff=0.0) - Determinism: PASS (all runs identical) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * wip(fa4): add Flash Attention 4 SM120 Phase 1 BF16 baseline Phase 1 implementation identical to FA3 TMA structure. This establishes the baseline for NVFP4 integration in Phase 2/3. Benchmark results (RTX 5090, seq_len=1024, 32 heads, 128 head_dim): - Kernel-only: 335.6 us (51.19 TFLOPS) - E2E cached: 368.1 us (46.67 TFLOPS) - Correctness: PASS (max diff = 0 vs FA3) Files added: - flash_attention_4_sm120.cuh: FA4 kernel with config structs for all phases - benchmark_fa4_sm120.py: Benchmark script with correctness verification - fa4_sm120_research.md: SM100 vs SM120 architecture research Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * wip(fa4): Phase 2 NVFP4 Q@K^T external validation Phase 2 validates the NVFP4 GEMM path for attention scores. Benchmark results (RTX 5090, seq_len=1024, single head): - NVFP4 Q@K^T: 394.0 us (0.68 TFLOPS) - Correctness: 21% rel_diff vs NumPy (acceptable for 4-bit) Key finding: NVFP4 GEMM optimized for large K (LLM weights), not attention's small K=128 (head_dim). CUTLASS uses K=256 tiles. For comparison: - Full FA3 TMA (32 heads): 330.9 us (51.92 TFLOPS) NVFP4 benefit in attention comes from memory bandwidth (4x smaller loads), not compute throughput. Full integration requires PTX inline assembly for mma.sync.aligned.block_scale. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * bench(fa4): add Phase 3 full NVFP4 pipeline validation Phase 3 Results (RTX 5090, seq_len=1024): - P@V (K=seq_len=1024): 94.7 us (2.84 TFLOPS) - Q@K^T (K=head_dim=128): 353.3 us (0.76 TFLOPS) - Larger K speedup: 3.73x (better tile utilization) Key Findings: 1. NVFP4 CUTLASS GEMM uses K=256 tile size, suboptimal for head_dim=128 2. P (softmax output) CANNOT use NVFP4 directly: - Softmax values ~1/seq_len = 0.001 - NVFP4 smallest positive = 0.25 - All P values quantize to 0 (100% error) Recommended FA4 Architecture: - Q, K, V: pre-quantize to NVFP4 (static weights OK) - P: keep in BF16 (dynamic, small values) - Q@K^T: use mma.sync.block_scale (NVFP4) - P@V: use mma.sync (BF16) or mixed precision Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * docs(fa4): add SM120 implementation report Complete analysis of FA4 (Flash Attention 4) feasibility for RTX 5090. Key Findings: 1. SM120 uses mma.sync.block_scale, NOT tcgen05.mma (datacenter) 2. NVFP4 GEMM optimized for K=256 tiles, suboptimal for head_dim=128 3. P (softmax output) CANNOT use NVFP4: - Softmax values ~0.001 << NVFP4 minimum 0.25 - All P values quantize to 0 (100% error) Recommendation: Do NOT proceed with FA4 NVFP4 for SM120. FA3 TMA (51.97 TFLOPS) is already optimal for GeForce Blackwell. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(fa3): add SM120 tuning configs with version selection Add 5 SM120 config versions for FA3 TMA attention tuning: - V0: Baseline (TILE_Q=32, TILE_KV=64, 4+8 warps) - 63.61 TFLOPS - V1: Smaller tiles (TILE_KV=32) - 53.11 TFLOPS - V2: 3-stage pipeline (TILE_KV=32) - 52.86 TFLOPS - V3: More compute warps (2+10) - 64.01 TFLOPS - V4: Most compute warps (4+12) - 66.62 TFLOPS (+4.7%) Environment variable PYGPUKIT_FA3_SM120_VERSION (0-4) selects config. Version 4 achieves best performance with 16 total warps. Benchmark results (RTX 5090, seq_len=4096, heads=32, head_dim=128): - V0 (baseline): 63.61 TFLOPS - V4 (4+12 warps): 66.62 TFLOPS Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * docs(fa3): document sync requirements in SM120 kernel Add detailed documentation explaining why all 6 __syncthreads() per KV tile are required and cannot be reduced: 1. After barrier_wait - mbarrier is per-thread, need block sync 2. After compute_scores - scores must complete before mask 3. After mask - mask must complete before softmax reads 4. After softmax phase1 - union race condition prevention 5. After softmax phase2 - probs must complete before P@V 6. End of loop - prevents cross-iteration TMA/read race Attempted sync reduction failed due to: - Removing sync after barrier_wait causes thread divergence races - Removing end-of-loop sync causes prefetch/read stage conflicts Current performance: 64.6 TFLOPS (SM120, seq_len=4096) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(fp8): add native PTX inline assembly for FP8 block-scale MMA Implement FP8 E4M3 block-scale MMA using native PTX inline assembly for SM120. Fragment layouts derived from CUTLASS CuTe mma_traits_sm80.hpp analysis. Key implementation details: - PTX instruction: mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e4m3.f32.ue8m0 - A fragment: 4 registers (16 FP8 E4M3 elements each) - B fragment: 2 registers (8 FP8 E4M3 elements each) - C/D fragment: 4 FP32 registers (16x8 output tile) - Scale factors: UE8M0 format (8-bit unsigned exponent) CuTe Layout Analysis: - ALayout: (T32,V16) -> (M16,K32), t0=lane/8, t1=lane%8 - BLayout: (T32,V8) -> (K32,N8), non-contiguous byte access - CLayout: (T32,V4) -> (M16,N8), d[v] = C[4*t0+v, t1] Test result: PASS on RTX 5090 (SM 120) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(fa3): add FP8 block-scale MMA Flash Attention 3 for SM120 Implements FA3 with FP8 E4M3 Q@K^T using SM120's block-scale MMA instruction for ~50% memory bandwidth reduction vs BF16. Key implementation details: - FP8 E4M3 quantization with per-head global UE8M0 scaling - mma.sync.aligned.kind::mxf8f6f4.block_scale.m16n8k32.f32.e4m3.e4m3 - B fragment loading: n_idx=lane_id/4, k_base=(lane_id%4)*8 - SM80_16x8_Row C fragment layout for correct output mapping - BF16 P@V with WMMA for precision (FP8 V gave ~18% error) Validation results (vs BF16 FA3 reference): - Prefill (128 tokens): 1.97% error, 0.9999 correlation - PASS - Prefill (512 tokens): 1.58% error, 0.9999 correlation - PASS - Decode (single token): 0% error, perfect correlation - PASS New files: - native/ops/nn/attention/flash_attention_3_fp8_sm120.cuh - native/ops/matmul/gemm/fp8_block_scale/test_mma_direct.cuh Python API: sdpa_causal_fp8(), fa3_fp8_available(), test_fp8_mma_direct() Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(nn): add fused kernels for RMSNorm+Residual, SwiGLU, GeGLU Add high-performance fused kernels to reduce memory bandwidth and kernel launch overhead in LLM inference pipelines. New kernels: - rmsnorm_residual: y = rmsnorm(x + residual) * gamma - swiglu: y = silu(gate) * up (used in Qwen, LLaMA3, Mistral FFN) - geglu: y = gelu(gate) * up Benchmark results (RTX 5090): - SwiGLU: 2.38-14.25x speedup vs separate ops - RMSNorm+Residual: 2.03-12.37x speedup - GeGLU: 2.40-13.10x speedup Larger batch sizes show greater speedups due to memory bandwidth savings from eliminating intermediate buffers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(lint): resolve ruff lint errors in fused kernel files - Fix import sorting (I001) - Fix unused loop variable (B007) by renaming to _features - Fix loop variable binding (B023) by using default args - Remove unused mode argument in open() (UP015) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix(mypy): add type annotation for scores_max in llama4.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 3adce30 commit 224e6bb

62 files changed

Lines changed: 16131 additions & 49 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,4 @@ test_gpu/
150150
.claude/memory.jsonl
151151
.claude/benchmarks.db
152152
.claude/logs/
153+
examples/security/*_benchmark.py

0 commit comments

Comments
 (0)