@@ -624,54 +624,14 @@ def convert_to_2d(offsets, batch, max_seqlen):
624624 )
625625 return output , softmax_aux , rng_state
626626
627- # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=seed,
628- # 6,7=seqlens, 8,9=seq_offsets, 10,11=segment_ids, 12,13=segment_pos.
629- _SEGMENT_IDS_BATCH_DIMS_IDX = (10 , 11 )
630- _SEGMENT_POS_BATCH_DIMS_IDX = (12 , 13 )
631-
632627 @staticmethod
633628 def batcher (batched_args , batch_dims , * , config ):
634- # batch_dims: tuple of length len(batched_args); each element is the axis index
635- # that is the batch axis (0, 1, ...) or None if that arg has no batch dim.
636- # check_valid_batch_dims: only 0 or None allowed (single leading batch or no batch).
629+ # batch_dims: each element is the batch axis (0, ...) or None. Only 0 or None allowed.
637630 check_valid_batch_dims (batch_dims )
638631 assert FusedAttnFwdPrimitive .outer_primitive is not None
639632 q_bdim , _ , _ , _ , _ , seed_bdim , * _ = batch_dims
640-
641- # When segment_ids are batched (vmap) and segment_pos are not, do not expand segment_pos to match.
642- # The impl() layer treats segment_pos as replicated and computes seqlens/offsets per batch index
643- # without materializing the full expanded segment_pos array.
644- # Assert on invalid case (segment_ids.ndim < segment_pos.ndim)
645- batched_args_list = list (batched_args )
646- updated_batch_dims = list (batch_dims )
647- for seg_id_idx , seg_pos_idx in zip (
648- FusedAttnFwdPrimitive ._SEGMENT_IDS_BATCH_DIMS_IDX ,
649- FusedAttnFwdPrimitive ._SEGMENT_POS_BATCH_DIMS_IDX ,
650- ):
651- seg_id_bdim = batch_dims [seg_id_idx ]
652- seg_pos_bdim = batch_dims [seg_pos_idx ]
653- if not (
654- seg_id_bdim is not None
655- and seg_pos_bdim is None
656- and batched_args_list [seg_id_idx ].size > 0
657- and batched_args_list [seg_pos_idx ].size > 0
658- ):
659- continue
660- segment_ids = batched_args_list [seg_id_idx ]
661- segment_pos = batched_args_list [seg_pos_idx ]
662- if segment_ids .ndim < segment_pos .ndim :
663- raise AssertionError (
664- "segment_ids must not have fewer dims than segment_pos; "
665- f"got segment_ids.ndim={ segment_ids .ndim } , segment_pos.ndim={ segment_pos .ndim } "
666- )
667- # Do not expand segment_pos: leave it unexpanded so the impl() layer
668- # treats it as replicated and computes seqlens/offsets per batch index
669- # without materializing the full expanded segment_pos array.
670- if segment_ids .ndim >= segment_pos .ndim :
671- continue
672- batch_dims = tuple (updated_batch_dims )
673- batched_args = tuple (batched_args_list )
674-
633+ # Pass through; segment_ids/segment_pos may have different batch dims (e.g. vmapped ids,
634+ # replicated pos). get_seqlens_and_offsets() in attention.py handles conversion without expanding.
675635 out_bdims = q_bdim , q_bdim , seed_bdim
676636 return (
677637 FusedAttnFwdPrimitive .outer_primitive .bind (* batched_args , config = config ),
@@ -1121,46 +1081,12 @@ def convert_to_2d(offsets, batch, max_seqlen):
11211081 )
11221082 return dq , dk , dv , dbias , dsoftmax_offset
11231083
1124- # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=softmax_aux,
1125- # 6=rng_state, 7=output, 8=doutput, 9,10=seqlens, 11,12=seq_offsets,
1126- # 13,14=segment_ids, 15,16=segment_pos.
1127- _SEGMENT_IDS_BATCH_DIMS_IDX = (13 , 14 )
1128- _SEGMENT_POS_BATCH_DIMS_IDX = (15 , 16 )
1129-
11301084 @staticmethod
11311085 def batcher (batched_args , batch_dims , * , config ):
11321086 check_valid_batch_dims (batch_dims )
11331087 assert FusedAttnBwdPrimitive .outer_primitive is not None
11341088 q_bdim , k_bdim , v_bdim , bias_bdim , softmax_offset_bdim , * _ = batch_dims
1135-
1136- # Option 3 (memory-efficient): do not expand segment_pos; conversion layer treats as replicated.
1137- batched_args_list = list (batched_args )
1138- updated_batch_dims = list (batch_dims )
1139- for seg_id_idx , seg_pos_idx in zip (
1140- FusedAttnBwdPrimitive ._SEGMENT_IDS_BATCH_DIMS_IDX ,
1141- FusedAttnBwdPrimitive ._SEGMENT_POS_BATCH_DIMS_IDX ,
1142- ):
1143- seg_id_bdim = batch_dims [seg_id_idx ]
1144- seg_pos_bdim = batch_dims [seg_pos_idx ]
1145- if not (
1146- seg_id_bdim is not None
1147- and seg_pos_bdim is None
1148- and batched_args_list [seg_id_idx ].size > 0
1149- and batched_args_list [seg_pos_idx ].size > 0
1150- ):
1151- continue
1152- segment_ids = batched_args_list [seg_id_idx ]
1153- segment_pos = batched_args_list [seg_pos_idx ]
1154- if segment_ids .ndim < segment_pos .ndim :
1155- raise AssertionError (
1156- "segment_ids must not have fewer dims than segment_pos; "
1157- f"got segment_ids.ndim={ segment_ids .ndim } , segment_pos.ndim={ segment_pos .ndim } "
1158- )
1159- if segment_ids .ndim >= segment_pos .ndim :
1160- continue
1161- batch_dims = tuple (updated_batch_dims )
1162- batched_args = tuple (batched_args_list )
1163-
1089+ # Pass through; segment_ids/segment_pos may have different batch dims. Conversion is in attention.py.
11641090 out_bdims = q_bdim , k_bdim , v_bdim , bias_bdim , softmax_offset_bdim
11651091 return (
11661092 FusedAttnBwdPrimitive .outer_primitive .bind (* batched_args , config = config ),
0 commit comments