@@ -93,12 +93,8 @@ def _moe_permute_index_map_fake(
9393 output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK
9494
9595 # row_id_map is 1D with size = num_tokens * topK (see permutation.cpp line 59-60)
96- fake_output = torch .empty (
97- (output_tokens , inp .shape [1 ]), dtype = inp .dtype , device = inp .device
98- )
99- fake_row_id_map = torch .empty (
100- (num_tokens * topK ,), dtype = torch .int32 , device = inp .device
101- )
96+ fake_output = torch .empty ((output_tokens , inp .shape [1 ]), dtype = inp .dtype , device = inp .device )
97+ fake_row_id_map = torch .empty ((num_tokens * topK ,), dtype = torch .int32 , device = inp .device )
10298
10399 return fake_output , fake_row_id_map
104100
@@ -167,6 +163,7 @@ def _moe_permute_index_map_backward_wrapper(ctx, grad_permuted_act, grad_row_id_
167163
168164# ---------------------------------- Forward custom op ----------------------------------
169165
166+
170167@torch .library .custom_op ("te_moe::unpermute_index_map_fwd" , mutates_args = [])
171168def moe_unpermute_index_map_forward (
172169 inp : torch .Tensor ,
@@ -190,13 +187,12 @@ def _moe_unpermute_index_map_forward_fake(
190187) -> torch .Tensor :
191188 """Fake implementation for shape inference."""
192189 # Output shape: (num_tokens, hidden_size) — see permutation.cpp line 95-97
193- return torch .empty (
194- (num_tokens , inp .shape [1 ]), dtype = inp .dtype , device = inp .device
195- )
190+ return torch .empty ((num_tokens , inp .shape [1 ]), dtype = inp .dtype , device = inp .device )
196191
197192
198193# ---------------------------------- Backward custom op ----------------------------------
199194
195+
200196@torch .library .custom_op ("te_moe::unpermute_index_map_bwd" , mutates_args = [])
201197def moe_unpermute_index_map_backward (
202198 unpermuted_act_grad : torch .Tensor ,
@@ -237,6 +233,7 @@ def _moe_unpermute_index_map_backward_fake(
237233
238234# ---------------------------------- Autograd glue ----------------------------------
239235
236+
240237def _moe_unpermute_index_map_setup_context (ctx , inputs , output ):
241238 """Save context for backward pass."""
242239 inp , row_id_map , probs , num_tokens , topK = inputs
@@ -272,6 +269,7 @@ def _moe_unpermute_index_map_backward_wrapper(ctx, unpermuted_act_grad):
272269
273270# ===================== _moe_permute_mask_map custom ops =====================
274271
272+
275273@torch .library .custom_op ("te_moe::permute_mask_map_fwd" , mutates_args = [])
276274def moe_permute_mask_map_forward (
277275 inp : torch .Tensor ,
@@ -296,7 +294,9 @@ def moe_permute_mask_map_forward(
296294 if pad_offsets is not None :
297295 assert pad_offsets .is_cuda , "TransformerEngine needs CUDA."
298296 assert inp .size (0 ) == routing_map .size (0 ), "Permute not possible"
299- assert num_out_tokens is not None , "num_out_tokens must be provided to the fused permute function."
297+ assert (
298+ num_out_tokens is not None
299+ ), "num_out_tokens must be provided to the fused permute function."
300300
301301 num_tokens , hidden_size = inp .size ()
302302 num_experts = routing_map .size (1 )
@@ -335,38 +335,58 @@ def moe_permute_mask_map_forward(
335335 scale_hidden_dim = None
336336
337337 output , permuted_scale , permuted_probs = triton_permutation .permute_with_mask_map (
338- inp , row_id_map , probs , fp8_scale , pad_offsets ,
339- num_tokens , num_experts , num_out_tokens , hidden_size , scale_hidden_dim ,
338+ inp ,
339+ row_id_map ,
340+ probs ,
341+ fp8_scale ,
342+ pad_offsets ,
343+ num_tokens ,
344+ num_experts ,
345+ num_out_tokens ,
346+ hidden_size ,
347+ scale_hidden_dim ,
340348 )
341349
342350 if fp8 :
343351 if per_tensor_recipe :
344352 output = Float8Tensor (
345- data = output , fp8_dtype = fp8_dtype , fp8_scale_inv = fp8_scale_inv ,
346- shape = output .shape , dtype = fake_dtype ,
353+ data = output ,
354+ fp8_dtype = fp8_dtype ,
355+ fp8_scale_inv = fp8_scale_inv ,
356+ shape = output .shape ,
357+ dtype = fake_dtype ,
347358 )
348359 elif blockwise_recipe :
349360 output = Float8BlockwiseQTensor (
350- shape = output .shape , dtype = fake_dtype , rowwise_data = output ,
361+ shape = output .shape ,
362+ dtype = fake_dtype ,
363+ rowwise_data = output ,
351364 rowwise_scale_inv = permuted_scale .T .contiguous (),
352- columnwise_data = None , columnwise_scale_inv = None ,
353- fp8_dtype = fp8_dtype , quantizer = None , is_2D_scaled = False ,
365+ columnwise_data = None ,
366+ columnwise_scale_inv = None ,
367+ fp8_dtype = fp8_dtype ,
368+ quantizer = None ,
369+ is_2D_scaled = False ,
354370 requires_grad = output .requires_grad ,
355371 )
356372 elif mxfp8_recipe :
357373 output = MXFP8Tensor (
358- shape = output .shape , dtype = fake_dtype , fp8_dtype = fp8_dtype ,
359- rowwise_data = output , rowwise_scale_inv = permuted_scale .contiguous (),
360- columnwise_data = None , columnwise_scale_inv = None ,
361- quantizer = None , requires_grad = output .requires_grad ,
374+ shape = output .shape ,
375+ dtype = fake_dtype ,
376+ fp8_dtype = fp8_dtype ,
377+ rowwise_data = output ,
378+ rowwise_scale_inv = permuted_scale .contiguous (),
379+ columnwise_data = None ,
380+ columnwise_scale_inv = None ,
381+ quantizer = None ,
382+ requires_grad = output .requires_grad ,
362383 with_gemm_swizzled_scales = False ,
363384 )
364385
365386 # If permuted_probs is None, return empty tensor (custom ops need concrete tensors)
366387 if permuted_probs is None :
367388 permuted_probs = torch .empty (0 , device = inp .device )
368389
369-
370390 return output , row_id_map , permuted_probs
371391
372392
@@ -406,8 +426,14 @@ def moe_permute_mask_map_backward(
406426) -> Tuple [torch .Tensor , torch .Tensor ]:
407427 """Backward pass for MoE permute with mask router map."""
408428 act_grad , probs_grad = triton_permutation .unpermute_with_mask_map (
409- permuted_act_grad , row_id_map , None , permuted_probs_grad , pad_offsets ,
410- num_tokens , num_experts , hidden_size ,
429+ permuted_act_grad ,
430+ row_id_map ,
431+ None ,
432+ permuted_probs_grad ,
433+ pad_offsets ,
434+ num_tokens ,
435+ num_experts ,
436+ hidden_size ,
411437 )
412438 if probs_grad is None :
413439 probs_grad = torch .empty (0 , device = permuted_act_grad .device )
@@ -430,7 +456,8 @@ def _moe_permute_mask_map_backward_fake(
430456 )
431457 if permuted_probs_grad is not None :
432458 probs_grad = torch .empty (
433- (num_tokens , num_experts ), dtype = permuted_probs_grad .dtype ,
459+ (num_tokens , num_experts ),
460+ dtype = permuted_probs_grad .dtype ,
434461 device = permuted_act_grad .device ,
435462 )
436463 else :
@@ -471,8 +498,13 @@ def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, gr
471498 probs_grad_input = grad_permuted_probs if grad_permuted_probs .numel () > 0 else None
472499
473500 act_grad , probs_grad = moe_permute_mask_map_backward (
474- grad_output , probs_grad_input , row_id_map , pad_offsets ,
475- ctx .num_tokens , ctx .num_experts , ctx .hidden_size ,
501+ grad_output ,
502+ probs_grad_input ,
503+ row_id_map ,
504+ pad_offsets ,
505+ ctx .num_tokens ,
506+ ctx .num_experts ,
507+ ctx .hidden_size ,
476508 )
477509
478510 if not ctx .needs_probs_grad or probs_grad .numel () == 0 :
@@ -489,6 +521,7 @@ def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, gr
489521
490522# ===================== _moe_unpermute_mask_map custom ops =====================
491523
524+
492525@torch .library .custom_op ("te_moe::unpermute_mask_map_fwd" , mutates_args = [])
493526def moe_unpermute_mask_map_forward (
494527 inp : torch .Tensor ,
@@ -508,8 +541,14 @@ def moe_unpermute_mask_map_forward(
508541 inp , QuantizedTensor
509542 ), "The forward of moe_unpermute does not support FP8."
510543 unpermuted_output , _ = triton_permutation .unpermute_with_mask_map (
511- inp , row_id_map , merging_probs , None , pad_offsets ,
512- num_tokens , num_experts , hidden_size ,
544+ inp ,
545+ row_id_map ,
546+ merging_probs ,
547+ None ,
548+ pad_offsets ,
549+ num_tokens ,
550+ num_experts ,
551+ hidden_size ,
513552 )
514553 return unpermuted_output
515554
@@ -542,8 +581,15 @@ def moe_unpermute_mask_map_backward_with_probs(
542581) -> Tuple [torch .Tensor , torch .Tensor ]:
543582 """Backward pass for MoE unpermute with merging probs."""
544583 act_grad , probs_grad = triton_permutation .unpermute_with_mask_map_bwd_with_merging_probs (
545- unpermuted_act_grad , row_id_map , fwd_input , merging_probs , pad_offsets ,
546- num_tokens , num_experts , num_permuted_tokens , hidden_size ,
584+ unpermuted_act_grad ,
585+ row_id_map ,
586+ fwd_input ,
587+ merging_probs ,
588+ pad_offsets ,
589+ num_tokens ,
590+ num_experts ,
591+ num_permuted_tokens ,
592+ hidden_size ,
547593 )
548594 return act_grad , probs_grad
549595
@@ -563,11 +609,13 @@ def _moe_unpermute_mask_map_bwd_with_probs_fake(
563609 """Fake for backward shape inference with merging probs."""
564610 act_grad = torch .empty (
565611 (num_permuted_tokens , hidden_size ),
566- dtype = unpermuted_act_grad .dtype , device = unpermuted_act_grad .device ,
612+ dtype = unpermuted_act_grad .dtype ,
613+ device = unpermuted_act_grad .device ,
567614 )
568615 probs_grad = torch .empty (
569616 (num_tokens , num_experts ),
570- dtype = merging_probs .dtype , device = unpermuted_act_grad .device ,
617+ dtype = merging_probs .dtype ,
618+ device = unpermuted_act_grad .device ,
571619 )
572620 return act_grad , probs_grad
573621
@@ -615,30 +663,51 @@ def moe_unpermute_mask_map_backward_no_probs(
615663 fp8_scale = None
616664
617665 act_grad , permuted_scale , _ = triton_permutation .permute_with_mask_map (
618- unpermuted_act_grad , row_id_map , None , fp8_scale , pad_offsets ,
619- num_tokens , num_experts , num_permuted_tokens , hidden_size , scale_hidden_dim ,
666+ unpermuted_act_grad ,
667+ row_id_map ,
668+ None ,
669+ fp8_scale ,
670+ pad_offsets ,
671+ num_tokens ,
672+ num_experts ,
673+ num_permuted_tokens ,
674+ hidden_size ,
675+ scale_hidden_dim ,
620676 )
621677
622678 if fp8 :
623679 if per_tensor_recipe :
624680 act_grad = Float8Tensor (
625- data = act_grad , fp8_dtype = fp8_dtype , fp8_scale_inv = fp8_scale_inv ,
626- shape = act_grad .shape , dtype = fake_dtype ,
681+ data = act_grad ,
682+ fp8_dtype = fp8_dtype ,
683+ fp8_scale_inv = fp8_scale_inv ,
684+ shape = act_grad .shape ,
685+ dtype = fake_dtype ,
627686 )
628687 elif blockwise_recipe :
629688 act_grad = Float8BlockwiseQTensor (
630- shape = act_grad .shape , dtype = fake_dtype , rowwise_data = act_grad ,
689+ shape = act_grad .shape ,
690+ dtype = fake_dtype ,
691+ rowwise_data = act_grad ,
631692 rowwise_scale_inv = permuted_scale .T .contiguous (),
632- columnwise_data = None , columnwise_scale_inv = None ,
633- fp8_dtype = fp8_dtype , quantizer = None , is_2D_scaled = False ,
693+ columnwise_data = None ,
694+ columnwise_scale_inv = None ,
695+ fp8_dtype = fp8_dtype ,
696+ quantizer = None ,
697+ is_2D_scaled = False ,
634698 requires_grad = act_grad .requires_grad ,
635699 )
636700 elif mxfp8_recipe :
637701 act_grad = MXFP8Tensor (
638- shape = act_grad .shape , dtype = fake_dtype , fp8_dtype = fp8_dtype ,
639- rowwise_data = act_grad , rowwise_scale_inv = permuted_scale .contiguous (),
640- columnwise_data = None , columnwise_scale_inv = None ,
641- quantizer = None , requires_grad = act_grad .requires_grad ,
702+ shape = act_grad .shape ,
703+ dtype = fake_dtype ,
704+ fp8_dtype = fp8_dtype ,
705+ rowwise_data = act_grad ,
706+ rowwise_scale_inv = permuted_scale .contiguous (),
707+ columnwise_data = None ,
708+ columnwise_scale_inv = None ,
709+ quantizer = None ,
710+ requires_grad = act_grad .requires_grad ,
642711 with_gemm_swizzled_scales = False ,
643712 )
644713
@@ -658,7 +727,8 @@ def _moe_unpermute_mask_map_bwd_no_probs_fake(
658727 """Fake for backward shape inference without probs."""
659728 return torch .empty (
660729 (num_permuted_tokens , hidden_size ),
661- dtype = unpermuted_act_grad .dtype , device = unpermuted_act_grad .device ,
730+ dtype = unpermuted_act_grad .dtype ,
731+ device = unpermuted_act_grad .device ,
662732 )
663733
664734
@@ -697,14 +767,26 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad):
697767 unpermuted_act_grad , QuantizedTensor
698768 ), "The backward of moe_unpermute with merging probs does not support FP8."
699769 act_grad , probs_grad = moe_unpermute_mask_map_backward_with_probs (
700- unpermuted_act_grad , row_id_map , fwd_input , merging_probs , pad_offsets ,
701- ctx .num_tokens , ctx .num_experts , ctx .num_permuted_tokens , ctx .hidden_size ,
770+ unpermuted_act_grad ,
771+ row_id_map ,
772+ fwd_input ,
773+ merging_probs ,
774+ pad_offsets ,
775+ ctx .num_tokens ,
776+ ctx .num_experts ,
777+ ctx .num_permuted_tokens ,
778+ ctx .hidden_size ,
702779 )
703780 else :
704781 row_id_map , pad_offsets = ctx .saved_tensors
705782 act_grad = moe_unpermute_mask_map_backward_no_probs (
706- unpermuted_act_grad , row_id_map , pad_offsets ,
707- ctx .num_tokens , ctx .num_experts , ctx .num_permuted_tokens , ctx .hidden_size ,
783+ unpermuted_act_grad ,
784+ row_id_map ,
785+ pad_offsets ,
786+ ctx .num_tokens ,
787+ ctx .num_experts ,
788+ ctx .num_permuted_tokens ,
789+ ctx .hidden_size ,
708790 )
709791
710792 if not ctx .needs_probs_grad :
@@ -945,8 +1027,13 @@ def moe_unpermute(
9451027 assert pad_offsets .is_cuda , "TransformerEngine needs CUDA."
9461028
9471029 return moe_unpermute_mask_map_forward (
948- inp , row_id_map , merging_probs ,
949- num_tokens , num_experts , hidden_size , pad_offsets ,
1030+ inp ,
1031+ row_id_map ,
1032+ merging_probs ,
1033+ num_tokens ,
1034+ num_experts ,
1035+ hidden_size ,
1036+ pad_offsets ,
9501037 )
9511038 raise ValueError ("map_type should be one of 'mask' or 'index'" )
9521039
0 commit comments