From 79682e61101af5eb778c7e5ea389ff0ba202bd87 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 19 Feb 2026 13:54:56 -0800 Subject: [PATCH 01/10] Fix batcher for when segment ids received are batched/vmapped whereas the TE constructed segment pos are not thereby causing mismatches in impl() Signed-off-by: Kshitij Lakhani --- .../jax/cpp_extensions/attention.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e5d75e1501..6ffd68257d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -624,12 +624,57 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return output, softmax_aux, rng_state + # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=seed, + # 6,7=seqlens, 8,9=seq_offsets, 10,11=segment_ids, 12,13=segment_pos. + _SEGMENT_IDS_BATCH_DIMS_IDX = (10, 11) + _SEGMENT_POS_BATCH_DIMS_IDX = (12, 13) + @staticmethod def batcher(batched_args, batch_dims, *, config): + # batch_dims: tuple of length len(batched_args); each element is the axis index + # that is the batch axis (0, 1, ...) or None if that arg has no batch dim. + # check_valid_batch_dims: only 0 or None allowed (single leading batch or no batch). check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims + # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. + # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when + # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). + batched_args_list = list(batched_args) + seg_id_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]] + seg_pos_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]] + # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos + if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[10].size > 0: + # Pair (segment_ids idx, segment_pos idx): (10, 12) for q, (11, 13) for kv. + for seg_id_idx, seg_pos_idx in zip( + FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, + FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + ): + segment_ids = batched_args_list[seg_id_idx] + segment_pos = batched_args_list[seg_pos_idx] + assert segment_ids.ndim > segment_pos.ndim, ( + "segment_ids must have more dims than segment_pos when adding batch dims; " + f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" + ) + assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, ( + "segment_pos must have same trailing shape as segment_ids when adding batch dims; " + f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}" + ) + # Expand the segment_pos by as many batch dims as the segment_ids has + leading_bdim = segment_ids.ndim - segment_pos.ndim + target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape + expanded = segment_pos + for _ in range(leading_bdim): + expanded = lax.expand_dims(expanded, (0,)) + batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) + # Update the batch_dims to use 0 instead of None for segment_pos batch dims + batch_dims = tuple( + 0 if i in FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b + for i, b in enumerate(batch_dims) + ) + batched_args = tuple(batched_args_list) + out_bdims = q_bdim, q_bdim, seed_bdim return ( FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), @@ -1079,12 +1124,54 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return dq, dk, dv, dbias, dsoftmax_offset + # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=softmax_aux, + # 6=rng_state, 7=output, 8=doutput, 9,10=seqlens, 11,12=seq_offsets, + # 13,14=segment_ids, 15,16=segment_pos. + _SEGMENT_IDS_BATCH_DIMS_IDX = (13, 14) + _SEGMENT_POS_BATCH_DIMS_IDX = (15, 16) + @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims + # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. + # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when + # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). + batched_args_list = list(batched_args) + seg_id_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]] + seg_pos_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]] + # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos + if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[13].size > 0: + for seg_id_idx, seg_pos_idx in zip( + FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, + FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + ): + segment_ids = batched_args_list[seg_id_idx] + segment_pos = batched_args_list[seg_pos_idx] + assert segment_ids.ndim > segment_pos.ndim, ( + "segment_ids must have more dims than segment_pos when adding batch dims; " + f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" + ) + assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, ( + "segment_pos must have same trailing shape as segment_ids when adding batch dims; " + f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}" + ) + leading_bdim = segment_ids.ndim - segment_pos.ndim + target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape + # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos + expanded = segment_pos + for _ in range(leading_bdim): + expanded = lax.expand_dims(expanded, (0,)) + batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) + # Update the batch_dims to use 0 instead of None for segment_pos batch dims + batch_dims = tuple( + 0 if i in FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b + for i, b in enumerate(batch_dims) + ) + batched_args = tuple(batched_args_list) + out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim return ( FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), From 35d6d0f77db55d14b4392c00330e061f253bfcb7 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Mon, 23 Feb 2026 17:27:12 -0800 Subject: [PATCH 02/10] nit: Fix the shape check for assert Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6ffd68257d..6119dfb7df 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -657,7 +657,7 @@ def batcher(batched_args, batch_dims, *, config): "segment_ids must have more dims than segment_pos when adding batch dims; " f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) - assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, ( + assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( "segment_pos must have same trailing shape as segment_ids when adding batch dims; " f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}" ) @@ -1154,7 +1154,7 @@ def batcher(batched_args, batch_dims, *, config): "segment_ids must have more dims than segment_pos when adding batch dims; " f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) - assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, ( + assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( "segment_pos must have same trailing shape as segment_ids when adding batch dims; " f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}" ) From 1ec442d6cd3880391e01fd1e58bdd431f7c74569 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Feb 2026 01:28:37 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/attention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6119dfb7df..dca512b8ce 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -658,8 +658,9 @@ def batcher(batched_args, batch_dims, *, config): f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( - "segment_pos must have same trailing shape as segment_ids when adding batch dims; " - f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}" + "segment_pos must have same trailing shape as segment_ids when adding batch" + f" dims; got segment_ids.shape={segment_ids.shape}," + f" segment_pos.shape={segment_pos.shape}" ) # Expand the segment_pos by as many batch dims as the segment_ids has leading_bdim = segment_ids.ndim - segment_pos.ndim @@ -1155,8 +1156,9 @@ def batcher(batched_args, batch_dims, *, config): f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( - "segment_pos must have same trailing shape as segment_ids when adding batch dims; " - f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}" + "segment_pos must have same trailing shape as segment_ids when adding batch" + f" dims; got segment_ids.shape={segment_ids.shape}," + f" segment_pos.shape={segment_pos.shape}" ) leading_bdim = segment_ids.ndim - segment_pos.ndim target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape From 0967b86bd31b5d134217ca458114ab43e763b410 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 24 Feb 2026 18:12:37 -0800 Subject: [PATCH 04/10] Fix the batcher logic to check for q and kv seg ids separately Signed-off-by: Kshitij Lakhani --- .../jax/cpp_extensions/attention.py | 125 ++++++++++-------- 1 file changed, 70 insertions(+), 55 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index dca512b8ce..f616593370 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -641,39 +641,48 @@ def batcher(batched_args, batch_dims, *, config): # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). + # Check expansion per (q, kv) pair so q and kv can be batched/vmapped independently. batched_args_list = list(batched_args) - seg_id_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]] - seg_pos_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]] - # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos - if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[10].size > 0: - # Pair (segment_ids idx, segment_pos idx): (10, 12) for q, (11, 13) for kv. - for seg_id_idx, seg_pos_idx in zip( - FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, - FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + updated_batch_dims = list(batch_dims) + for seg_id_idx, seg_pos_idx in zip( + FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, + FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + ): + seg_id_bdim = batch_dims[seg_id_idx] + seg_pos_bdim = batch_dims[seg_pos_idx] + if not ( + seg_id_bdim is not None + and seg_pos_bdim is None + and batched_args_list[seg_id_idx].size > 0 + and batched_args_list[seg_pos_idx].size > 0 ): - segment_ids = batched_args_list[seg_id_idx] - segment_pos = batched_args_list[seg_pos_idx] - assert segment_ids.ndim > segment_pos.ndim, ( - "segment_ids must have more dims than segment_pos when adding batch dims; " + # Do no batch dim expansion if there's no vmapped function + continue + segment_ids = batched_args_list[seg_id_idx] + segment_pos = batched_args_list[seg_pos_idx] + # The segment_ids, at the very least, must have the same number of dimensions as segment_pos. + # Either because the user created them or because TE generated them. + if segment_ids.ndim < segment_pos.ndim: + raise AssertionError( + "segment_ids must not have fewer dims than segment_pos; " f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) - assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( - "segment_pos must have same trailing shape as segment_ids when adding batch" - f" dims; got segment_ids.shape={segment_ids.shape}," - f" segment_pos.shape={segment_pos.shape}" - ) - # Expand the segment_pos by as many batch dims as the segment_ids has - leading_bdim = segment_ids.ndim - segment_pos.ndim - target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape - expanded = segment_pos - for _ in range(leading_bdim): - expanded = lax.expand_dims(expanded, (0,)) - batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) - # Update the batch_dims to use 0 instead of None for segment_pos batch dims - batch_dims = tuple( - 0 if i in FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b - for i, b in enumerate(batch_dims) + if segment_ids.ndim == segment_pos.ndim: + # Do no batch dim expansion if there's no dim mismatch. + continue + assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( + "segment_pos must have same trailing shape as segment_ids when adding batch" + f" dims; got segment_ids.shape={segment_ids.shape}," + f" segment_pos.shape={segment_pos.shape}" ) + leading_bdim = segment_ids.ndim - segment_pos.ndim + target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape + expanded = segment_pos + for _ in range(leading_bdim): + expanded = lax.expand_dims(expanded, (0,)) + batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) + updated_batch_dims[seg_pos_idx] = 0 + batch_dims = tuple(updated_batch_dims) batched_args = tuple(batched_args_list) out_bdims = q_bdim, q_bdim, seed_bdim @@ -1140,38 +1149,44 @@ def batcher(batched_args, batch_dims, *, config): # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). + # Check expansion per (q, kv) pair so q and kv can be batched/vmapped independently. batched_args_list = list(batched_args) - seg_id_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]] - seg_pos_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]] - # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos - if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[13].size > 0: - for seg_id_idx, seg_pos_idx in zip( - FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, - FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + updated_batch_dims = list(batch_dims) + for seg_id_idx, seg_pos_idx in zip( + FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, + FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + ): + seg_id_bdim = batch_dims[seg_id_idx] + seg_pos_bdim = batch_dims[seg_pos_idx] + if not ( + seg_id_bdim is not None + and seg_pos_bdim is None + and batched_args_list[seg_id_idx].size > 0 + and batched_args_list[seg_pos_idx].size > 0 ): - segment_ids = batched_args_list[seg_id_idx] - segment_pos = batched_args_list[seg_pos_idx] - assert segment_ids.ndim > segment_pos.ndim, ( - "segment_ids must have more dims than segment_pos when adding batch dims; " + continue + segment_ids = batched_args_list[seg_id_idx] + segment_pos = batched_args_list[seg_pos_idx] + if segment_ids.ndim < segment_pos.ndim: + raise AssertionError( + "segment_ids must not have fewer dims than segment_pos; " f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) - assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( - "segment_pos must have same trailing shape as segment_ids when adding batch" - f" dims; got segment_ids.shape={segment_ids.shape}," - f" segment_pos.shape={segment_pos.shape}" - ) - leading_bdim = segment_ids.ndim - segment_pos.ndim - target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape - # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos - expanded = segment_pos - for _ in range(leading_bdim): - expanded = lax.expand_dims(expanded, (0,)) - batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) - # Update the batch_dims to use 0 instead of None for segment_pos batch dims - batch_dims = tuple( - 0 if i in FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b - for i, b in enumerate(batch_dims) + if segment_ids.ndim == segment_pos.ndim: + continue + assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( + "segment_pos must have same trailing shape as segment_ids when adding batch" + f" dims; got segment_ids.shape={segment_ids.shape}," + f" segment_pos.shape={segment_pos.shape}" ) + leading_bdim = segment_ids.ndim - segment_pos.ndim + target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape + expanded = segment_pos + for _ in range(leading_bdim): + expanded = lax.expand_dims(expanded, (0,)) + batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) + updated_batch_dims[seg_pos_idx] = 0 + batch_dims = tuple(updated_batch_dims) batched_args = tuple(batched_args_list) out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim From ea8e5d6979a376535e869641b570dd8972aebc84 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 27 Feb 2026 10:22:51 -0800 Subject: [PATCH 05/10] Remove batcher logic to expand segment pos. Keep the shape check asserts. Signed-off-by: Kshitij Lakhani --- .../jax/cpp_extensions/attention.py | 48 ++++--------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index f616593370..41daa70585 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -638,10 +638,10 @@ def batcher(batched_args, batch_dims, *, config): assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims - # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. - # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when - # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). - # Check expansion per (q, kv) pair so q and kv can be batched/vmapped independently. + # When segment_ids are batched (vmap) and segment_pos are not, do not expand segment_pos to match. + # The impl() layer treats segment_pos as replicated and computes seqlens/offsets per batch index + # without materializing the full expanded segment_pos array. + # Assert on invalid case (segment_ids.ndim < segment_pos.ndim) batched_args_list = list(batched_args) updated_batch_dims = list(batch_dims) for seg_id_idx, seg_pos_idx in zip( @@ -656,32 +656,19 @@ def batcher(batched_args, batch_dims, *, config): and batched_args_list[seg_id_idx].size > 0 and batched_args_list[seg_pos_idx].size > 0 ): - # Do no batch dim expansion if there's no vmapped function continue segment_ids = batched_args_list[seg_id_idx] segment_pos = batched_args_list[seg_pos_idx] - # The segment_ids, at the very least, must have the same number of dimensions as segment_pos. - # Either because the user created them or because TE generated them. if segment_ids.ndim < segment_pos.ndim: raise AssertionError( "segment_ids must not have fewer dims than segment_pos; " f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) - if segment_ids.ndim == segment_pos.ndim: - # Do no batch dim expansion if there's no dim mismatch. + # Do not expand segment_pos: leave it unexpanded so the impl() layer + # treats it as replicated and computes seqlens/offsets per batch index + # without materializing the full expanded segment_pos array. + if segment_ids.ndim >= segment_pos.ndim: continue - assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( - "segment_pos must have same trailing shape as segment_ids when adding batch" - f" dims; got segment_ids.shape={segment_ids.shape}," - f" segment_pos.shape={segment_pos.shape}" - ) - leading_bdim = segment_ids.ndim - segment_pos.ndim - target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape - expanded = segment_pos - for _ in range(leading_bdim): - expanded = lax.expand_dims(expanded, (0,)) - batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) - updated_batch_dims[seg_pos_idx] = 0 batch_dims = tuple(updated_batch_dims) batched_args = tuple(batched_args_list) @@ -1146,10 +1133,7 @@ def batcher(batched_args, batch_dims, *, config): assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims - # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. - # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when - # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). - # Check expansion per (q, kv) pair so q and kv can be batched/vmapped independently. + # Option 3 (memory-efficient): do not expand segment_pos; conversion layer treats as replicated. batched_args_list = list(batched_args) updated_batch_dims = list(batch_dims) for seg_id_idx, seg_pos_idx in zip( @@ -1172,20 +1156,8 @@ def batcher(batched_args, batch_dims, *, config): "segment_ids must not have fewer dims than segment_pos; " f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" ) - if segment_ids.ndim == segment_pos.ndim: + if segment_ids.ndim >= segment_pos.ndim: continue - assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, ( - "segment_pos must have same trailing shape as segment_ids when adding batch" - f" dims; got segment_ids.shape={segment_ids.shape}," - f" segment_pos.shape={segment_pos.shape}" - ) - leading_bdim = segment_ids.ndim - segment_pos.ndim - target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape - expanded = segment_pos - for _ in range(leading_bdim): - expanded = lax.expand_dims(expanded, (0,)) - batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) - updated_batch_dims[seg_pos_idx] = 0 batch_dims = tuple(updated_batch_dims) batched_args = tuple(batched_args_list) From 5138ac305b1235387345cdaed944d02b076f19d4 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 27 Feb 2026 10:23:52 -0800 Subject: [PATCH 06/10] Add support for vmapped seg id and non vmapped seg pos when computing the seqlens and offsets for fused attn Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 87 +++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 21db296c34..aa48ac2b10 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -693,26 +693,89 @@ def get_seqlens_and_offsets( self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq ): """ - Acquire the seqlens/offsets for cuDNN backend + Acquire the seqlens/offsets for cuDNN backend. """ q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos - assert q_segment_ids.shape == q_segment_pos.shape - assert kv_segment_ids.shape == kv_segment_pos.shape # No segment_ids/segment_pos if q_segment_ids.size + kv_segment_ids.size == 0: return self.seqlens, self.seq_offsets - if qkv_layout.is_thd(): - q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( - q_segment_ids, - kv_segment_ids, - q_segment_pos, - kv_segment_pos, - attn_mask_type, - window_size, - max_segments_per_seq, + # Allow segment_pos to have fewer leading dims than segment_ids if vmapped segment_ids and non-vmapped segment_pos + # e.g. when using from_segment_ids_and_pos() for segment_pos generation from segment_ids it is acceptable to have + # something like : segment_ids (B, batch, seq), segment_pos (batch, seq)). + if q_segment_ids.ndim < q_segment_pos.ndim or kv_segment_ids.ndim < kv_segment_pos.ndim: + raise AssertionError( + "segment_ids must not have fewer dims than segment_pos; " + f"got q_segment_ids.ndim={q_segment_ids.ndim}, q_segment_pos.ndim={q_segment_pos.ndim}, " + f"kv_segment_ids.ndim={kv_segment_ids.ndim}, kv_segment_pos.ndim={kv_segment_pos.ndim}" ) + if not ( + q_segment_ids.shape[-q_segment_pos.ndim :] == q_segment_pos.shape + and kv_segment_ids.shape[-kv_segment_pos.ndim :] == kv_segment_pos.shape + ): + raise AssertionError( + "segment_pos trailing shape must match segment_ids; " + f"got q_segment_ids.shape={q_segment_ids.shape}, q_segment_pos.shape={q_segment_pos.shape}, " + f"kv_segment_ids.shape={kv_segment_ids.shape}, kv_segment_pos.shape={kv_segment_pos.shape}" + ) + + if qkv_layout.is_thd(): + # THD: compute seqlens/offsets. Replicated segment_pos (more leading dims on segment_ids, e.g. if vmap) + # i) Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims, + # ii) vmap seqlens/offsets computation with segment_pos broadcast, + # iii) reshape back to the original leading batch dims. + if ( + q_segment_ids.ndim > q_segment_pos.ndim + or kv_segment_ids.ndim > kv_segment_pos.ndim + ): + n_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim + n_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim + batch_shape_q = q_segment_ids.shape[:n_batch_dims_q] + batch_shape_kv = kv_segment_ids.shape[:n_batch_dims_kv] + flat_batch_q = jnp.prod(batch_shape_q) + flat_batch_kv = jnp.prod(batch_shape_kv) + # assert flat_batch_q == flat_batch_kv, ( + # f"segment_ids batch size mismatch: {batch_shape_q} vs {batch_shape_kv}" + # ) + q_flat = q_segment_ids.reshape( + flat_batch_q, *q_segment_ids.shape[n_batch_dims_q:] + ) + kv_flat = kv_segment_ids.reshape( + flat_batch_kv, *kv_segment_ids.shape[n_batch_dims_kv:] + ) + + def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv): + return _segment_ids_pos_to_seqlens_offsets( + seg_id_q, + seg_id_kv, + seg_pos_q, + seg_pos_kv, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + + q_sl, kv_sl, q_off, kv_off = jax.vmap( + single_batch, in_axes=(0, 0, None, None) + )(q_flat, kv_flat, q_segment_pos, kv_segment_pos) + + q_seqlens = q_sl.reshape(*batch_shape_q, *q_sl.shape[1:]) + kv_seqlens = kv_sl.reshape(*batch_shape_kv, *kv_sl.shape[1:]) + q_offsets = q_off.reshape(*batch_shape_q, *q_off.shape[1:]) + kv_offsets = kv_off.reshape(*batch_shape_kv, *kv_off.shape[1:]) + else: + q_seqlens, kv_seqlens, q_offsets, kv_offsets = ( + _segment_ids_pos_to_seqlens_offsets( + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + ) else: q_seqlens, kv_seqlens = _segment_ids_to_seqlens( q_segment_ids, From 395ac54979e9fdfeb2ce0cc7d8c50820c3d138e2 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 27 Feb 2026 10:35:53 -0800 Subject: [PATCH 07/10] Undo batcher check logic for seg pos and seg ids as it is already moved to get_seqlens_and_offsets() Signed-off-by: Kshitij Lakhani --- .../jax/cpp_extensions/attention.py | 82 +------------------ 1 file changed, 4 insertions(+), 78 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 41daa70585..f4d914062d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -624,54 +624,14 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return output, softmax_aux, rng_state - # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=seed, - # 6,7=seqlens, 8,9=seq_offsets, 10,11=segment_ids, 12,13=segment_pos. - _SEGMENT_IDS_BATCH_DIMS_IDX = (10, 11) - _SEGMENT_POS_BATCH_DIMS_IDX = (12, 13) - @staticmethod def batcher(batched_args, batch_dims, *, config): - # batch_dims: tuple of length len(batched_args); each element is the axis index - # that is the batch axis (0, 1, ...) or None if that arg has no batch dim. - # check_valid_batch_dims: only 0 or None allowed (single leading batch or no batch). + # batch_dims: each element is the batch axis (0, ...) or None. Only 0 or None allowed. check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims - - # When segment_ids are batched (vmap) and segment_pos are not, do not expand segment_pos to match. - # The impl() layer treats segment_pos as replicated and computes seqlens/offsets per batch index - # without materializing the full expanded segment_pos array. - # Assert on invalid case (segment_ids.ndim < segment_pos.ndim) - batched_args_list = list(batched_args) - updated_batch_dims = list(batch_dims) - for seg_id_idx, seg_pos_idx in zip( - FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, - FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, - ): - seg_id_bdim = batch_dims[seg_id_idx] - seg_pos_bdim = batch_dims[seg_pos_idx] - if not ( - seg_id_bdim is not None - and seg_pos_bdim is None - and batched_args_list[seg_id_idx].size > 0 - and batched_args_list[seg_pos_idx].size > 0 - ): - continue - segment_ids = batched_args_list[seg_id_idx] - segment_pos = batched_args_list[seg_pos_idx] - if segment_ids.ndim < segment_pos.ndim: - raise AssertionError( - "segment_ids must not have fewer dims than segment_pos; " - f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" - ) - # Do not expand segment_pos: leave it unexpanded so the impl() layer - # treats it as replicated and computes seqlens/offsets per batch index - # without materializing the full expanded segment_pos array. - if segment_ids.ndim >= segment_pos.ndim: - continue - batch_dims = tuple(updated_batch_dims) - batched_args = tuple(batched_args_list) - + # Pass through; segment_ids/segment_pos may have different batch dims (e.g. vmapped ids, + # replicated pos). get_seqlens_and_offsets() in attention.py handles conversion without expanding. out_bdims = q_bdim, q_bdim, seed_bdim return ( FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), @@ -1121,46 +1081,12 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return dq, dk, dv, dbias, dsoftmax_offset - # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=softmax_aux, - # 6=rng_state, 7=output, 8=doutput, 9,10=seqlens, 11,12=seq_offsets, - # 13,14=segment_ids, 15,16=segment_pos. - _SEGMENT_IDS_BATCH_DIMS_IDX = (13, 14) - _SEGMENT_POS_BATCH_DIMS_IDX = (15, 16) - @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims - - # Option 3 (memory-efficient): do not expand segment_pos; conversion layer treats as replicated. - batched_args_list = list(batched_args) - updated_batch_dims = list(batch_dims) - for seg_id_idx, seg_pos_idx in zip( - FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, - FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, - ): - seg_id_bdim = batch_dims[seg_id_idx] - seg_pos_bdim = batch_dims[seg_pos_idx] - if not ( - seg_id_bdim is not None - and seg_pos_bdim is None - and batched_args_list[seg_id_idx].size > 0 - and batched_args_list[seg_pos_idx].size > 0 - ): - continue - segment_ids = batched_args_list[seg_id_idx] - segment_pos = batched_args_list[seg_pos_idx] - if segment_ids.ndim < segment_pos.ndim: - raise AssertionError( - "segment_ids must not have fewer dims than segment_pos; " - f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" - ) - if segment_ids.ndim >= segment_pos.ndim: - continue - batch_dims = tuple(updated_batch_dims) - batched_args = tuple(batched_args_list) - + # Pass through; segment_ids/segment_pos may have different batch dims. Conversion is in attention.py. out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim return ( FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), From 2f9dcc5683a7181fec7c4275aec13f18042a85cc Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 27 Feb 2026 11:05:21 -0800 Subject: [PATCH 08/10] nit: Remove unnecessary assert check Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index aa48ac2b10..d17319426f 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -569,6 +569,8 @@ def _segment_ids_pos_to_seqlens_offsets( # using the segment ids and pos along with mask type (causal or brcm) is sufficient. # It does not need to involve SW for this mask's creation + # Currently, this function is only exercised for THD qkv_layout. + # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well if (attn_mask_type.is_causal() and window_size is None) or ( window_size == (-1, -1) and not attn_mask_type.is_bottom_right() @@ -733,11 +735,9 @@ def get_seqlens_and_offsets( n_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim batch_shape_q = q_segment_ids.shape[:n_batch_dims_q] batch_shape_kv = kv_segment_ids.shape[:n_batch_dims_kv] - flat_batch_q = jnp.prod(batch_shape_q) - flat_batch_kv = jnp.prod(batch_shape_kv) - # assert flat_batch_q == flat_batch_kv, ( - # f"segment_ids batch size mismatch: {batch_shape_q} vs {batch_shape_kv}" - # ) + flat_batch_q = jnp.prod(jnp.array(batch_shape_q)) + flat_batch_kv = jnp.prod(jnp.array(batch_shape_kv)) + # vmap below requires same batch size on axis 0 for q_flat and kv_flat; JAX will raise if they differ. q_flat = q_segment_ids.reshape( flat_batch_q, *q_segment_ids.shape[n_batch_dims_q:] ) From 693ba652171017d77e8e5d51fed652038d66c65e Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 27 Feb 2026 11:18:44 -0800 Subject: [PATCH 09/10] nit: Code clean up Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 44 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index d17319426f..7e6c435fc0 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -721,31 +721,28 @@ def get_seqlens_and_offsets( f"got q_segment_ids.shape={q_segment_ids.shape}, q_segment_pos.shape={q_segment_pos.shape}, " f"kv_segment_ids.shape={kv_segment_ids.shape}, kv_segment_pos.shape={kv_segment_pos.shape}" ) - + # THD: compute seqlens/offsets. if qkv_layout.is_thd(): - # THD: compute seqlens/offsets. Replicated segment_pos (more leading dims on segment_ids, e.g. if vmap) - # i) Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims, - # ii) vmap seqlens/offsets computation with segment_pos broadcast, - # iii) reshape back to the original leading batch dims. - if ( - q_segment_ids.ndim > q_segment_pos.ndim - or kv_segment_ids.ndim > kv_segment_pos.ndim - ): - n_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim - n_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim - batch_shape_q = q_segment_ids.shape[:n_batch_dims_q] - batch_shape_kv = kv_segment_ids.shape[:n_batch_dims_kv] - flat_batch_q = jnp.prod(jnp.array(batch_shape_q)) - flat_batch_kv = jnp.prod(jnp.array(batch_shape_kv)) + # If there are more leading dims on segment_ids, e.g. vmap + if (q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim): + # Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims, + # vmap seqlens/offsets computation with segment_pos broadcast, + # reshape back to the original leading batch dims. + n_extra_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim + n_extra_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim + extra_batch_shape_q = q_segment_ids.shape[:n_extra_batch_dims_q] + extra_batch_shape_kv = kv_segment_ids.shape[:n_extra_batch_dims_kv] + extra_flat_batch_size_q = jnp.prod(extra_batch_shape_q) + extra_flat_batch_size_kv = jnp.prod(extra_batch_shape_kv) # vmap below requires same batch size on axis 0 for q_flat and kv_flat; JAX will raise if they differ. q_flat = q_segment_ids.reshape( - flat_batch_q, *q_segment_ids.shape[n_batch_dims_q:] + extra_flat_batch_size_q, *q_segment_ids.shape[n_extra_batch_dims_q:] ) kv_flat = kv_segment_ids.reshape( - flat_batch_kv, *kv_segment_ids.shape[n_batch_dims_kv:] + extra_flat_batch_size_kv, *kv_segment_ids.shape[n_extra_batch_dims_kv:] ) - def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv): + def single_extra_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv): return _segment_ids_pos_to_seqlens_offsets( seg_id_q, seg_id_kv, @@ -757,13 +754,13 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv): ) q_sl, kv_sl, q_off, kv_off = jax.vmap( - single_batch, in_axes=(0, 0, None, None) + single_extra_batch, in_axes=(0, 0, None, None) )(q_flat, kv_flat, q_segment_pos, kv_segment_pos) - q_seqlens = q_sl.reshape(*batch_shape_q, *q_sl.shape[1:]) - kv_seqlens = kv_sl.reshape(*batch_shape_kv, *kv_sl.shape[1:]) - q_offsets = q_off.reshape(*batch_shape_q, *q_off.shape[1:]) - kv_offsets = kv_off.reshape(*batch_shape_kv, *kv_off.shape[1:]) + q_seqlens = q_sl.reshape(*extra_batch_shape_q, *q_sl.shape[1:]) + kv_seqlens = kv_sl.reshape(*extra_batch_shape_kv, *kv_sl.shape[1:]) + q_offsets = q_off.reshape(*extra_batch_shape_q, *q_off.shape[1:]) + kv_offsets = kv_off.reshape(*extra_batch_shape_kv, *kv_off.shape[1:]) else: q_seqlens, kv_seqlens, q_offsets, kv_offsets = ( _segment_ids_pos_to_seqlens_offsets( @@ -776,6 +773,7 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv): max_segments_per_seq, ) ) + # BSHD: compute seqlens/offsets. else: q_seqlens, kv_seqlens = _segment_ids_to_seqlens( q_segment_ids, From d630082674eff27fdb2746fce68fdb67448b7c75 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:20:16 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 40 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 7e6c435fc0..217827a266 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -708,25 +708,29 @@ def get_seqlens_and_offsets( # something like : segment_ids (B, batch, seq), segment_pos (batch, seq)). if q_segment_ids.ndim < q_segment_pos.ndim or kv_segment_ids.ndim < kv_segment_pos.ndim: raise AssertionError( - "segment_ids must not have fewer dims than segment_pos; " - f"got q_segment_ids.ndim={q_segment_ids.ndim}, q_segment_pos.ndim={q_segment_pos.ndim}, " - f"kv_segment_ids.ndim={kv_segment_ids.ndim}, kv_segment_pos.ndim={kv_segment_pos.ndim}" + "segment_ids must not have fewer dims than segment_pos; got" + f" q_segment_ids.ndim={q_segment_ids.ndim}," + f" q_segment_pos.ndim={q_segment_pos.ndim}," + f" kv_segment_ids.ndim={kv_segment_ids.ndim}," + f" kv_segment_pos.ndim={kv_segment_pos.ndim}" ) if not ( q_segment_ids.shape[-q_segment_pos.ndim :] == q_segment_pos.shape and kv_segment_ids.shape[-kv_segment_pos.ndim :] == kv_segment_pos.shape ): raise AssertionError( - "segment_pos trailing shape must match segment_ids; " - f"got q_segment_ids.shape={q_segment_ids.shape}, q_segment_pos.shape={q_segment_pos.shape}, " - f"kv_segment_ids.shape={kv_segment_ids.shape}, kv_segment_pos.shape={kv_segment_pos.shape}" + "segment_pos trailing shape must match segment_ids; got" + f" q_segment_ids.shape={q_segment_ids.shape}," + f" q_segment_pos.shape={q_segment_pos.shape}," + f" kv_segment_ids.shape={kv_segment_ids.shape}," + f" kv_segment_pos.shape={kv_segment_pos.shape}" ) - # THD: compute seqlens/offsets. + # THD: compute seqlens/offsets. if qkv_layout.is_thd(): # If there are more leading dims on segment_ids, e.g. vmap - if (q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim): + if q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim: # Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims, - # vmap seqlens/offsets computation with segment_pos broadcast, + # vmap seqlens/offsets computation with segment_pos broadcast, # reshape back to the original leading batch dims. n_extra_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim n_extra_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim @@ -762,16 +766,14 @@ def single_extra_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv): q_offsets = q_off.reshape(*extra_batch_shape_q, *q_off.shape[1:]) kv_offsets = kv_off.reshape(*extra_batch_shape_kv, *kv_off.shape[1:]) else: - q_seqlens, kv_seqlens, q_offsets, kv_offsets = ( - _segment_ids_pos_to_seqlens_offsets( - q_segment_ids, - kv_segment_ids, - q_segment_pos, - kv_segment_pos, - attn_mask_type, - window_size, - max_segments_per_seq, - ) + q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + attn_mask_type, + window_size, + max_segments_per_seq, ) # BSHD: compute seqlens/offsets. else: