A small self-contained JAX reproducer show that legal indexed operations using fill/drop semantics can fail with NRT_EXEC_OOB.
On Neuron, the same graphs can fail during execution with NRT_EXEC_OOB. Runtime logs report scatter/gather indirect memory copy via vector DGE attempting an out-of-bounds access. This suggests the inactive sentinel lane is being routed through an indirect memory access before fill/drop semantics are applied.
I am reporting the related cases together because they appear to be the same semantic area: inactive indexed lanes represented by max-int sentinels should not touch memory.
N/A. These is self-contained JAX reproducer.
Minimal JAX indexed gather/scatter/update graphs.
import argparse
import importlib.metadata
import os
import sys
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--jax-platforms", default=None)
parser.add_argument(
"--case",
choices=("drop_1d", "drop_rank4", "gather_fill_drop", "dense_zero_drop"),
required=True,
)
parser.add_argument("--index-dtype", default="int32")
parser.add_argument("--index-rank", choices=("vector", "scalar"), default="vector")
parser.add_argument("--seq-len", type=int, default=32)
parser.add_argument("--rows", type=int, default=8)
parser.add_argument("--hidden-dim", type=int, default=16)
parser.add_argument("--vocab-size", type=int, default=64)
parser.add_argument("--target-index", type=int, default=7)
parser.add_argument("--target-token", type=int, default=23)
args = parser.parse_args()
if args.jax_platforms:
os.environ["JAX_PLATFORMS"] = args.jax_platforms
import jax
if args.index_dtype in ("int64", "uint64"):
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
print(f"python={sys.version.split()[0]}")
print(f"numpy={np.__version__}")
print(f"jax={jax.__version__}")
print(f"jaxlib={importlib.metadata.version('jaxlib')}")
print(f"jax-neuronx={importlib.metadata.version('jax-neuronx')}")
print(f"libneuronxla={importlib.metadata.version('libneuronxla')}")
print(f"neuronx-cc={importlib.metadata.version('neuronx-cc')}")
print(f"jax_default_backend={jax.default_backend()}")
print(f"jax_devices={jax.devices()}")
def run_drop_1d():
print("\n=== drop_1d ===")
index_dtype = np.dtype(args.index_dtype)
sentinel = np.iinfo(index_dtype).max
def repro(values, indices, updates):
return values.at[indices].set(updates, mode="drop")
values = jnp.zeros((8,), dtype=jnp.int32)
if args.index_rank == "vector":
indices = jnp.array([3, sentinel], dtype=jnp.dtype(args.index_dtype))
updates = jnp.array([33, 99], dtype=jnp.int32)
expected = np.array([0, 0, 0, 33, 0, 0, 0, 0], dtype=np.int32)
else:
indices = jnp.array(sentinel, dtype=jnp.dtype(args.index_dtype))
updates = jnp.array(99, dtype=jnp.int32)
expected = np.zeros((8,), dtype=np.int32)
print(f"index_dtype={args.index_dtype}")
print(f"index_rank={args.index_rank}")
print(f"sentinel={int(sentinel)}")
print("\n--- stablehlo ---")
print(jax.jit(repro).lower(values, indices, updates).compiler_ir(dialect="stablehlo"))
actual = np.asarray(jax.jit(repro)(values, indices, updates))
print(f"actual={actual.tolist()}")
print(f"expected={expected.tolist()}")
if not np.array_equal(actual, expected):
raise AssertionError("drop_1d output mismatch")
print("PASS drop_1d")
def run_drop_rank4():
print("\n=== drop_rank4 ===")
max_i32 = np.iinfo(np.int32).max
def repro(values, indices, updates):
flat = values.reshape((values.shape[0] * values.shape[1], values.shape[2], values.shape[3]))
updated = flat.at[indices].set(updates, mode="drop")
return updated.reshape(values.shape)
values = jnp.zeros((2, 4, 2, 3), dtype=jnp.int32)
indices = jnp.array([5, max_i32], dtype=jnp.int32)
updates = jnp.arange(12, dtype=jnp.int32).reshape((2, 2, 3)) + 1
expected = np.zeros((2, 4, 2, 3), dtype=np.int32)
expected.reshape((8, 2, 3))[5, :, :] = np.arange(6, dtype=np.int32).reshape((2, 3)) + 1
print("\n--- stablehlo ---")
print(jax.jit(repro).lower(values, indices, updates).compiler_ir(dialect="stablehlo"))
actual = np.asarray(jax.jit(repro)(values, indices, updates))
print(f"indices={[5, int(max_i32)]}")
print(f"actual_flat={actual.reshape((8, 2, 3)).tolist()}")
print(f"expected_flat={expected.reshape((8, 2, 3)).tolist()}")
if not np.array_equal(actual, expected):
raise AssertionError("drop_rank4 output mismatch")
print("PASS drop_rank4")
def run_gather_fill_drop():
print("\n=== gather_fill_drop ===")
max_i32 = np.iinfo(np.int32).max
def repro(table, indices, projection, output):
gathered = jnp.take(table, indices, axis=0, mode="fill", fill_value=0)
logits = gathered.astype(jnp.float32) @ projection.T
updates = jnp.argmax(logits, axis=-1).astype(jnp.int32)
return output.at[indices].set(updates, mode="drop")
table_np = np.zeros((8, 4), dtype=np.float32)
table_np[2, 0] = 1.0
projection_np = np.zeros((8, 4), dtype=np.float32)
projection_np[5, 0] = 1.0
table = jnp.array(table_np)
projection = jnp.array(projection_np)
output = jnp.zeros((8,), dtype=jnp.int32)
indices = jnp.array([2, max_i32], dtype=jnp.int32)
expected = np.array([0, 0, 5, 0, 0, 0, 0, 0], dtype=np.int32)
print("\n--- stablehlo ---")
print(jax.jit(repro).lower(table, indices, projection, output).compiler_ir(dialect="stablehlo"))
actual = np.asarray(jax.jit(repro)(table, indices, projection, output))
print(f"indices={[2, int(max_i32)]}")
print(f"actual={actual.tolist()}")
print(f"expected={expected.tolist()}")
if not np.array_equal(actual, expected):
raise AssertionError("gather_fill_drop output mismatch")
print("PASS gather_fill_drop")
def run_dense_zero_drop():
print("\n=== dense_zero_drop ===")
max_i32 = np.iinfo(np.int32).max
if not (0 <= args.target_index < min(args.seq_len, args.rows)):
raise ValueError("--target-index must fit in both --seq-len and --rows")
if not (0 <= args.target_token < args.vocab_size):
raise ValueError("--target-token must be smaller than --vocab-size")
zero_literal = np.zeros((args.seq_len,), dtype=np.uint32)
def repro(activations, projection, index):
selected = jnp.take(activations.astype(jnp.bfloat16), index, axis=0, mode="fill", fill_value=0)
logits = selected.astype(jnp.float32) @ projection.astype(jnp.bfloat16).T.astype(jnp.float32)
choice = jnp.argmax(logits, axis=-1).astype(jnp.uint32)
base = jnp.asarray(zero_literal)
sparse = base.at[index].set(choice, mode="drop")
out = jnp.asarray(zero_literal)
return jax.lax.dynamic_update_slice(out, sparse, (jnp.array(0, dtype=jnp.int32),))
activations_np = np.zeros((args.rows, args.hidden_dim), dtype=np.float32)
activations_np[args.target_index, 0] = 1.0
projection_np = np.zeros((args.vocab_size, args.hidden_dim), dtype=np.float32)
projection_np[args.target_token, 0] = 1.0
index_np = np.array([args.target_index, max_i32], dtype=np.int32)
expected = np.zeros((args.seq_len,), dtype=np.uint32)
expected[args.target_index] = np.uint32(args.target_token)
activations = jnp.array(activations_np)
projection = jnp.array(projection_np)
index = jnp.array(index_np)
print("\n--- stablehlo ---")
print(jax.jit(repro).lower(activations, projection, index).compiler_ir(dialect="stablehlo"))
actual = np.asarray(jax.jit(repro)(activations, projection, index))
print(f"shape seq_len={args.seq_len} rows={args.rows} hidden_dim={args.hidden_dim} vocab_size={args.vocab_size}")
print(f"indices={[args.target_index, int(max_i32)]}")
print(f"actual_prefix={actual[: min(16, actual.size)].tolist()}")
print(f"expected_prefix={expected[: min(16, expected.size)].tolist()}")
if not np.array_equal(actual, expected):
raise AssertionError("dense_zero_drop output mismatch")
print("PASS dense_zero_drop")
cases = {
"drop_1d": run_drop_1d,
"drop_rank4": run_drop_rank4,
"gather_fill_drop": run_gather_fill_drop,
"dense_zero_drop": run_dense_zero_drop,
}
cases[args.case]()
Describe the bug
A small self-contained JAX reproducer show that legal indexed operations using fill/drop semantics can fail with NRT_EXEC_OOB.
The common pattern is:
mode="fill"and/or scatter/update withmode="drop"On CPU, these operations behave as expected:
On Neuron, the same graphs can fail during execution with NRT_EXEC_OOB. Runtime logs report scatter/gather indirect memory copy via vector DGE attempting an out-of-bounds access. This suggests the inactive sentinel lane is being routed through an indirect memory access before fill/drop semantics are applied.
I am reporting the related cases together because they appear to be the same semantic area: inactive indexed lanes represented by max-int sentinels should not touch memory.
Model Name
N/A. These is self-contained JAX reproducer.
Describe the workload type
Minimal JAX indexed gather/scatter/update graphs.
The attached reproducer cover:
Instance Type
inf2.8xlarge
Release version
python=3.12.12
numpy=2.4.6
jax=0.7.0
jaxlib=0.7.0
jax-neuronx=0.7.0.1.0.8181+1e892be0
libneuronxla=2.2.16408.0+50c26cbd
neuronx-cc=2.24.8799.0+6f62ff7c
Reproduction Steps
CPU references:
python main.py --jax-platforms=cpu --case=drop_1dpython main.py --jax-platforms=cpu --case=drop_rank4python main.py --jax-platforms=cpu --case=gather_fill_droppython main.py --jax-platforms=cpu --case=dense_zero_dropNeuron runs:
NEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=drop_1dNEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=drop_rank4NEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=gather_fill_dropNEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=dense_zero_dropOptional dtype/rank probe:
python main.py --jax-platforms=cpu --case=drop_1d --index-dtype=uint32 --index-rank=scalarNEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=drop_1d --index-dtype=uint32 --index-rank=scalarmain.py:Regression Issue
Possible Solution
N/A
Logs/Context/Additional Information
No response