Skip to content

Commit 47722eb

Browse files
committed
fix validation of stage2
1 parent 25bc5d6 commit 47722eb

1 file changed

Lines changed: 27 additions & 4 deletions

File tree

op_tests/test_moe_ck_2stage.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def run_2stage_moe(
172172

173173
# ========== Stage 2 ==========
174174
# Reference implementation using the function from fused_moe.py
175+
# NOTE: Both stage2 implementations should use the same intermediate tensor
176+
# to isolate stage 2 correctness from any stage 1 numerical differences
175177
output_ref = torch_moe_stage2(
176178
inter_ref,
177179
w1,
@@ -187,23 +189,44 @@ def run_2stage_moe(
187189
)
188190

189191
# CK implementation using metadata.stage2
190-
output_out = torch.empty((num_tokens, hidden_dim), dtype=dtype, device="cuda")
191-
output_ck, us_stage2 = run_perftest(
192+
# Use inter_ref (not inter_ck) to test stage 2 in isolation
193+
# NOTE: Stage 2 kernel uses AtomicAdd, so we need fresh output buffers for run_perftest
194+
# to avoid accumulation across iterations. We'll call it once for validation.
195+
output_out_validation = torch.zeros((num_tokens, hidden_dim), dtype=dtype, device="cuda")
196+
output_ck = metadata.stage2(
197+
inter_ref,
198+
w1_shuffle,
199+
w2_shuffle,
200+
sorted_token_ids,
201+
sorted_expert_ids,
202+
num_valid_ids,
203+
output_out_validation,
204+
topk,
205+
w2_scale=w2_scale,
206+
a2_scale=a2_scale,
207+
block_m=block_size,
208+
sorted_weights=sorted_weights,
209+
)
210+
211+
# For performance measurement, create fresh buffers for each iteration
212+
output_out_perf = torch.zeros((num_tokens, hidden_dim), dtype=dtype, device="cuda")
213+
_, us_stage2 = run_perftest(
192214
metadata.stage2,
193-
inter_ck,
215+
inter_ref,
194216
w1_shuffle,
195217
w2_shuffle,
196218
sorted_token_ids,
197219
sorted_expert_ids,
198220
num_valid_ids,
199-
output_out,
221+
output_out_perf,
200222
topk,
201223
w2_scale=w2_scale,
202224
a2_scale=a2_scale,
203225
block_m=block_size,
204226
sorted_weights=sorted_weights,
205227
num_iters=10,
206228
num_warmup=2,
229+
num_rotate_args=12, # Force creating 12 separate output buffers
207230
needTrace=False,
208231
)
209232

0 commit comments

Comments
 (0)