Skip to content

Refactor-the-forward-pass-in-Mamba2-SSM-and-Titans-nodes. #22

@david-thrower

Description

@david-thrower

Refactor the forward pass in Mamba2/SSM and Titans nodes.


1. Mamba2SSD._ssd_parallel (helix_lm/mamba2.py:183–200)

def _ssd_parallel(self, A_bar, B_bar, x_conv, C):
    # ...
    h = torch.zeros(batch * d_inner, d_state, ...)
    ys = []
    for t in range(seq_len):           # ← 256 iterations
        h = A_flat[:, t] * h + B_flat[:, t] * x_flat[:, t]
        y_t = (h * C_flat[:, t]).sum(dim=-1)
        ys.append(y_t)
    y = torch.stack(ys, dim=1)

The "parallel" path is not parallel. The comment even admits it: "For now, use a batched sequential approach... True parallel scan requires custom CUDA kernels."

With seq_len=256, this loop runs 256 times per Mamba2 node forward. Each iteration launches 2–3 tiny CUDA kernels. At ~0.5ms launch overhead per kernel, that's ~300ms per node forward. With n_loops=2 and 2 Mamba2 nodes in the graph, that's 1.2s per batch forward. Backward pass is 2–3× slower → 3–4s. Still not 61s, but we're not done.

2. SSMNode.forward (helix_lm/nodes.py:270–282)

Same exact pattern:

for t in range(T):
    h = A_bar[:, t] * h + B_bar[:, t] * x_conv[:, t].unsqueeze(-1)
    y = (h * C[:, t].unsqueeze(1)).sum(dim=-1)
    ys.append(y)

3. TitansMemoryNode.forward (helix_lm/nodes.py:458–478)

for t in range(T):                     # ← 256 iterations
    k_t = k[:, t, :]
    v_t = v[:, t, :]
    v_pred = torch.einsum('bf,bfd->bd', k_t, M)
    surprise = torch.norm(v_t - v_pred, ...)
    delta = torch.matmul(k_t.unsqueeze(-1), v_t.unsqueeze(1))
    M = M + eta... * surprise... * delta
    M = F.layer_norm(M, M.shape[-2:])

This one is even worse — it does 5–6 kernel launches per token (einsum, norm, matmul, add, layer_norm). 256 × 6 = 1,536 launches. At 0.5ms each = 768ms per Titans node forward.


Why This Explains slow throughput:

Your NAS samples use_ssm ∈ {False, True} and use_titans ∈ {False, True}. If a trial samples both True with n_loops=2 and n_columns=2:

Component Iterations per batch forward Time
2× Mamba2 nodes × 2 loops × 256 seq 1,024 token-steps ~1.5s
1× Titans node × 2 loops × 256 seq 512 token-steps ~1.5s
Backward pass (2–3× forward) ~6–9s
ACT halting syncs + graph overhead ~2s
Total per batch ~10–15s

But your log shows 61s, not 15s. The remaining factor is that PyTorch backward on these Python-loop graphs builds enormous computation graphs each loop iteration becomes a separate autograd node. With 1,500+ iterations per forward and backward, the autograd graph traversal alone takes 30–40 seconds.


Proper Fix for Mamba2/SSM/Titans (Post-Screening)

Replace the Python sequential loops with vectorized associative scans using torch.cumsum or a custom CUDA kernel.

For Mamba2 with diagonal A (which yours is, since A is a 1D vector per channel), the recurrence is:

h_t = a_t * h_{t-1} + b_t

This can be rewritten as:

h_t = (prod_{i=1..t} a_i) * h_0 + sum_{i=1..t} (b_i * prod_{j=i+1..t} a_j)

And computed in O(log T) via parallel scan, or approximated with torch.cumprod + torch.cumsum in O(T) but fully vectorized (no Python loop).

Quick vectorized fix for Mamba2SSD._ssd_parallel:

def _ssd_parallel(self, A_bar, B_bar, x_conv, C):
    # A_bar: (B, T, D, N) — diagonal discretized A
    # B_bar: (B, T, D, N)
    # x_conv: (B, T, D)
    # C: (B, T, N)
    
    # Vectorized scan using cumprod + cumsum
    # h_t = A_t * h_{t-1} + B_t * x_t
    # For diagonal A, we can compute the full sequence in parallel:
    
    Bx = B_bar * x_conv.unsqueeze(-1)  # (B, T, D, N)
    
    # Compute cumulative product of A
    log_A = torch.log(A_bar.clamp(min=1e-10))
    cum_log_A = torch.cumsum(log_A, dim=1)  # (B, T, D, N)
    cum_A = torch.exp(cum_log_A)
    
    # Compute the "filtered" Bx weighted by future A products
    # This is a standard parallel scan for first-order linear recurrences
    # Using the "cumsum with decay" trick
    decayed_Bx = Bx / cum_A  # pre-weight
    cum_decayed = torch.cumsum(decayed_Bx, dim=1)
    h = cum_A * cum_decayed  # (B, T, D, N)
    
    # Output
    y = (h * C.unsqueeze(2)).sum(dim=-1)  # (B, T, D)
    return y + self.D * x_conv, None

(Note: this is a sketch — the exact math depends on whether A is constant or time-varying per channel. For Mamba-2's selective SSM where A varies per token, a true parallel scan kernel is needed. The mamba-ssm package provides this.)

For Titans: The memory update is inherently sequential (M_t depends on M_{t-1}). But you can batch the outer-product updates into a single torch.einsum over the full sequence:

# Instead of looping over t:
# M_t = M_{t-1} + eta * surprise_t * k_t.outer(v_t)
# Rewrite as:
# M_T = M_0 + eta * sum_t (surprise_t * k_t.outer(v_t))
# But surprise_t depends on M_{t-1}, so this is not trivially parallelizable.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions