Skip to content

Input-side action-dim masking for heterogeneous-DoF mixtures (action_in_proj contamination) #345

@shuheng-liu

Description

@shuheng-liu

Background

PR #344 introduced per-dim masking of the flow-matching MSE for heterogeneous-DoF co-training: the velocity-field loss no longer scores zero-padded action dims, so the action expert isn't supervised against (0, 0)-stats targets it doesn't actually need to predict.

The PR review surfaced an unaddressed residual contamination path:

  • Loss-side (fixed in fix(policies): mask zero-pad action dims in flow-matching MSE #344): zero-padded action dims used to receive a "predict 0 here" signal. The dim mask now AND-s into the MSE reduction → no gradient pressure on the padded tail.
  • Input-side (this issue): even with the loss-side fix, the action expert still sees noise[padded_dim] ~ N(0, 1) at the input to action_in_proj during training, because flow matching forms x_t = (1-t)·noise + t·actions and actions are zero at padded dims but noise isn't. Those noise samples flow through the embedding into attention and contribute a small mean-zero perturbation to predictions at real dims. For heterogeneous-DoF mixtures the padded dims aren't sample-independent (they vary per source dataset), so the perturbation doesn't cleanly cancel.

What was tried and reverted

Commit 1b1d9cf (reverted in c46ed20) added training-time input-side masking: multiply noise *= dim_mask_expanded right after sampling, so x_t and u_t are both zero at padded dims and action_in_proj sees clean zeros.

The review (item 14, see the PR comment thread) flagged a train/inference distribution mismatch that makes the partial fix arguably worse than no fix:

  • The inference path (sample_actions) was not updated — it still samples unmasked noise (actions_shape = (bsize, chunk_size, max_action_dim)noise[padded] ~ N(0,1)).
  • Training-time masking means action_in_proj's W[:, padded] columns receive zero gradient and stay at random initialization.
  • At inference those same columns multiply nonzero (1-t)·noise[padded] inputs, producing a small random perturbation that the model never learned to suppress.

Net: contamination at training eliminated, contamination at inference unchanged (and now with no training-time pressure to suppress it). Pre-revert had train/inference symmetry; post-revert restores that symmetry.

Options to consider

  1. Mask both training and inference. Apply the same make_action_dim_mask to noise at inference time, using config.action_feature.shape[0] (the deployed homogeneous DoF, already used for output truncation via original_action_dim) as the per-sample real dim. Keeps train/inference symmetric and removes the contamination in both regimes. Requires updating sample_actions in all six low-level policies plus the matching denoise_step / iterative-Euler paths.

  2. Leave as-is (current state after revert). Accept the input-side contamination as documented footnote. The loss-side fix already does the heavy lifting; the input-side contribution is mean-zero noise that contributes added variance, not bias. The reviewer's framing was: "This is much smaller than the loss-side contamination (no direct supervision pressure)."

  3. Empirically benchmark. Run a controlled comparison (one mixture, both states) and measure whether the input-side fix moves a downstream metric (real-policy regression on a held-out heterogeneous mixture, or LIBERO eval, or both). Decide based on signal.

Open questions

  • Does input-side masking at inference actually help, given that the model trained without it has already (presumably) learned to attend less to padded-dim embeddings via gradient on the output dims that depend on them through attention?
  • Are there other downstream callers of sample_actions (rollouts, eval scripts) that would need a matching change?
  • How does this interact with the original_action_dim truncation at inference (actions = actions[:, :, :original_action_dim])? Is masking-then-truncating equivalent to truncating-the-input?

Acceptance criteria

A decision (option 1, 2, or 3) with empirical or analytical justification, recorded on this issue. If option 1, a follow-up PR implementing inference-side masking in all six policies; if option 3, the benchmark results posted before deciding.

References

  • PR #344 — loss-side fix that landed
  • Review thread item 11 (originally flagged as future work)
  • Review thread item 14 (train/inference mismatch in the reverted attempt)
  • Reverted commit: 1b1d9cf
  • Revert commit: c46ed20

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions