Skip to content

Commit 693ba65

Browse files
nit: Code clean up
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
1 parent 2f9dcc5 commit 693ba65

1 file changed

Lines changed: 21 additions & 23 deletions

File tree

transformer_engine/jax/attention.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -721,31 +721,28 @@ def get_seqlens_and_offsets(
721721
f"got q_segment_ids.shape={q_segment_ids.shape}, q_segment_pos.shape={q_segment_pos.shape}, "
722722
f"kv_segment_ids.shape={kv_segment_ids.shape}, kv_segment_pos.shape={kv_segment_pos.shape}"
723723
)
724-
724+
# THD: compute seqlens/offsets.
725725
if qkv_layout.is_thd():
726-
# THD: compute seqlens/offsets. Replicated segment_pos (more leading dims on segment_ids, e.g. if vmap)
727-
# i) Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims,
728-
# ii) vmap seqlens/offsets computation with segment_pos broadcast,
729-
# iii) reshape back to the original leading batch dims.
730-
if (
731-
q_segment_ids.ndim > q_segment_pos.ndim
732-
or kv_segment_ids.ndim > kv_segment_pos.ndim
733-
):
734-
n_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim
735-
n_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim
736-
batch_shape_q = q_segment_ids.shape[:n_batch_dims_q]
737-
batch_shape_kv = kv_segment_ids.shape[:n_batch_dims_kv]
738-
flat_batch_q = jnp.prod(jnp.array(batch_shape_q))
739-
flat_batch_kv = jnp.prod(jnp.array(batch_shape_kv))
726+
# If there are more leading dims on segment_ids, e.g. vmap
727+
if (q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim):
728+
# Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims,
729+
# vmap seqlens/offsets computation with segment_pos broadcast,
730+
# reshape back to the original leading batch dims.
731+
n_extra_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim
732+
n_extra_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim
733+
extra_batch_shape_q = q_segment_ids.shape[:n_extra_batch_dims_q]
734+
extra_batch_shape_kv = kv_segment_ids.shape[:n_extra_batch_dims_kv]
735+
extra_flat_batch_size_q = jnp.prod(extra_batch_shape_q)
736+
extra_flat_batch_size_kv = jnp.prod(extra_batch_shape_kv)
740737
# vmap below requires same batch size on axis 0 for q_flat and kv_flat; JAX will raise if they differ.
741738
q_flat = q_segment_ids.reshape(
742-
flat_batch_q, *q_segment_ids.shape[n_batch_dims_q:]
739+
extra_flat_batch_size_q, *q_segment_ids.shape[n_extra_batch_dims_q:]
743740
)
744741
kv_flat = kv_segment_ids.reshape(
745-
flat_batch_kv, *kv_segment_ids.shape[n_batch_dims_kv:]
742+
extra_flat_batch_size_kv, *kv_segment_ids.shape[n_extra_batch_dims_kv:]
746743
)
747744

748-
def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
745+
def single_extra_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
749746
return _segment_ids_pos_to_seqlens_offsets(
750747
seg_id_q,
751748
seg_id_kv,
@@ -757,13 +754,13 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
757754
)
758755

759756
q_sl, kv_sl, q_off, kv_off = jax.vmap(
760-
single_batch, in_axes=(0, 0, None, None)
757+
single_extra_batch, in_axes=(0, 0, None, None)
761758
)(q_flat, kv_flat, q_segment_pos, kv_segment_pos)
762759

763-
q_seqlens = q_sl.reshape(*batch_shape_q, *q_sl.shape[1:])
764-
kv_seqlens = kv_sl.reshape(*batch_shape_kv, *kv_sl.shape[1:])
765-
q_offsets = q_off.reshape(*batch_shape_q, *q_off.shape[1:])
766-
kv_offsets = kv_off.reshape(*batch_shape_kv, *kv_off.shape[1:])
760+
q_seqlens = q_sl.reshape(*extra_batch_shape_q, *q_sl.shape[1:])
761+
kv_seqlens = kv_sl.reshape(*extra_batch_shape_kv, *kv_sl.shape[1:])
762+
q_offsets = q_off.reshape(*extra_batch_shape_q, *q_off.shape[1:])
763+
kv_offsets = kv_off.reshape(*extra_batch_shape_kv, *kv_off.shape[1:])
767764
else:
768765
q_seqlens, kv_seqlens, q_offsets, kv_offsets = (
769766
_segment_ids_pos_to_seqlens_offsets(
@@ -776,6 +773,7 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
776773
max_segments_per_seq,
777774
)
778775
)
776+
# BSHD: compute seqlens/offsets.
779777
else:
780778
q_seqlens, kv_seqlens = _segment_ids_to_seqlens(
781779
q_segment_ids,

0 commit comments

Comments
 (0)