Skip to content

[WS1] Backward-pass consistency across all ops #153

Description

@Flink-ddd

Part of WS1 — Full Batch-Invariant Forward Chain (epic: #)

Why

A forward-aligned chain still breaks training if gradient reductions drift with batch shape — the optimizer then sees batch-dependent gradients and the run diverges even though inference looked aligned. Backward invariance is a separate, explicit requirement, not something the forward checks cover. This issue makes "backward also invariant" a first-class, cross-op acceptance condition.

Scope

Make batch-invariant backward a required, tested property of every WS1 op.

  • Define the backward-invariance check in the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness (gradient outputs compared across batch configs, same tolerance policy as forward).
  • Ensure each op's gradient reduction (dx, dweight, dW, etc.) uses a fixed, batch-shape-independent order — no atomicAdd in backward.
  • Cover the most reduction-heavy backward first: RMSNorm dweight, GEMM dW/dX, attention backward, embedding-grad scatter.
  • Validate gradients across batch=1/N, chunked-prefill on/off, and padding layouts.

Out of scope

  • Re-implementing each op (each op issue owns its own backward kernel and fix; this issue owns the cross-cutting requirement, reusable gradient check, and status matrix).
  • Optimizer / training-loop changes; multi-GPU gradient synchronization (WS2).
  • FP8 backward.

Acceptance criteria

Notes

Planned PRs

  • Add a reusable gradient-invariance assertion to the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness
  • Per-op backward test requirements (RMSNorm dweight, GEMM dW/dX, attention backward first)
  • Gradient-diff reporting utility (max abs / relative diff, tensor name, first failing op)
  • Enforce no-atomicAdd / fixed-order accumulation in backward paths
  • Per-op forward+backward invariance status matrix; wire full-chain backward into CI

Metadata

Metadata

Labels

component: testingAdd test cases and benchmark-related tasksfeatureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.sprint-0615

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions