Refactor the forward pass in Mamba2/SSM and Titans nodes.
1. Mamba2SSD._ssd_parallel (helix_lm/mamba2.py:183–200)
def _ssd_parallel(self, A_bar, B_bar, x_conv, C):
# ...
h = torch.zeros(batch * d_inner, d_state, ...)
ys = []
for t in range(seq_len): # ← 256 iterations
h = A_flat[:, t] * h + B_flat[:, t] * x_flat[:, t]
y_t = (h * C_flat[:, t]).sum(dim=-1)
ys.append(y_t)
y = torch.stack(ys, dim=1)
The "parallel" path is not parallel. The comment even admits it: "For now, use a batched sequential approach... True parallel scan requires custom CUDA kernels."
With seq_len=256, this loop runs 256 times per Mamba2 node forward. Each iteration launches 2–3 tiny CUDA kernels. At ~0.5ms launch overhead per kernel, that's ~300ms per node forward. With n_loops=2 and 2 Mamba2 nodes in the graph, that's 1.2s per batch forward. Backward pass is 2–3× slower → 3–4s. Still not 61s, but we're not done.
2. SSMNode.forward (helix_lm/nodes.py:270–282)
Same exact pattern:
for t in range(T):
h = A_bar[:, t] * h + B_bar[:, t] * x_conv[:, t].unsqueeze(-1)
y = (h * C[:, t].unsqueeze(1)).sum(dim=-1)
ys.append(y)
3. TitansMemoryNode.forward (helix_lm/nodes.py:458–478)
for t in range(T): # ← 256 iterations
k_t = k[:, t, :]
v_t = v[:, t, :]
v_pred = torch.einsum('bf,bfd->bd', k_t, M)
surprise = torch.norm(v_t - v_pred, ...)
delta = torch.matmul(k_t.unsqueeze(-1), v_t.unsqueeze(1))
M = M + eta... * surprise... * delta
M = F.layer_norm(M, M.shape[-2:])
This one is even worse — it does 5–6 kernel launches per token (einsum, norm, matmul, add, layer_norm). 256 × 6 = 1,536 launches. At 0.5ms each = 768ms per Titans node forward.
Why This Explains slow throughput:
Your NAS samples use_ssm ∈ {False, True} and use_titans ∈ {False, True}. If a trial samples both True with n_loops=2 and n_columns=2:
| Component |
Iterations per batch forward |
Time |
| 2× Mamba2 nodes × 2 loops × 256 seq |
1,024 token-steps |
~1.5s |
| 1× Titans node × 2 loops × 256 seq |
512 token-steps |
~1.5s |
| Backward pass (2–3× forward) |
— |
~6–9s |
| ACT halting syncs + graph overhead |
— |
~2s |
| Total per batch |
|
~10–15s |
But your log shows 61s, not 15s. The remaining factor is that PyTorch backward on these Python-loop graphs builds enormous computation graphs each loop iteration becomes a separate autograd node. With 1,500+ iterations per forward and backward, the autograd graph traversal alone takes 30–40 seconds.
Proper Fix for Mamba2/SSM/Titans (Post-Screening)
Replace the Python sequential loops with vectorized associative scans using torch.cumsum or a custom CUDA kernel.
For Mamba2 with diagonal A (which yours is, since A is a 1D vector per channel), the recurrence is:
h_t = a_t * h_{t-1} + b_t
This can be rewritten as:
h_t = (prod_{i=1..t} a_i) * h_0 + sum_{i=1..t} (b_i * prod_{j=i+1..t} a_j)
And computed in O(log T) via parallel scan, or approximated with torch.cumprod + torch.cumsum in O(T) but fully vectorized (no Python loop).
Quick vectorized fix for Mamba2SSD._ssd_parallel:
def _ssd_parallel(self, A_bar, B_bar, x_conv, C):
# A_bar: (B, T, D, N) — diagonal discretized A
# B_bar: (B, T, D, N)
# x_conv: (B, T, D)
# C: (B, T, N)
# Vectorized scan using cumprod + cumsum
# h_t = A_t * h_{t-1} + B_t * x_t
# For diagonal A, we can compute the full sequence in parallel:
Bx = B_bar * x_conv.unsqueeze(-1) # (B, T, D, N)
# Compute cumulative product of A
log_A = torch.log(A_bar.clamp(min=1e-10))
cum_log_A = torch.cumsum(log_A, dim=1) # (B, T, D, N)
cum_A = torch.exp(cum_log_A)
# Compute the "filtered" Bx weighted by future A products
# This is a standard parallel scan for first-order linear recurrences
# Using the "cumsum with decay" trick
decayed_Bx = Bx / cum_A # pre-weight
cum_decayed = torch.cumsum(decayed_Bx, dim=1)
h = cum_A * cum_decayed # (B, T, D, N)
# Output
y = (h * C.unsqueeze(2)).sum(dim=-1) # (B, T, D)
return y + self.D * x_conv, None
(Note: this is a sketch — the exact math depends on whether A is constant or time-varying per channel. For Mamba-2's selective SSM where A varies per token, a true parallel scan kernel is needed. The mamba-ssm package provides this.)
For Titans: The memory update is inherently sequential (M_t depends on M_{t-1}). But you can batch the outer-product updates into a single torch.einsum over the full sequence:
# Instead of looping over t:
# M_t = M_{t-1} + eta * surprise_t * k_t.outer(v_t)
# Rewrite as:
# M_T = M_0 + eta * sum_t (surprise_t * k_t.outer(v_t))
# But surprise_t depends on M_{t-1}, so this is not trivially parallelizable.
Refactor the forward pass in Mamba2/SSM and Titans nodes.
1.
Mamba2SSD._ssd_parallel(helix_lm/mamba2.py:183–200)The "parallel" path is not parallel. The comment even admits it: "For now, use a batched sequential approach... True parallel scan requires custom CUDA kernels."
With
seq_len=256, this loop runs 256 times per Mamba2 node forward. Each iteration launches 2–3 tiny CUDA kernels. At ~0.5ms launch overhead per kernel, that's ~300ms per node forward. Withn_loops=2and 2 Mamba2 nodes in the graph, that's 1.2s per batch forward. Backward pass is 2–3× slower → 3–4s. Still not 61s, but we're not done.2.
SSMNode.forward(helix_lm/nodes.py:270–282)Same exact pattern:
3.
TitansMemoryNode.forward(helix_lm/nodes.py:458–478)This one is even worse — it does 5–6 kernel launches per token (einsum, norm, matmul, add, layer_norm). 256 × 6 = 1,536 launches. At 0.5ms each = 768ms per Titans node forward.
Why This Explains slow throughput:
Your NAS samples
use_ssm ∈ {False, True}anduse_titans ∈ {False, True}. If a trial samples both True withn_loops=2andn_columns=2:But your log shows 61s, not 15s. The remaining factor is that PyTorch backward on these Python-loop graphs builds enormous computation graphs each loop iteration becomes a separate autograd node. With 1,500+ iterations per forward and backward, the autograd graph traversal alone takes 30–40 seconds.
Proper Fix for Mamba2/SSM/Titans (Post-Screening)
Replace the Python sequential loops with vectorized associative scans using
torch.cumsumor a custom CUDA kernel.For Mamba2 with diagonal A (which yours is, since
Ais a 1D vector per channel), the recurrence is:This can be rewritten as:
And computed in O(log T) via parallel scan, or approximated with
torch.cumprod+torch.cumsumin O(T) but fully vectorized (no Python loop).Quick vectorized fix for
Mamba2SSD._ssd_parallel:(Note: this is a sketch — the exact math depends on whether A is constant or time-varying per channel. For Mamba-2's selective SSM where A varies per token, a true parallel scan kernel is needed. The
mamba-ssmpackage provides this.)For Titans: The memory update is inherently sequential (M_t depends on M_{t-1}). But you can batch the outer-product updates into a single
torch.einsumover the full sequence: