Skip to content

Commit 395ac54

Browse files
Undo batcher check logic for seg pos and seg ids as it is already moved to get_seqlens_and_offsets()
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
1 parent 5138ac3 commit 395ac54

1 file changed

Lines changed: 4 additions & 78 deletions

File tree

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 4 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -624,54 +624,14 @@ def convert_to_2d(offsets, batch, max_seqlen):
624624
)
625625
return output, softmax_aux, rng_state
626626

627-
# Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=seed,
628-
# 6,7=seqlens, 8,9=seq_offsets, 10,11=segment_ids, 12,13=segment_pos.
629-
_SEGMENT_IDS_BATCH_DIMS_IDX = (10, 11)
630-
_SEGMENT_POS_BATCH_DIMS_IDX = (12, 13)
631-
632627
@staticmethod
633628
def batcher(batched_args, batch_dims, *, config):
634-
# batch_dims: tuple of length len(batched_args); each element is the axis index
635-
# that is the batch axis (0, 1, ...) or None if that arg has no batch dim.
636-
# check_valid_batch_dims: only 0 or None allowed (single leading batch or no batch).
629+
# batch_dims: each element is the batch axis (0, ...) or None. Only 0 or None allowed.
637630
check_valid_batch_dims(batch_dims)
638631
assert FusedAttnFwdPrimitive.outer_primitive is not None
639632
q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims
640-
641-
# When segment_ids are batched (vmap) and segment_pos are not, do not expand segment_pos to match.
642-
# The impl() layer treats segment_pos as replicated and computes seqlens/offsets per batch index
643-
# without materializing the full expanded segment_pos array.
644-
# Assert on invalid case (segment_ids.ndim < segment_pos.ndim)
645-
batched_args_list = list(batched_args)
646-
updated_batch_dims = list(batch_dims)
647-
for seg_id_idx, seg_pos_idx in zip(
648-
FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX,
649-
FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX,
650-
):
651-
seg_id_bdim = batch_dims[seg_id_idx]
652-
seg_pos_bdim = batch_dims[seg_pos_idx]
653-
if not (
654-
seg_id_bdim is not None
655-
and seg_pos_bdim is None
656-
and batched_args_list[seg_id_idx].size > 0
657-
and batched_args_list[seg_pos_idx].size > 0
658-
):
659-
continue
660-
segment_ids = batched_args_list[seg_id_idx]
661-
segment_pos = batched_args_list[seg_pos_idx]
662-
if segment_ids.ndim < segment_pos.ndim:
663-
raise AssertionError(
664-
"segment_ids must not have fewer dims than segment_pos; "
665-
f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}"
666-
)
667-
# Do not expand segment_pos: leave it unexpanded so the impl() layer
668-
# treats it as replicated and computes seqlens/offsets per batch index
669-
# without materializing the full expanded segment_pos array.
670-
if segment_ids.ndim >= segment_pos.ndim:
671-
continue
672-
batch_dims = tuple(updated_batch_dims)
673-
batched_args = tuple(batched_args_list)
674-
633+
# Pass through; segment_ids/segment_pos may have different batch dims (e.g. vmapped ids,
634+
# replicated pos). get_seqlens_and_offsets() in attention.py handles conversion without expanding.
675635
out_bdims = q_bdim, q_bdim, seed_bdim
676636
return (
677637
FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
@@ -1121,46 +1081,12 @@ def convert_to_2d(offsets, batch, max_seqlen):
11211081
)
11221082
return dq, dk, dv, dbias, dsoftmax_offset
11231083

1124-
# Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=softmax_aux,
1125-
# 6=rng_state, 7=output, 8=doutput, 9,10=seqlens, 11,12=seq_offsets,
1126-
# 13,14=segment_ids, 15,16=segment_pos.
1127-
_SEGMENT_IDS_BATCH_DIMS_IDX = (13, 14)
1128-
_SEGMENT_POS_BATCH_DIMS_IDX = (15, 16)
1129-
11301084
@staticmethod
11311085
def batcher(batched_args, batch_dims, *, config):
11321086
check_valid_batch_dims(batch_dims)
11331087
assert FusedAttnBwdPrimitive.outer_primitive is not None
11341088
q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims
1135-
1136-
# Option 3 (memory-efficient): do not expand segment_pos; conversion layer treats as replicated.
1137-
batched_args_list = list(batched_args)
1138-
updated_batch_dims = list(batch_dims)
1139-
for seg_id_idx, seg_pos_idx in zip(
1140-
FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX,
1141-
FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX,
1142-
):
1143-
seg_id_bdim = batch_dims[seg_id_idx]
1144-
seg_pos_bdim = batch_dims[seg_pos_idx]
1145-
if not (
1146-
seg_id_bdim is not None
1147-
and seg_pos_bdim is None
1148-
and batched_args_list[seg_id_idx].size > 0
1149-
and batched_args_list[seg_pos_idx].size > 0
1150-
):
1151-
continue
1152-
segment_ids = batched_args_list[seg_id_idx]
1153-
segment_pos = batched_args_list[seg_pos_idx]
1154-
if segment_ids.ndim < segment_pos.ndim:
1155-
raise AssertionError(
1156-
"segment_ids must not have fewer dims than segment_pos; "
1157-
f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}"
1158-
)
1159-
if segment_ids.ndim >= segment_pos.ndim:
1160-
continue
1161-
batch_dims = tuple(updated_batch_dims)
1162-
batched_args = tuple(batched_args_list)
1163-
1089+
# Pass through; segment_ids/segment_pos may have different batch dims. Conversion is in attention.py.
11641090
out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
11651091
return (
11661092
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),

0 commit comments

Comments
 (0)