@@ -706,29 +706,30 @@ def get_seqlens_and_offsets(
706706 # something like : segment_ids (B, batch, seq), segment_pos (batch, seq)).
707707 if q_segment_ids .ndim < q_segment_pos .ndim or kv_segment_ids .ndim < kv_segment_pos .ndim :
708708 raise AssertionError (
709- "segment_ids must not have fewer dims than segment_pos; "
710- f"got q_segment_ids.ndim={ q_segment_ids .ndim } , q_segment_pos.ndim={ q_segment_pos .ndim } , "
711- f"kv_segment_ids.ndim={ kv_segment_ids .ndim } , kv_segment_pos.ndim={ kv_segment_pos .ndim } "
709+ "segment_ids must not have fewer dims than segment_pos; got"
710+ f" q_segment_ids.ndim={ q_segment_ids .ndim } ,"
711+ f" q_segment_pos.ndim={ q_segment_pos .ndim } ,"
712+ f" kv_segment_ids.ndim={ kv_segment_ids .ndim } ,"
713+ f" kv_segment_pos.ndim={ kv_segment_pos .ndim } "
712714 )
713715 if not (
714716 q_segment_ids .shape [- q_segment_pos .ndim :] == q_segment_pos .shape
715717 and kv_segment_ids .shape [- kv_segment_pos .ndim :] == kv_segment_pos .shape
716718 ):
717719 raise AssertionError (
718- "segment_pos trailing shape must match segment_ids; "
719- f"got q_segment_ids.shape={ q_segment_ids .shape } , q_segment_pos.shape={ q_segment_pos .shape } , "
720- f"kv_segment_ids.shape={ kv_segment_ids .shape } , kv_segment_pos.shape={ kv_segment_pos .shape } "
720+ "segment_pos trailing shape must match segment_ids; got"
721+ f" q_segment_ids.shape={ q_segment_ids .shape } ,"
722+ f" q_segment_pos.shape={ q_segment_pos .shape } ,"
723+ f" kv_segment_ids.shape={ kv_segment_ids .shape } ,"
724+ f" kv_segment_pos.shape={ kv_segment_pos .shape } "
721725 )
722726
723727 if qkv_layout .is_thd ():
724728 # THD: compute seqlens/offsets. Replicated segment_pos (more leading dims on segment_ids, e.g. if vmap)
725729 # i) Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims,
726- # ii) vmap seqlens/offsets computation with segment_pos broadcast,
730+ # ii) vmap seqlens/offsets computation with segment_pos broadcast,
727731 # iii) reshape back to the original leading batch dims.
728- if (
729- q_segment_ids .ndim > q_segment_pos .ndim
730- or kv_segment_ids .ndim > kv_segment_pos .ndim
731- ):
732+ if q_segment_ids .ndim > q_segment_pos .ndim or kv_segment_ids .ndim > kv_segment_pos .ndim :
732733 n_batch_dims_q = q_segment_ids .ndim - q_segment_pos .ndim
733734 n_batch_dims_kv = kv_segment_ids .ndim - kv_segment_pos .ndim
734735 batch_shape_q = q_segment_ids .shape [:n_batch_dims_q ]
@@ -738,9 +739,7 @@ def get_seqlens_and_offsets(
738739 # assert flat_batch_q == flat_batch_kv, (
739740 # f"segment_ids batch size mismatch: {batch_shape_q} vs {batch_shape_kv}"
740741 # )
741- q_flat = q_segment_ids .reshape (
742- flat_batch_q , * q_segment_ids .shape [n_batch_dims_q :]
743- )
742+ q_flat = q_segment_ids .reshape (flat_batch_q , * q_segment_ids .shape [n_batch_dims_q :])
744743 kv_flat = kv_segment_ids .reshape (
745744 flat_batch_kv , * kv_segment_ids .shape [n_batch_dims_kv :]
746745 )
@@ -756,25 +755,23 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
756755 max_segments_per_seq ,
757756 )
758757
759- q_sl , kv_sl , q_off , kv_off = jax .vmap (
760- single_batch , in_axes = ( 0 , 0 , None , None )
761- )( q_flat , kv_flat , q_segment_pos , kv_segment_pos )
758+ q_sl , kv_sl , q_off , kv_off = jax .vmap (single_batch , in_axes = ( 0 , 0 , None , None ))(
759+ q_flat , kv_flat , q_segment_pos , kv_segment_pos
760+ )
762761
763762 q_seqlens = q_sl .reshape (* batch_shape_q , * q_sl .shape [1 :])
764763 kv_seqlens = kv_sl .reshape (* batch_shape_kv , * kv_sl .shape [1 :])
765764 q_offsets = q_off .reshape (* batch_shape_q , * q_off .shape [1 :])
766765 kv_offsets = kv_off .reshape (* batch_shape_kv , * kv_off .shape [1 :])
767766 else :
768- q_seqlens , kv_seqlens , q_offsets , kv_offsets = (
769- _segment_ids_pos_to_seqlens_offsets (
770- q_segment_ids ,
771- kv_segment_ids ,
772- q_segment_pos ,
773- kv_segment_pos ,
774- attn_mask_type ,
775- window_size ,
776- max_segments_per_seq ,
777- )
767+ q_seqlens , kv_seqlens , q_offsets , kv_offsets = _segment_ids_pos_to_seqlens_offsets (
768+ q_segment_ids ,
769+ kv_segment_ids ,
770+ q_segment_pos ,
771+ kv_segment_pos ,
772+ attn_mask_type ,
773+ window_size ,
774+ max_segments_per_seq ,
778775 )
779776 else :
780777 q_seqlens , kv_seqlens = _segment_ids_to_seqlens (
0 commit comments