Skip to content

Commit 41e22ef

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ce5ee86 commit 41e22ef

2 files changed

Lines changed: 140 additions & 54 deletions

File tree

tests/pytorch/test_permutation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def _test_permutation_index_map(
304304
torch._dynamo.reset()
305305
# Disable donated buffers to allow retain_graph=True
306306
import torch._functorch.config as functorch_config
307+
307308
old_donated_buffer = functorch_config.donated_buffer
308309
functorch_config.donated_buffer = False
309310

@@ -337,9 +338,7 @@ def unpermute_wrapper(inp, row_map, probs_val):
337338

338339
# Compile with fullgraph=True
339340
compiled_unpermute = torch.compile(unpermute_wrapper, fullgraph=True)
340-
te_unpermute_output = compiled_unpermute(
341-
te_unpermute_fwd_input, row_id_map, te_probs
342-
)
341+
te_unpermute_output = compiled_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
343342
else:
344343
te_unpermute_output = te_unpermute(
345344
te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"

transformer_engine/pytorch/permutation.py

Lines changed: 138 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,8 @@ def _moe_permute_index_map_fake(
9393
output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK
9494

9595
# row_id_map is 1D with size = num_tokens * topK (see permutation.cpp line 59-60)
96-
fake_output = torch.empty(
97-
(output_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device
98-
)
99-
fake_row_id_map = torch.empty(
100-
(num_tokens * topK,), dtype=torch.int32, device=inp.device
101-
)
96+
fake_output = torch.empty((output_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device)
97+
fake_row_id_map = torch.empty((num_tokens * topK,), dtype=torch.int32, device=inp.device)
10298

10399
return fake_output, fake_row_id_map
104100

@@ -167,6 +163,7 @@ def _moe_permute_index_map_backward_wrapper(ctx, grad_permuted_act, grad_row_id_
167163

168164
# ---------------------------------- Forward custom op ----------------------------------
169165

166+
170167
@torch.library.custom_op("te_moe::unpermute_index_map_fwd", mutates_args=[])
171168
def moe_unpermute_index_map_forward(
172169
inp: torch.Tensor,
@@ -190,13 +187,12 @@ def _moe_unpermute_index_map_forward_fake(
190187
) -> torch.Tensor:
191188
"""Fake implementation for shape inference."""
192189
# Output shape: (num_tokens, hidden_size) — see permutation.cpp line 95-97
193-
return torch.empty(
194-
(num_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device
195-
)
190+
return torch.empty((num_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device)
196191

197192

198193
# ---------------------------------- Backward custom op ----------------------------------
199194

195+
200196
@torch.library.custom_op("te_moe::unpermute_index_map_bwd", mutates_args=[])
201197
def moe_unpermute_index_map_backward(
202198
unpermuted_act_grad: torch.Tensor,
@@ -237,6 +233,7 @@ def _moe_unpermute_index_map_backward_fake(
237233

238234
# ---------------------------------- Autograd glue ----------------------------------
239235

236+
240237
def _moe_unpermute_index_map_setup_context(ctx, inputs, output):
241238
"""Save context for backward pass."""
242239
inp, row_id_map, probs, num_tokens, topK = inputs
@@ -272,6 +269,7 @@ def _moe_unpermute_index_map_backward_wrapper(ctx, unpermuted_act_grad):
272269

273270
# ===================== _moe_permute_mask_map custom ops =====================
274271

272+
275273
@torch.library.custom_op("te_moe::permute_mask_map_fwd", mutates_args=[])
276274
def moe_permute_mask_map_forward(
277275
inp: torch.Tensor,
@@ -296,7 +294,9 @@ def moe_permute_mask_map_forward(
296294
if pad_offsets is not None:
297295
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
298296
assert inp.size(0) == routing_map.size(0), "Permute not possible"
299-
assert num_out_tokens is not None, "num_out_tokens must be provided to the fused permute function."
297+
assert (
298+
num_out_tokens is not None
299+
), "num_out_tokens must be provided to the fused permute function."
300300

301301
num_tokens, hidden_size = inp.size()
302302
num_experts = routing_map.size(1)
@@ -335,38 +335,58 @@ def moe_permute_mask_map_forward(
335335
scale_hidden_dim = None
336336

337337
output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
338-
inp, row_id_map, probs, fp8_scale, pad_offsets,
339-
num_tokens, num_experts, num_out_tokens, hidden_size, scale_hidden_dim,
338+
inp,
339+
row_id_map,
340+
probs,
341+
fp8_scale,
342+
pad_offsets,
343+
num_tokens,
344+
num_experts,
345+
num_out_tokens,
346+
hidden_size,
347+
scale_hidden_dim,
340348
)
341349

342350
if fp8:
343351
if per_tensor_recipe:
344352
output = Float8Tensor(
345-
data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv,
346-
shape=output.shape, dtype=fake_dtype,
353+
data=output,
354+
fp8_dtype=fp8_dtype,
355+
fp8_scale_inv=fp8_scale_inv,
356+
shape=output.shape,
357+
dtype=fake_dtype,
347358
)
348359
elif blockwise_recipe:
349360
output = Float8BlockwiseQTensor(
350-
shape=output.shape, dtype=fake_dtype, rowwise_data=output,
361+
shape=output.shape,
362+
dtype=fake_dtype,
363+
rowwise_data=output,
351364
rowwise_scale_inv=permuted_scale.T.contiguous(),
352-
columnwise_data=None, columnwise_scale_inv=None,
353-
fp8_dtype=fp8_dtype, quantizer=None, is_2D_scaled=False,
365+
columnwise_data=None,
366+
columnwise_scale_inv=None,
367+
fp8_dtype=fp8_dtype,
368+
quantizer=None,
369+
is_2D_scaled=False,
354370
requires_grad=output.requires_grad,
355371
)
356372
elif mxfp8_recipe:
357373
output = MXFP8Tensor(
358-
shape=output.shape, dtype=fake_dtype, fp8_dtype=fp8_dtype,
359-
rowwise_data=output, rowwise_scale_inv=permuted_scale.contiguous(),
360-
columnwise_data=None, columnwise_scale_inv=None,
361-
quantizer=None, requires_grad=output.requires_grad,
374+
shape=output.shape,
375+
dtype=fake_dtype,
376+
fp8_dtype=fp8_dtype,
377+
rowwise_data=output,
378+
rowwise_scale_inv=permuted_scale.contiguous(),
379+
columnwise_data=None,
380+
columnwise_scale_inv=None,
381+
quantizer=None,
382+
requires_grad=output.requires_grad,
362383
with_gemm_swizzled_scales=False,
363384
)
364385

365386
# If permuted_probs is None, return empty tensor (custom ops need concrete tensors)
366387
if permuted_probs is None:
367388
permuted_probs = torch.empty(0, device=inp.device)
368389

369-
370390
return output, row_id_map, permuted_probs
371391

372392

@@ -406,8 +426,14 @@ def moe_permute_mask_map_backward(
406426
) -> Tuple[torch.Tensor, torch.Tensor]:
407427
"""Backward pass for MoE permute with mask router map."""
408428
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
409-
permuted_act_grad, row_id_map, None, permuted_probs_grad, pad_offsets,
410-
num_tokens, num_experts, hidden_size,
429+
permuted_act_grad,
430+
row_id_map,
431+
None,
432+
permuted_probs_grad,
433+
pad_offsets,
434+
num_tokens,
435+
num_experts,
436+
hidden_size,
411437
)
412438
if probs_grad is None:
413439
probs_grad = torch.empty(0, device=permuted_act_grad.device)
@@ -430,7 +456,8 @@ def _moe_permute_mask_map_backward_fake(
430456
)
431457
if permuted_probs_grad is not None:
432458
probs_grad = torch.empty(
433-
(num_tokens, num_experts), dtype=permuted_probs_grad.dtype,
459+
(num_tokens, num_experts),
460+
dtype=permuted_probs_grad.dtype,
434461
device=permuted_act_grad.device,
435462
)
436463
else:
@@ -471,8 +498,13 @@ def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, gr
471498
probs_grad_input = grad_permuted_probs if grad_permuted_probs.numel() > 0 else None
472499

473500
act_grad, probs_grad = moe_permute_mask_map_backward(
474-
grad_output, probs_grad_input, row_id_map, pad_offsets,
475-
ctx.num_tokens, ctx.num_experts, ctx.hidden_size,
501+
grad_output,
502+
probs_grad_input,
503+
row_id_map,
504+
pad_offsets,
505+
ctx.num_tokens,
506+
ctx.num_experts,
507+
ctx.hidden_size,
476508
)
477509

478510
if not ctx.needs_probs_grad or probs_grad.numel() == 0:
@@ -489,6 +521,7 @@ def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, gr
489521

490522
# ===================== _moe_unpermute_mask_map custom ops =====================
491523

524+
492525
@torch.library.custom_op("te_moe::unpermute_mask_map_fwd", mutates_args=[])
493526
def moe_unpermute_mask_map_forward(
494527
inp: torch.Tensor,
@@ -508,8 +541,14 @@ def moe_unpermute_mask_map_forward(
508541
inp, QuantizedTensor
509542
), "The forward of moe_unpermute does not support FP8."
510543
unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
511-
inp, row_id_map, merging_probs, None, pad_offsets,
512-
num_tokens, num_experts, hidden_size,
544+
inp,
545+
row_id_map,
546+
merging_probs,
547+
None,
548+
pad_offsets,
549+
num_tokens,
550+
num_experts,
551+
hidden_size,
513552
)
514553
return unpermuted_output
515554

@@ -542,8 +581,15 @@ def moe_unpermute_mask_map_backward_with_probs(
542581
) -> Tuple[torch.Tensor, torch.Tensor]:
543582
"""Backward pass for MoE unpermute with merging probs."""
544583
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
545-
unpermuted_act_grad, row_id_map, fwd_input, merging_probs, pad_offsets,
546-
num_tokens, num_experts, num_permuted_tokens, hidden_size,
584+
unpermuted_act_grad,
585+
row_id_map,
586+
fwd_input,
587+
merging_probs,
588+
pad_offsets,
589+
num_tokens,
590+
num_experts,
591+
num_permuted_tokens,
592+
hidden_size,
547593
)
548594
return act_grad, probs_grad
549595

@@ -563,11 +609,13 @@ def _moe_unpermute_mask_map_bwd_with_probs_fake(
563609
"""Fake for backward shape inference with merging probs."""
564610
act_grad = torch.empty(
565611
(num_permuted_tokens, hidden_size),
566-
dtype=unpermuted_act_grad.dtype, device=unpermuted_act_grad.device,
612+
dtype=unpermuted_act_grad.dtype,
613+
device=unpermuted_act_grad.device,
567614
)
568615
probs_grad = torch.empty(
569616
(num_tokens, num_experts),
570-
dtype=merging_probs.dtype, device=unpermuted_act_grad.device,
617+
dtype=merging_probs.dtype,
618+
device=unpermuted_act_grad.device,
571619
)
572620
return act_grad, probs_grad
573621

@@ -615,30 +663,51 @@ def moe_unpermute_mask_map_backward_no_probs(
615663
fp8_scale = None
616664

617665
act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
618-
unpermuted_act_grad, row_id_map, None, fp8_scale, pad_offsets,
619-
num_tokens, num_experts, num_permuted_tokens, hidden_size, scale_hidden_dim,
666+
unpermuted_act_grad,
667+
row_id_map,
668+
None,
669+
fp8_scale,
670+
pad_offsets,
671+
num_tokens,
672+
num_experts,
673+
num_permuted_tokens,
674+
hidden_size,
675+
scale_hidden_dim,
620676
)
621677

622678
if fp8:
623679
if per_tensor_recipe:
624680
act_grad = Float8Tensor(
625-
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv,
626-
shape=act_grad.shape, dtype=fake_dtype,
681+
data=act_grad,
682+
fp8_dtype=fp8_dtype,
683+
fp8_scale_inv=fp8_scale_inv,
684+
shape=act_grad.shape,
685+
dtype=fake_dtype,
627686
)
628687
elif blockwise_recipe:
629688
act_grad = Float8BlockwiseQTensor(
630-
shape=act_grad.shape, dtype=fake_dtype, rowwise_data=act_grad,
689+
shape=act_grad.shape,
690+
dtype=fake_dtype,
691+
rowwise_data=act_grad,
631692
rowwise_scale_inv=permuted_scale.T.contiguous(),
632-
columnwise_data=None, columnwise_scale_inv=None,
633-
fp8_dtype=fp8_dtype, quantizer=None, is_2D_scaled=False,
693+
columnwise_data=None,
694+
columnwise_scale_inv=None,
695+
fp8_dtype=fp8_dtype,
696+
quantizer=None,
697+
is_2D_scaled=False,
634698
requires_grad=act_grad.requires_grad,
635699
)
636700
elif mxfp8_recipe:
637701
act_grad = MXFP8Tensor(
638-
shape=act_grad.shape, dtype=fake_dtype, fp8_dtype=fp8_dtype,
639-
rowwise_data=act_grad, rowwise_scale_inv=permuted_scale.contiguous(),
640-
columnwise_data=None, columnwise_scale_inv=None,
641-
quantizer=None, requires_grad=act_grad.requires_grad,
702+
shape=act_grad.shape,
703+
dtype=fake_dtype,
704+
fp8_dtype=fp8_dtype,
705+
rowwise_data=act_grad,
706+
rowwise_scale_inv=permuted_scale.contiguous(),
707+
columnwise_data=None,
708+
columnwise_scale_inv=None,
709+
quantizer=None,
710+
requires_grad=act_grad.requires_grad,
642711
with_gemm_swizzled_scales=False,
643712
)
644713

@@ -658,7 +727,8 @@ def _moe_unpermute_mask_map_bwd_no_probs_fake(
658727
"""Fake for backward shape inference without probs."""
659728
return torch.empty(
660729
(num_permuted_tokens, hidden_size),
661-
dtype=unpermuted_act_grad.dtype, device=unpermuted_act_grad.device,
730+
dtype=unpermuted_act_grad.dtype,
731+
device=unpermuted_act_grad.device,
662732
)
663733

664734

@@ -697,14 +767,26 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad):
697767
unpermuted_act_grad, QuantizedTensor
698768
), "The backward of moe_unpermute with merging probs does not support FP8."
699769
act_grad, probs_grad = moe_unpermute_mask_map_backward_with_probs(
700-
unpermuted_act_grad, row_id_map, fwd_input, merging_probs, pad_offsets,
701-
ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size,
770+
unpermuted_act_grad,
771+
row_id_map,
772+
fwd_input,
773+
merging_probs,
774+
pad_offsets,
775+
ctx.num_tokens,
776+
ctx.num_experts,
777+
ctx.num_permuted_tokens,
778+
ctx.hidden_size,
702779
)
703780
else:
704781
row_id_map, pad_offsets = ctx.saved_tensors
705782
act_grad = moe_unpermute_mask_map_backward_no_probs(
706-
unpermuted_act_grad, row_id_map, pad_offsets,
707-
ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size,
783+
unpermuted_act_grad,
784+
row_id_map,
785+
pad_offsets,
786+
ctx.num_tokens,
787+
ctx.num_experts,
788+
ctx.num_permuted_tokens,
789+
ctx.hidden_size,
708790
)
709791

710792
if not ctx.needs_probs_grad:
@@ -945,8 +1027,13 @@ def moe_unpermute(
9451027
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
9461028

9471029
return moe_unpermute_mask_map_forward(
948-
inp, row_id_map, merging_probs,
949-
num_tokens, num_experts, hidden_size, pad_offsets,
1030+
inp,
1031+
row_id_map,
1032+
merging_probs,
1033+
num_tokens,
1034+
num_experts,
1035+
hidden_size,
1036+
pad_offsets,
9501037
)
9511038
raise ValueError("map_type should be one of 'mask' or 'index'")
9521039

0 commit comments

Comments
 (0)