Skip to content

Commit a7c398c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5138ac3 commit a7c398c

2 files changed

Lines changed: 25 additions & 28 deletions

File tree

transformer_engine/jax/attention.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -706,29 +706,30 @@ def get_seqlens_and_offsets(
706706
# something like : segment_ids (B, batch, seq), segment_pos (batch, seq)).
707707
if q_segment_ids.ndim < q_segment_pos.ndim or kv_segment_ids.ndim < kv_segment_pos.ndim:
708708
raise AssertionError(
709-
"segment_ids must not have fewer dims than segment_pos; "
710-
f"got q_segment_ids.ndim={q_segment_ids.ndim}, q_segment_pos.ndim={q_segment_pos.ndim}, "
711-
f"kv_segment_ids.ndim={kv_segment_ids.ndim}, kv_segment_pos.ndim={kv_segment_pos.ndim}"
709+
"segment_ids must not have fewer dims than segment_pos; got"
710+
f" q_segment_ids.ndim={q_segment_ids.ndim},"
711+
f" q_segment_pos.ndim={q_segment_pos.ndim},"
712+
f" kv_segment_ids.ndim={kv_segment_ids.ndim},"
713+
f" kv_segment_pos.ndim={kv_segment_pos.ndim}"
712714
)
713715
if not (
714716
q_segment_ids.shape[-q_segment_pos.ndim :] == q_segment_pos.shape
715717
and kv_segment_ids.shape[-kv_segment_pos.ndim :] == kv_segment_pos.shape
716718
):
717719
raise AssertionError(
718-
"segment_pos trailing shape must match segment_ids; "
719-
f"got q_segment_ids.shape={q_segment_ids.shape}, q_segment_pos.shape={q_segment_pos.shape}, "
720-
f"kv_segment_ids.shape={kv_segment_ids.shape}, kv_segment_pos.shape={kv_segment_pos.shape}"
720+
"segment_pos trailing shape must match segment_ids; got"
721+
f" q_segment_ids.shape={q_segment_ids.shape},"
722+
f" q_segment_pos.shape={q_segment_pos.shape},"
723+
f" kv_segment_ids.shape={kv_segment_ids.shape},"
724+
f" kv_segment_pos.shape={kv_segment_pos.shape}"
721725
)
722726

723727
if qkv_layout.is_thd():
724728
# THD: compute seqlens/offsets. Replicated segment_pos (more leading dims on segment_ids, e.g. if vmap)
725729
# i) Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims,
726-
# ii) vmap seqlens/offsets computation with segment_pos broadcast,
730+
# ii) vmap seqlens/offsets computation with segment_pos broadcast,
727731
# iii) reshape back to the original leading batch dims.
728-
if (
729-
q_segment_ids.ndim > q_segment_pos.ndim
730-
or kv_segment_ids.ndim > kv_segment_pos.ndim
731-
):
732+
if q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim:
732733
n_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim
733734
n_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim
734735
batch_shape_q = q_segment_ids.shape[:n_batch_dims_q]
@@ -738,9 +739,7 @@ def get_seqlens_and_offsets(
738739
# assert flat_batch_q == flat_batch_kv, (
739740
# f"segment_ids batch size mismatch: {batch_shape_q} vs {batch_shape_kv}"
740741
# )
741-
q_flat = q_segment_ids.reshape(
742-
flat_batch_q, *q_segment_ids.shape[n_batch_dims_q:]
743-
)
742+
q_flat = q_segment_ids.reshape(flat_batch_q, *q_segment_ids.shape[n_batch_dims_q:])
744743
kv_flat = kv_segment_ids.reshape(
745744
flat_batch_kv, *kv_segment_ids.shape[n_batch_dims_kv:]
746745
)
@@ -756,25 +755,23 @@ def single_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
756755
max_segments_per_seq,
757756
)
758757

759-
q_sl, kv_sl, q_off, kv_off = jax.vmap(
760-
single_batch, in_axes=(0, 0, None, None)
761-
)(q_flat, kv_flat, q_segment_pos, kv_segment_pos)
758+
q_sl, kv_sl, q_off, kv_off = jax.vmap(single_batch, in_axes=(0, 0, None, None))(
759+
q_flat, kv_flat, q_segment_pos, kv_segment_pos
760+
)
762761

763762
q_seqlens = q_sl.reshape(*batch_shape_q, *q_sl.shape[1:])
764763
kv_seqlens = kv_sl.reshape(*batch_shape_kv, *kv_sl.shape[1:])
765764
q_offsets = q_off.reshape(*batch_shape_q, *q_off.shape[1:])
766765
kv_offsets = kv_off.reshape(*batch_shape_kv, *kv_off.shape[1:])
767766
else:
768-
q_seqlens, kv_seqlens, q_offsets, kv_offsets = (
769-
_segment_ids_pos_to_seqlens_offsets(
770-
q_segment_ids,
771-
kv_segment_ids,
772-
q_segment_pos,
773-
kv_segment_pos,
774-
attn_mask_type,
775-
window_size,
776-
max_segments_per_seq,
777-
)
767+
q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets(
768+
q_segment_ids,
769+
kv_segment_ids,
770+
q_segment_pos,
771+
kv_segment_pos,
772+
attn_mask_type,
773+
window_size,
774+
max_segments_per_seq,
778775
)
779776
else:
780777
q_seqlens, kv_seqlens = _segment_ids_to_seqlens(

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def batcher(batched_args, batch_dims, *, config):
638638
assert FusedAttnFwdPrimitive.outer_primitive is not None
639639
q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims
640640

641-
# When segment_ids are batched (vmap) and segment_pos are not, do not expand segment_pos to match.
641+
# When segment_ids are batched (vmap) and segment_pos are not, do not expand segment_pos to match.
642642
# The impl() layer treats segment_pos as replicated and computes seqlens/offsets per batch index
643643
# without materializing the full expanded segment_pos array.
644644
# Assert on invalid case (segment_ids.ndim < segment_pos.ndim)

0 commit comments

Comments
 (0)