Skip to content

JAX indexed fill/drop semantics with max-int sentinel can execute OOB indirect memory accesses #1335

@hugomano

Description

@hugomano

Describe the bug

A small self-contained JAX reproducer show that legal indexed operations using fill/drop semantics can fail with NRT_EXEC_OOB.

The common pattern is:

  • one live integer index
  • one inactive integer sentinel index set to the max value for that index dtype
  • gather with mode="fill" and/or scatter/update with mode="drop"

On CPU, these operations behave as expected:

  • fill-mode gather returns the fill value for the sentinel row
  • drop-mode scatter ignores the sentinel update
  • only live rows are read or written

On Neuron, the same graphs can fail during execution with NRT_EXEC_OOB. Runtime logs report scatter/gather indirect memory copy via vector DGE attempting an out-of-bounds access. This suggests the inactive sentinel lane is being routed through an indirect memory access before fill/drop semantics are applied.

I am reporting the related cases together because they appear to be the same semantic area: inactive indexed lanes represented by max-int sentinels should not touch memory.

Model Name

N/A. These is self-contained JAX reproducer.

Describe the workload type

Minimal JAX indexed gather/scatter/update graphs.

The attached reproducer cover:

  1. 1D scatter/update with mode="drop" and a max-int sentinel index.
  2. Rank-4 batched slice update flattened to an indexed table, with mode="drop".
  3. Gather with mode="fill" followed by scatter/update with mode="drop" using the same sentinel indices.
  4. Dense integer zero/update graph followed by sentinel scatter/drop, where generated StableHLO includes dense integer zero tensors near scatter/update operations.

Instance Type

inf2.8xlarge

Release version

python=3.12.12
numpy=2.4.6
jax=0.7.0
jaxlib=0.7.0
jax-neuronx=0.7.0.1.0.8181+1e892be0
libneuronxla=2.2.16408.0+50c26cbd
neuronx-cc=2.24.8799.0+6f62ff7c

Reproduction Steps

CPU references:

  • python main.py --jax-platforms=cpu --case=drop_1d
  • python main.py --jax-platforms=cpu --case=drop_rank4
  • python main.py --jax-platforms=cpu --case=gather_fill_drop
  • python main.py --jax-platforms=cpu --case=dense_zero_drop

Neuron runs:

  • NEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=drop_1d
  • NEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=drop_rank4
  • NEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=gather_fill_drop
  • NEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=dense_zero_drop

Optional dtype/rank probe:

  • python main.py --jax-platforms=cpu --case=drop_1d --index-dtype=uint32 --index-rank=scalar
  • NEURON_RT_LOG_LEVEL=warning NEURON_RT_VISIBLE_CORES=0 python main.py --jax-platforms=neuron,cpu --case=drop_1d --index-dtype=uint32 --index-rank=scalar

main.py :

