diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 21db296c34..217827a266 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -569,6 +569,8 @@ def _segment_ids_pos_to_seqlens_offsets( # using the segment ids and pos along with mask type (causal or brcm) is sufficient. # It does not need to involve SW for this mask's creation + # Currently, this function is only exercised for THD qkv_layout. + # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well if (attn_mask_type.is_causal() and window_size is None) or ( window_size == (-1, -1) and not attn_mask_type.is_bottom_right() @@ -693,26 +695,87 @@ def get_seqlens_and_offsets( self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq ): """ - Acquire the seqlens/offsets for cuDNN backend + Acquire the seqlens/offsets for cuDNN backend. """ q_segment_ids, kv_segment_ids = self.segment_ids q_segment_pos, kv_segment_pos = self.segment_pos - assert q_segment_ids.shape == q_segment_pos.shape - assert kv_segment_ids.shape == kv_segment_pos.shape # No segment_ids/segment_pos if q_segment_ids.size + kv_segment_ids.size == 0: return self.seqlens, self.seq_offsets - if qkv_layout.is_thd(): - q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( - q_segment_ids, - kv_segment_ids, - q_segment_pos, - kv_segment_pos, - attn_mask_type, - window_size, - max_segments_per_seq, + # Allow segment_pos to have fewer leading dims than segment_ids if vmapped segment_ids and non-vmapped segment_pos + # e.g. when using from_segment_ids_and_pos() for segment_pos generation from segment_ids it is acceptable to have + # something like : segment_ids (B, batch, seq), segment_pos (batch, seq)). + if q_segment_ids.ndim < q_segment_pos.ndim or kv_segment_ids.ndim < kv_segment_pos.ndim: + raise AssertionError( + "segment_ids must not have fewer dims than segment_pos; got" + f" q_segment_ids.ndim={q_segment_ids.ndim}," + f" q_segment_pos.ndim={q_segment_pos.ndim}," + f" kv_segment_ids.ndim={kv_segment_ids.ndim}," + f" kv_segment_pos.ndim={kv_segment_pos.ndim}" + ) + if not ( + q_segment_ids.shape[-q_segment_pos.ndim :] == q_segment_pos.shape + and kv_segment_ids.shape[-kv_segment_pos.ndim :] == kv_segment_pos.shape + ): + raise AssertionError( + "segment_pos trailing shape must match segment_ids; got" + f" q_segment_ids.shape={q_segment_ids.shape}," + f" q_segment_pos.shape={q_segment_pos.shape}," + f" kv_segment_ids.shape={kv_segment_ids.shape}," + f" kv_segment_pos.shape={kv_segment_pos.shape}" ) + # THD: compute seqlens/offsets. + if qkv_layout.is_thd(): + # If there are more leading dims on segment_ids, e.g. vmap + if q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim: + # Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims, + # vmap seqlens/offsets computation with segment_pos broadcast, + # reshape back to the original leading batch dims. + n_extra_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim + n_extra_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim + extra_batch_shape_q = q_segment_ids.shape[:n_extra_batch_dims_q] + extra_batch_shape_kv = kv_segment_ids.shape[:n_extra_batch_dims_kv] + extra_flat_batch_size_q = jnp.prod(extra_batch_shape_q) + extra_flat_batch_size_kv = jnp.prod(extra_batch_shape_kv) + # vmap below requires same batch size on axis 0 for q_flat and kv_flat; JAX will raise if they differ. + q_flat = q_segment_ids.reshape( + extra_flat_batch_size_q, *q_segment_ids.shape[n_extra_batch_dims_q:] + ) + kv_flat = kv_segment_ids.reshape( + extra_flat_batch_size_kv, *kv_segment_ids.shape[n_extra_batch_dims_kv:] + ) + + def single_extra_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv): + return _segment_ids_pos_to_seqlens_offsets( + seg_id_q, + seg_id_kv, + seg_pos_q, + seg_pos_kv, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + + q_sl, kv_sl, q_off, kv_off = jax.vmap( + single_extra_batch, in_axes=(0, 0, None, None) + )(q_flat, kv_flat, q_segment_pos, kv_segment_pos) + + q_seqlens = q_sl.reshape(*extra_batch_shape_q, *q_sl.shape[1:]) + kv_seqlens = kv_sl.reshape(*extra_batch_shape_kv, *kv_sl.shape[1:]) + q_offsets = q_off.reshape(*extra_batch_shape_q, *q_off.shape[1:]) + kv_offsets = kv_off.reshape(*extra_batch_shape_kv, *kv_off.shape[1:]) + else: + q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + # BSHD: compute seqlens/offsets. else: q_seqlens, kv_seqlens = _segment_ids_to_seqlens( q_segment_ids, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e5d75e1501..f4d914062d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -626,10 +626,12 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): + # batch_dims: each element is the batch axis (0, ...) or None. Only 0 or None allowed. check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims - + # Pass through; segment_ids/segment_pos may have different batch dims (e.g. vmapped ids, + # replicated pos). get_seqlens_and_offsets() in attention.py handles conversion without expanding. out_bdims = q_bdim, q_bdim, seed_bdim return ( FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), @@ -1084,7 +1086,7 @@ def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims - + # Pass through; segment_ids/segment_pos may have different batch dims. Conversion is in attention.py. out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim return ( FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),