Skip to content

# RFC: Split cuLA into Core Kernel Layer and FLA Transitional Wrapper #25

@icavan

Description

@icavan

Status: Draft
Created: 2026-04-04

Summary

Split the cula package into two layers:

  1. cula (core) — pure CUDA kernels with zero FLA dependency. FLA calls into this directly via its backend dispatch. This is the long-term deliverable.
  2. cula-fla (transitional wrapper) — a temporary package that exists only during the transition period before cuLA is fully integrated into FLA. It allows early adopters to try cuLA kernels directly with a familiar FLA-like API, without waiting for FLA upstream integration to complete. cula-fla is expected to be deprecated and removed once FLA natively dispatches to cuLA.
End state (target):  FLA user API → @dispatch → cula (core kernels)
Transition period:   User → cula-fla → cula (core kernels) + fla (gate, bwd, etc.)

Motivation

Current Problem

cuLA imports from FLA at multiple levels:

Category Examples Count
Utils tensor_cache, prepare_chunk_indices, prepare_lens, RCP_LN2, autocast_custom_fwd/bwd, input_guard ~10 imports across 6 files
Algorithm logic kda_gate_chunk_cumsum, recompute_w_u_fwd, chunk_kda_bwd, l2norm_fwd/bwd, chunk_local_cumsum ~8 imports across 5 files
CP / Distributed FLACPContext, chunk_gated_delta_rule_fwd_h_pre_process, compress_h0 3 imports across 2 files

Today the dependency is one-directional (cula → fla), so no circular import exists. But the end goal is for FLA to call cuLA kernels directly:

fla.ops.kda.chunk_kda() → @dispatch('common') → cula.ops.chunk_delta_h.chunk_gated_delta_rule_fwd_h()

If cula core still imports fla, this creates a circular dependency at import time. The solution: cula core must have zero FLA imports.

Architecture

┌──────────────────────────────────────────────────────────┐
│                     FLA                                  │
│  (user-facing API, Triton fallback, algorithm logic)     │
│           │                                              │
│     @dispatch('common')                                  │
│     @dispatch('kda')    ← backend dispatch               │
│           │                                              │
│     try: import cula    ← optional dependency            │
└───────────┼──────────────────────────────────────────────┘
            ↓
┌───────────────────────────────┐    ┌──────────────────────────────┐
│         cula (core)           │    │       cula-fla (trial)       │
│  CUDA kernels, zero FLA dep  │    │  Quick-start wrapper for     │
│                               │    │  users to try cuLA directly  │
│  • chunk_delta_h (fwd_h)     │    │                              │
│  • fwd_o                     │    │  • chunk_kda() autograd fn   │
│  • recompute_wu              │    │  • kda_prefill_hopper()      │
│  • lightning_attn (prefill)  │    │  • kda_prefill_blackwell()   │
│  • la_decode                 │    │                              │
│  • C++ kernels (sm90/sm10x)  │    │  depends on: cula + fla     │
│                               │    │  (uses fla for gate, bwd,   │
│  depends on: torch, cutlass,  │    │   l2norm, CP, intra-chunk)  │
│  tvm-ffi, triton              │    │                              │
└───────────────────────────────┘    └──────────────────────────────┘

Key insight: FLA already owns the orchestration (gating, intra-chunk WY repr, backward pass, CP coordination). It just needs to call cuLA's optimized CUDA kernels for the compute-heavy inner loops. FLA does this via @dispatch — no adapter layer needed.

cula-fla is a transitional package that bridges the gap during the period when:

  • FLA has not yet added cuLA backend dispatch upstream
  • cuLA has not yet implemented backward pass, gate, and intra-chunk as CUDA kernels
  • Early adopters want to test cuLA's performance improvements immediately

It is today's cula/kda/ extracted into a separate package — it reuses FLA's algorithm logic (gate, backward, etc.) and plugs in cuLA's CUDA kernels. Once FLA integration is complete, users switch from from cula_fla.kda import chunk_kda to from fla.ops.kda import chunk_kda and cula-fla gets archived.

Current Dependency Analysis

Files with zero or trivially-removable FLA imports (core-ready)

