@@ -172,6 +172,8 @@ def run_2stage_moe(
172172
173173 # ========== Stage 2 ==========
174174 # Reference implementation using the function from fused_moe.py
175+ # NOTE: Both stage2 implementations should use the same intermediate tensor
176+ # to isolate stage 2 correctness from any stage 1 numerical differences
175177 output_ref = torch_moe_stage2 (
176178 inter_ref ,
177179 w1 ,
@@ -187,23 +189,44 @@ def run_2stage_moe(
187189 )
188190
189191 # CK implementation using metadata.stage2
190- output_out = torch .empty ((num_tokens , hidden_dim ), dtype = dtype , device = "cuda" )
191- output_ck , us_stage2 = run_perftest (
192+ # Use inter_ref (not inter_ck) to test stage 2 in isolation
193+ # NOTE: Stage 2 kernel uses AtomicAdd, so we need fresh output buffers for run_perftest
194+ # to avoid accumulation across iterations. We'll call it once for validation.
195+ output_out_validation = torch .zeros ((num_tokens , hidden_dim ), dtype = dtype , device = "cuda" )
196+ output_ck = metadata .stage2 (
197+ inter_ref ,
198+ w1_shuffle ,
199+ w2_shuffle ,
200+ sorted_token_ids ,
201+ sorted_expert_ids ,
202+ num_valid_ids ,
203+ output_out_validation ,
204+ topk ,
205+ w2_scale = w2_scale ,
206+ a2_scale = a2_scale ,
207+ block_m = block_size ,
208+ sorted_weights = sorted_weights ,
209+ )
210+
211+ # For performance measurement, create fresh buffers for each iteration
212+ output_out_perf = torch .zeros ((num_tokens , hidden_dim ), dtype = dtype , device = "cuda" )
213+ _ , us_stage2 = run_perftest (
192214 metadata .stage2 ,
193- inter_ck ,
215+ inter_ref ,
194216 w1_shuffle ,
195217 w2_shuffle ,
196218 sorted_token_ids ,
197219 sorted_expert_ids ,
198220 num_valid_ids ,
199- output_out ,
221+ output_out_perf ,
200222 topk ,
201223 w2_scale = w2_scale ,
202224 a2_scale = a2_scale ,
203225 block_m = block_size ,
204226 sorted_weights = sorted_weights ,
205227 num_iters = 10 ,
206228 num_warmup = 2 ,
229+ num_rotate_args = 12 , # Force creating 12 separate output buffers
207230 needTrace = False ,
208231 )
209232
0 commit comments