import argparse
import importlib.metadata
import os
import sys


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--jax-platforms", default=None)
    parser.add_argument(
        "--case",
        choices=("drop_1d", "drop_rank4", "gather_fill_drop", "dense_zero_drop"),
        required=True,
    )
    parser.add_argument("--index-dtype", default="int32")
    parser.add_argument("--index-rank", choices=("vector", "scalar"), default="vector")
    parser.add_argument("--seq-len", type=int, default=32)
    parser.add_argument("--rows", type=int, default=8)
    parser.add_argument("--hidden-dim", type=int, default=16)
    parser.add_argument("--vocab-size", type=int, default=64)
    parser.add_argument("--target-index", type=int, default=7)
    parser.add_argument("--target-token", type=int, default=23)
    args = parser.parse_args()

    if args.jax_platforms:
        os.environ["JAX_PLATFORMS"] = args.jax_platforms

    import jax

    if args.index_dtype in ("int64", "uint64"):
        jax.config.update("jax_enable_x64", True)

    import jax.numpy as jnp
    import numpy as np

    print(f"python={sys.version.split()[0]}")
    print(f"numpy={np.__version__}")
    print(f"jax={jax.__version__}")
    print(f"jaxlib={importlib.metadata.version('jaxlib')}")
    print(f"jax-neuronx={importlib.metadata.version('jax-neuronx')}")
    print(f"libneuronxla={importlib.metadata.version('libneuronxla')}")
    print(f"neuronx-cc={importlib.metadata.version('neuronx-cc')}")
    print(f"jax_default_backend={jax.default_backend()}")
    print(f"jax_devices={jax.devices()}")

    def run_drop_1d():
        print("\n=== drop_1d ===")
        index_dtype = np.dtype(args.index_dtype)
        sentinel = np.iinfo(index_dtype).max

        def repro(values, indices, updates):
            return values.at[indices].set(updates, mode="drop")

        values = jnp.zeros((8,), dtype=jnp.int32)
        if args.index_rank == "vector":
            indices = jnp.array([3, sentinel], dtype=jnp.dtype(args.index_dtype))
            updates = jnp.array([33, 99], dtype=jnp.int32)
            expected = np.array([0, 0, 0, 33, 0, 0, 0, 0], dtype=np.int32)
        else:
            indices = jnp.array(sentinel, dtype=jnp.dtype(args.index_dtype))
            updates = jnp.array(99, dtype=jnp.int32)
            expected = np.zeros((8,), dtype=np.int32)

        print(f"index_dtype={args.index_dtype}")
        print(f"index_rank={args.index_rank}")
        print(f"sentinel={int(sentinel)}")
        print("\n--- stablehlo ---")
        print(jax.jit(repro).lower(values, indices, updates).compiler_ir(dialect="stablehlo"))

        actual = np.asarray(jax.jit(repro)(values, indices, updates))
        print(f"actual={actual.tolist()}")
        print(f"expected={expected.tolist()}")
        if not np.array_equal(actual, expected):
            raise AssertionError("drop_1d output mismatch")
        print("PASS drop_1d")

    def run_drop_rank4():
        print("\n=== drop_rank4 ===")
        max_i32 = np.iinfo(np.int32).max

        def repro(values, indices, updates):
            flat = values.reshape((values.shape[0] * values.shape[1], values.shape[2], values.shape[3]))
            updated = flat.at[indices].set(updates, mode="drop")
            return updated.reshape(values.shape)

        values = jnp.zeros((2, 4, 2, 3), dtype=jnp.int32)
        indices = jnp.array([5, max_i32], dtype=jnp.int32)
        updates = jnp.arange(12, dtype=jnp.int32).reshape((2, 2, 3)) + 1
        expected = np.zeros((2, 4, 2, 3), dtype=np.int32)
        expected.reshape((8, 2, 3))[5, :, :] = np.arange(6, dtype=np.int32).reshape((2, 3)) + 1

        print("\n--- stablehlo ---")
        print(jax.jit(repro).lower(values, indices, updates).compiler_ir(dialect="stablehlo"))

        actual = np.asarray(jax.jit(repro)(values, indices, updates))
        print(f"indices={[5, int(max_i32)]}")
        print(f"actual_flat={actual.reshape((8, 2, 3)).tolist()}")
        print(f"expected_flat={expected.reshape((8, 2, 3)).tolist()}")
        if not np.array_equal(actual, expected):
            raise AssertionError("drop_rank4 output mismatch")
        print("PASS drop_rank4")

    def run_gather_fill_drop():
        print("\n=== gather_fill_drop ===")
        max_i32 = np.iinfo(np.int32).max

        def repro(table, indices, projection, output):
            gathered = jnp.take(table, indices, axis=0, mode="fill", fill_value=0)
            logits = gathered.astype(jnp.float32) @ projection.T
            updates = jnp.argmax(logits, axis=-1).astype(jnp.int32)
            return output.at[indices].set(updates, mode="drop")

        table_np = np.zeros((8, 4), dtype=np.float32)
        table_np[2, 0] = 1.0
        projection_np = np.zeros((8, 4), dtype=np.float32)
        projection_np[5, 0] = 1.0
        table = jnp.array(table_np)
        projection = jnp.array(projection_np)
        output = jnp.zeros((8,), dtype=jnp.int32)
        indices = jnp.array([2, max_i32], dtype=jnp.int32)
        expected = np.array([0, 0, 5, 0, 0, 0, 0, 0], dtype=np.int32)

        print("\n--- stablehlo ---")
        print(jax.jit(repro).lower(table, indices, projection, output).compiler_ir(dialect="stablehlo"))

        actual = np.asarray(jax.jit(repro)(table, indices, projection, output))
        print(f"indices={[2, int(max_i32)]}")
        print(f"actual={actual.tolist()}")
        print(f"expected={expected.tolist()}")
        if not np.array_equal(actual, expected):
            raise AssertionError("gather_fill_drop output mismatch")
        print("PASS gather_fill_drop")

    def run_dense_zero_drop():
        print("\n=== dense_zero_drop ===")
        max_i32 = np.iinfo(np.int32).max
        if not (0 <= args.target_index < min(args.seq_len, args.rows)):
            raise ValueError("--target-index must fit in both --seq-len and --rows")
        if not (0 <= args.target_token < args.vocab_size):
            raise ValueError("--target-token must be smaller than --vocab-size")

        zero_literal = np.zeros((args.seq_len,), dtype=np.uint32)

        def repro(activations, projection, index):
            selected = jnp.take(activations.astype(jnp.bfloat16), index, axis=0, mode="fill", fill_value=0)
            logits = selected.astype(jnp.float32) @ projection.astype(jnp.bfloat16).T.astype(jnp.float32)
            choice = jnp.argmax(logits, axis=-1).astype(jnp.uint32)
            base = jnp.asarray(zero_literal)
            sparse = base.at[index].set(choice, mode="drop")
            out = jnp.asarray(zero_literal)
            return jax.lax.dynamic_update_slice(out, sparse, (jnp.array(0, dtype=jnp.int32),))

        activations_np = np.zeros((args.rows, args.hidden_dim), dtype=np.float32)
        activations_np[args.target_index, 0] = 1.0
        projection_np = np.zeros((args.vocab_size, args.hidden_dim), dtype=np.float32)
        projection_np[args.target_token, 0] = 1.0
        index_np = np.array([args.target_index, max_i32], dtype=np.int32)
        expected = np.zeros((args.seq_len,), dtype=np.uint32)
        expected[args.target_index] = np.uint32(args.target_token)

        activations = jnp.array(activations_np)
        projection = jnp.array(projection_np)
        index = jnp.array(index_np)

        print("\n--- stablehlo ---")
        print(jax.jit(repro).lower(activations, projection, index).compiler_ir(dialect="stablehlo"))

        actual = np.asarray(jax.jit(repro)(activations, projection, index))
        print(f"shape seq_len={args.seq_len} rows={args.rows} hidden_dim={args.hidden_dim} vocab_size={args.vocab_size}")
        print(f"indices={[args.target_index, int(max_i32)]}")
        print(f"actual_prefix={actual[: min(16, actual.size)].tolist()}")
        print(f"expected_prefix={expected[: min(16, expected.size)].tolist()}")
        if not np.array_equal(actual, expected):
            raise AssertionError("dense_zero_drop output mismatch")
        print("PASS dense_zero_drop")

    cases = {
        "drop_1d": run_drop_1d,
        "drop_rank4": run_drop_rank4,
        "gather_fill_drop": run_gather_fill_drop,
        "dense_zero_drop": run_dense_zero_drop,
    }
    cases[args.case]()

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

N/A

Logs/Context/Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Inf2bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions