Skip to content

Conversation

@agolajko
Copy link
Contributor

@agolajko agolajko commented Jan 15, 2026

Draft PR re #862

Replaces the Jax ragged_dot with a cuda tile implementation
Inspired by https://github.com/NVIDIA/cutile-python/blob/main/samples/MoE.py

Benchmarking cuda-tile and existing ragged_dot implementation

On RTX Pro 6000 via Runpod


================================================================================
                      CUTILE vs RAGGED_DOT BENCHMARK SUITE                      
================================================================================

Small (original)
  Config: 1024 tokens × 512 hidden → 512 out, 16 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 1.739 ms
  cutile:     1.721 ms
  Speedup:    1.01x

Medium (Qwen-0.6B scale)
  Config: 2048 tokens × 1024 hidden → 1024 out, 16 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 3.222 ms
  cutile:     3.506 ms
  Speedup:    0.92x

Large (Qwen2.5-1.5B scale)
  Config: 4096 tokens × 1536 hidden → 1536 out, 32 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 9.152 ms
  cutile:     9.142 ms
  Speedup:    1.00x

Large+ (2B scale)
  Config: 4096 tokens × 2048 hidden → 2048 out, 32 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 14.459 ms
  cutile:     14.464 ms
  Speedup:    1.00x

XLarge (Llama 3 8B scale)
  Config: 8192 tokens × 4096 hidden → 4096 out, 64 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 94.049 ms
  cutile:     92.585 ms
  Speedup:    1.02x

Results of time_cutile_parts.py giving breakdown of time spent on different tasks

TX_USE_CUTILE_LORA=1 uv run tests/cutile/time_cutile_parts.py
Config: m=2048, d=1024, out=1024, E=16, dtype=torch.float16
TILE_M/N/K = 128/128/64
rhs contiguous=True stride=(1048576, 1024, 1)

=== CUDA-event timing breakdown ===
pad_groups:    0.199 ms
cutile_launch: 0.068 ms
combined:      0.272 ms
pad fraction:  73.3%
launch frac:   25.0%
(pad+launch):  0.267 ms (rough expected)

Todo:

  • Multi GPU support
  • backward pass
  • more tests
  • profile

@pcmoritz pcmoritz added the tx label Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants