Skip to content

Commit 35d6d0f

Browse files
nit: Fix the shape check for assert
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
1 parent 79682e6 commit 35d6d0f

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def batcher(batched_args, batch_dims, *, config):
657657
"segment_ids must have more dims than segment_pos when adding batch dims; "
658658
f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}"
659659
)
660-
assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, (
660+
assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, (
661661
"segment_pos must have same trailing shape as segment_ids when adding batch dims; "
662662
f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}"
663663
)
@@ -1154,7 +1154,7 @@ def batcher(batched_args, batch_dims, *, config):
11541154
"segment_ids must have more dims than segment_pos when adding batch dims; "
11551155
f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}"
11561156
)
1157-
assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, (
1157+
assert segment_ids.shape[-segment_pos.ndim :] == segment_pos.shape, (
11581158
"segment_pos must have same trailing shape as segment_ids when adding batch dims; "
11591159
f"got segment_ids.shape={segment_ids.shape}, segment_pos.shape={segment_pos.shape}"
11601160
)

0 commit comments

Comments
 (0)