@@ -721,31 +721,28 @@ def get_seqlens_and_offsets(
721721 f"got q_segment_ids.shape={ q_segment_ids .shape } , q_segment_pos.shape={ q_segment_pos .shape } , "
722722 f"kv_segment_ids.shape={ kv_segment_ids .shape } , kv_segment_pos.shape={ kv_segment_pos .shape } "
723723 )
724-
724+ # THD: compute seqlens/offsets.
725725 if qkv_layout .is_thd ():
726- # THD: compute seqlens/offsets. Replicated segment_pos (more leading dims on segment_ids, e.g. if vmap)
727- # i) Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims,
728- # ii) vmap seqlens/offsets computation with segment_pos broadcast,
729- # iii) reshape back to the original leading batch dims.
730- if (
731- q_segment_ids .ndim > q_segment_pos .ndim
732- or kv_segment_ids .ndim > kv_segment_pos .ndim
733- ):
734- n_batch_dims_q = q_segment_ids .ndim - q_segment_pos .ndim
735- n_batch_dims_kv = kv_segment_ids .ndim - kv_segment_pos .ndim
736- batch_shape_q = q_segment_ids .shape [:n_batch_dims_q ]
737- batch_shape_kv = kv_segment_ids .shape [:n_batch_dims_kv ]
738- flat_batch_q = jnp .prod (jnp .array (batch_shape_q ))
739- flat_batch_kv = jnp .prod (jnp .array (batch_shape_kv ))
726+ # If there are more leading dims on segment_ids, e.g. vmap
727+ if (q_segment_ids .ndim > q_segment_pos .ndim or kv_segment_ids .ndim > kv_segment_pos .ndim ):
728+ # Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims,
729+ # vmap seqlens/offsets computation with segment_pos broadcast,
730+ # reshape back to the original leading batch dims.
731+ n_extra_batch_dims_q = q_segment_ids .ndim - q_segment_pos .ndim
732+ n_extra_batch_dims_kv = kv_segment_ids .ndim - kv_segment_pos .ndim
733+ extra_batch_shape_q = q_segment_ids .shape [:n_extra_batch_dims_q ]
734+ extra_batch_shape_kv = kv_segment_ids .shape [:n_extra_batch_dims_kv ]
735+ extra_flat_batch_size_q = jnp .prod (extra_batch_shape_q )
736+ extra_flat_batch_size_kv = jnp .prod (extra_batch_shape_kv )
740737 # vmap below requires same batch size on axis 0 for q_flat and kv_flat; JAX will raise if they differ.
741738 q_flat = q_segment_ids .reshape (
742- flat_batch_q , * q_segment_ids .shape [n_batch_dims_q :]
739+ extra_flat_batch_size_q , * q_segment_ids .shape [n_extra_batch_dims_q :]
743740 )
744741 kv_flat = kv_segment_ids .reshape (
745- flat_batch_kv , * kv_segment_ids .shape [n_batch_dims_kv :]
742+ extra_flat_batch_size_kv , * kv_segment_ids .shape [n_extra_batch_dims_kv :]
746743 )
747744
748- def single_batch (seg_id_q , seg_id_kv , seg_pos_q , seg_pos_kv ):
745+ def single_extra_batch (seg_id_q , seg_id_kv , seg_pos_q , seg_pos_kv ):
749746 return _segment_ids_pos_to_seqlens_offsets (
750747 seg_id_q ,
751748 seg_id_kv ,
@@ -757,13 +754,13 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
757754 )
758755
759756 q_sl , kv_sl , q_off , kv_off = jax .vmap (
760- single_batch , in_axes = (0 , 0 , None , None )
757+ single_extra_batch , in_axes = (0 , 0 , None , None )
761758 )(q_flat , kv_flat , q_segment_pos , kv_segment_pos )
762759
763- q_seqlens = q_sl .reshape (* batch_shape_q , * q_sl .shape [1 :])
764- kv_seqlens = kv_sl .reshape (* batch_shape_kv , * kv_sl .shape [1 :])
765- q_offsets = q_off .reshape (* batch_shape_q , * q_off .shape [1 :])
766- kv_offsets = kv_off .reshape (* batch_shape_kv , * kv_off .shape [1 :])
760+ q_seqlens = q_sl .reshape (* extra_batch_shape_q , * q_sl .shape [1 :])
761+ kv_seqlens = kv_sl .reshape (* extra_batch_shape_kv , * kv_sl .shape [1 :])
762+ q_offsets = q_off .reshape (* extra_batch_shape_q , * q_off .shape [1 :])
763+ kv_offsets = kv_off .reshape (* extra_batch_shape_kv , * kv_off .shape [1 :])
767764 else :
768765 q_seqlens , kv_seqlens , q_offsets , kv_offsets = (
769766 _segment_ids_pos_to_seqlens_offsets (
@@ -776,6 +773,7 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
776773 max_segments_per_seq ,
777774 )
778775 )
776+ # BSHD: compute seqlens/offsets.
779777 else :
780778 q_seqlens , kv_seqlens = _segment_ids_to_seqlens (
781779 q_segment_ids ,
0 commit comments