File FLA Imports Action
cula/ops/chunk_delta_h.py prepare_chunk_indices, prepare_lens, tensor_cache Internalize 3 small utils
cula/ops/fwd_o.py prepare_chunk_indices Same
cula/ops/lightning_attn.py None Ready
cula/ops/linear_attn.py None Ready
cula/ops/inv.py None Ready
cula/ops/recompute_wu.py None (FLA import only in __main__ benchmark) Ready
cula/ops/recompute_wu_occ.py None (FLA import only in __main__ benchmark) Ready
cula/lightning/la_decode.py None Ready
cula/utils.py None Ready
csrc/** (C++/CUDA) None Ready

Files that belong in cula-fla (heavy FLA dependency)

File FLA Dependencies Role
cula/kda/chunk.py l2norm_fwd/bwd, chunk_kda_bwd, FLACPContext, autocast_custom_fwd/bwd, input_guard torch.autograd.Function — forward uses cuLA, backward is entirely FLA
cula/kda/chunk_fwd.py kda_gate_chunk_cumsum, chunk_local_cumsum, RCP_LN2, FLACPContext, CP pre-process, compress_h0 Forward orchestration: gate → intra → fwd_h → fwd_o
cula/kda/chunk_intra.py recompute_w_u_fwd, prepare_chunk_indices, Triton ops (exp2, gather) Triton intra-chunk kernel (not a CUDA kernel)
cula/kda/hopper_fused_fwd.py l2norm_fwd, kda_gate_chunk_cumsum, chunk_local_cumsum, autocast utils Hopper fused forward orchestration
cula/kda/blackwell_fused_fwd.py Same as above + kda_gate_fwd Blackwell fused forward orchestration

FLA utils to internalize into cula core

Utility Lines Current Location Action
prepare_chunk_indices ~15 fla.ops.utils.index Copy to cula/utils.py
prepare_lens ~3 fla.ops.utils Copy to cula/utils.py
tensor_cache ~30 fla.utils Copy to cula/utils.py
RCP_LN2 / INV_LN2 1 fla.ops.utils.constant Already exists as INV_LN2 in cula/ops/chunk_delta_h.py

These are stable, trivial utilities unlikely to change. Internalizing avoids any FLA dependency in core.

Proposed Package Structure

cula (core) — pip install cula

cula/
├── __init__.py              # exports kernel functions
├── _version.py
├── utils.py                 # device detection, prepare_chunk_indices,
│                            # prepare_lens, tensor_cache, constants
├── ops/
│   ├── __init__.py
│   ├── chunk_delta_h.py     # fwd_h kernel (CuTe DSL, SM10X)
│   ├── fwd_o.py             # fwd_o kernel (CuTe DSL, SM10X)
│   ├── recompute_wu.py      # WU recomputation (CuTe DSL)
│   ├── inv.py               # matrix inverse
│   ├── lightning_attn.py    # Lightning Attention prefill
│   └── linear_attn.py       # Lightning Attention varlen
├── lightning/
│   ├── __init__.py
│   └── la_decode.py         # Lightning Attention decode
└── cudac/                   # C++/CUDA extension (sm90/sm100/sm103)

Dependencies: torch, nvidia-cutlass-dsl, apache-tvm-ffi, triton

Zero FLA dependency. Every function is a pure kernel: tensors in, tensors out.

cula-fla (transitional wrapper) — pip install cula-fla

Lifecycle: This package is a transitional artifact. It exists because cuLA's FLA integration is not yet complete — FLA has not yet added cuLA backend dispatch, and cuLA has not yet implemented all required kernels (backward pass, gate, intra-chunk) to be self-sufficient. Once FLA natively dispatches to cuLA core, cula-fla will be deprecated with a notice pointing users to from fla.ops.kda import chunk_kda instead.

cula_fla/
├── __init__.py
└── kda/
    ├── __init__.py            # exports: chunk_kda, kda_prefill_hopper, etc.
    ├── chunk.py               # torch.autograd.Function (fwd: cuLA, bwd: FLA)
    ├── chunk_fwd.py           # forward orchestration
    ├── chunk_intra.py         # Triton intra-chunk (FLA's recompute_w_u_fwd)
    ├── hopper_fused_fwd.py    # Hopper fused forward
    └── blackwell_fused_fwd.py # Blackwell fused forward

Dependencies: cula, flash-linear-attention

This is what cula/kda/ is today — moved into a separate package. Usage:

# Quick-start: try cuLA KDA directly (same interface as fla.ops.kda)
from cula_fla.kda import chunk_kda
o, final_state = chunk_kda(q, k, v, g, beta, ...)

FLA integration (in FLA's repo, not cuLA's)

FLA registers cuLA as a backend via its existing BackendRegistry. This lives in FLA's codebase:

# fla/ops/kda/backends/cula.py (in FLA repo)
from fla.ops.backends import BaseBackend

class CuLAKDABackend(BaseBackend):
    backend_type = "cula"
    package_name = "cula"       # is_available() → checks if cula is installed
    env_var = "FLA_CULA"        # FLA_CULA=0 to disable
    priority = 3                # higher than default Triton (5)

    def chunk_gated_delta_rule_fwd_h_verifier(self, k, **kw):
        from cula.utils import is_blackwell
        if not is_blackwell(k.device):
            return False, "cuLA fwd_h requires Blackwell GPU"
        return True, None

    def chunk_gated_delta_rule_fwd_h(self, k, w, u, **kw):
        from cula.ops.chunk_delta_h import chunk_gated_delta_rule_fwd_h
        return chunk_gated_delta_rule_fwd_h(k=k, w=w, u=u, **kw)

No import cycle: FLA imports cula.ops.* (core) which has no FLA dependency.

FLA can add more dispatch points as cuLA implements more kernels:

# Future: FLA dispatches fwd_o to cuLA
class CuLAFwdOBackend(BaseBackend):
    def chunk_gla_fwd_o(self, q, v, g, A, h, o, **kw):
        from cula.ops.fwd_o import chunk_gla_fwd_o
        return chunk_gla_fwd_o(q=q, v=v, g=g, A=A, h=h, o=o, **kw)

Migration Plan

Phase 1: Internalize utils — make cula core FLA-free

Low-risk refactor within the current single-package structure:

  1. Copy prepare_chunk_indices, prepare_lens, tensor_cache into cula/utils.py
  2. Update imports in cula/ops/chunk_delta_h.py, cula/ops/fwd_o.py to use cula.utils
  3. Verify: all files under cula/ops/ and cula/lightning/ have zero fla.* imports
  4. Keep cula/kda/ unchanged — still imports FLA, still works as today

No user-facing change. The cula/ops/ and cula/lightning/ modules are now independently importable without FLA installed.

Phase 2: Extract cula-fla package

  1. Move cula/kda/ contents into a new cula-fla package (separate repo or subdirectory)
  2. cula-fla depends on cula (core) + fla
  3. cula retains cula/ops/, cula/lightning/, cula/utils.py
  4. Publish both to PyPI: pip install cula and pip install cula-fla

Users who currently do from cula.kda import chunk_kda switch to from cula_fla.kda import chunk_kda. Breaking change, acceptable at pre-1.0.

Phase 3: FLA backend registration

Work with FLA-org to add cuLA backend registration in FLA's codebase:

  1. FLA adds fla/ops/kda/backends/cula.py (or fla/ops/common/backends/cula.py)
  2. @dispatch decorated functions in FLA automatically try cuLA when installed
  3. FLA users get cuLA acceleration with zero code changes:
    pip install cula  # just install the core package
    from fla.ops.kda import chunk_kda  # automatically uses cuLA kernels on Blackwell

Phase 4: Expand cuLA core, shrink cula-fla

As cuLA implements more CUDA kernels, they move into core and FLA adds dispatch points:

Kernel Today Future
fwd_h (chunk_delta_h) cuLA core cuLA core
fwd_o cuLA core cuLA core
recompute_wu cuLA core (WIP) cuLA core
chunk_intra cula-fla (uses FLA Triton) cuLA core (CUDA)
gate_cumsum cula-fla (uses FLA Triton) cuLA core (CUDA)
backward cula-fla (uses FLA Triton) cuLA core (CUDA)
l2norm cula-fla (uses FLA) cuLA core (CUDA)
Lightning kernels cuLA core cuLA core

End state: cula-fla is deprecated and archived. Users use from fla.ops.kda import chunk_kda which transparently dispatches to cuLA CUDA kernels. The transitional wrapper has served its purpose and is no longer needed.

Alternatives Considered

1. No split — just vendor FLA utils

Copy all needed FLA code (gate, backward, intra-chunk, l2norm) into cuLA.

Rejected:

  • Duplicates ~2000 lines of algorithm logic with ongoing maintenance burden
  • Backward pass alone is ~500 lines of complex Triton code that evolves with FLA
  • Doesn't scale as cuLA adds GDN, GLA, and other algorithms

2. Single package with optional FLA extras (cula[fla])

Keep everything in one package, gate FLA-dependent code behind try: import fla.

Rejected:

  • cula/kda/ would fail at import time without FLA — confusing to users
  • Makes it unclear which parts of cuLA work standalone vs. require FLA
  • Circular dependency still exists when FLA tries to import cula

3. Monorepo (merge cuLA into FLA)

Rejected for now due to:

  • cuLA uses CuTe DSL + CUTLASS C++ — fundamentally different build system from FLA's pure Triton
  • Different contributor pools

However, cuLA plans to migrate all its kernels fully to CuTe DSL, phasing out the current CUTLASS C++ code. Once cuLA completes this transition, the build system and toolchain gap between the two projects shrinks significantly — both would use Python-based kernel authoring (Triton for FLA, CuTe DSL for cuLA) with similar packaging and CI requirements. At that point, merging cuLA into FLA as a monorepo becomes a viable and potentially preferable option — a single repo with unified CI, shared infrastructure, and a single pip install fla that includes both Triton fallbacks and optimized CUDA kernels. This should be revisited with the FLA-org maintainers once cuLA's CuTe DSL migration is further along.

Open Questions

  1. Package naming: cula-fla or keep it as a subpackage within the same repo (cula/contrib/fla_wrapper/)?
    • Separate package is cleaner for dependency management
    • Same repo subdirectory is easier to maintain during rapid development

Metadata

Metadata

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions