Summary
XLA's AliasInfo::GetFusionInstructionInPlaceInputOutputPairs only traces backwards through kBitcast when determining whether a fusion operates in-place on its inputs. When a kTranspose (or kReshape) sits between a kParameter and an in-place operation like scatter, the function fails to identify the in-place relationship. This appears to cause downstream passes (buffer assignment / copy removal) to incorrectly allow concurrent writes to the same buffer, resulting in intermittent NaN/inf corruption on TPU.
I initially discovered this bug through a JAX model that uses a tied embedding matrix (one matrix used for both input token lookup and output projection). During backprop, this creates two gradient paths — a sparse one via scatter and a dense one via dot — that both write to the same parameter.
Root cause
The function GetFusionInstructionInPlaceInputOutputPairs in xla/hlo/analysis/alias_info.cc is responsible for identifying when a fusion modifies one of its input buffers in-place. It does this by tracing backwards from the in-place operation (e.g., scatter) to see if it connects to a kParameter (a fusion input).
The current logic traces through Tuple and at most one kBitcast:
// Skip bitcast
if (in_place_input_source != nullptr &&
in_place_input_source->opcode() == HloOpcode::kBitcast) {
in_place_input_source = in_place_input_source->operand(0);
}
In the HLO generated from my reproduction, the backward pass introduces a kTranspose between the kParameter and the scatter. Because kTranspose is not kBitcast, the tracing stops and the function reports no in-place relationship.
As best I can tell, this causes CopyRemover / BufferAssignment to conclude that the sparse fusion and the dense gradient path don't interfere, allowing them to share the same physical buffer without proper control dependencies. The two paths then write concurrently to the same memory, producing NaN/inf values.
Proposed fix
The fix is to trace backwards through all zero-copy shape/layout reinterpretation ops — operations that produce a view of the same underlying data rather than new independent data:
- // Skip bitcast
- if (in_place_input_source != nullptr &&
- in_place_input_source->opcode() == HloOpcode::kBitcast) {
- in_place_input_source = in_place_input_source->operand(0);
- }
+ // Skip shape/layout modifiers to find the originating parameter.
+ // These are zero-copy ops that reinterpret the same buffer without
+ // creating new independent data.
+ while (in_place_input_source != nullptr &&
+ (in_place_input_source->opcode() == HloOpcode::kBitcast ||
+ in_place_input_source->opcode() == HloOpcode::kTranspose ||
+ in_place_input_source->opcode() == HloOpcode::kReshape)) {
+ in_place_input_source = in_place_input_source->operand(0);
+ }
The if is also changed to a while to handle chains of multiple such operations.
System info
jax: 0.9.0.1
jaxlib: 0.9.0.1
numpy: 2.4.2
python: 3.12.12 (main, Feb 12 2026, 00:42:14) [Clang 21.1.4 ]
device info: TPU v6 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-43e8b932-w-0', release='6.8.0-1015-gcp', version='#17~22.04.1-Ubuntu SMP Tue Sep 3 16:11:52 UTC 2024', machine='x86_64')
Summary
XLA's
AliasInfo::GetFusionInstructionInPlaceInputOutputPairsonly traces backwards throughkBitcastwhen determining whether a fusion operates in-place on its inputs. When akTranspose(orkReshape) sits between akParameterand an in-place operation likescatter, the function fails to identify the in-place relationship. This appears to cause downstream passes (buffer assignment / copy removal) to incorrectly allow concurrent writes to the same buffer, resulting in intermittent NaN/inf corruption on TPU.I initially discovered this bug through a JAX model that uses a tied embedding matrix (one matrix used for both input token lookup and output projection). During backprop, this creates two gradient paths — a sparse one via
scatterand a dense one viadot— that both write to the same parameter.Root cause
The function
GetFusionInstructionInPlaceInputOutputPairsinxla/hlo/analysis/alias_info.ccis responsible for identifying when a fusion modifies one of its input buffers in-place. It does this by tracing backwards from the in-place operation (e.g.,scatter) to see if it connects to akParameter(a fusion input).The current logic traces through
Tupleand at most onekBitcast:In the HLO generated from my reproduction, the backward pass introduces a
kTransposebetween thekParameterand thescatter. BecausekTransposeis notkBitcast, the tracing stops and the function reports no in-place relationship.As best I can tell, this causes
CopyRemover/BufferAssignmentto conclude that the sparse fusion and the dense gradient path don't interfere, allowing them to share the same physical buffer without proper control dependencies. The two paths then write concurrently to the same memory, producing NaN/inf values.Proposed fix
The fix is to trace backwards through all zero-copy shape/layout reinterpretation ops — operations that produce a view of the same underlying data rather than new independent data:
The
ifis also changed to awhileto handle chains of multiple such operations.System info