diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f4b1fb23ae..d3320fd70f 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -507,11 +507,6 @@ def hook_fn( else: grad_inputs = None del outputs, grad_inputs - # The following code is added specifically for MCore's special requirements, - # aimed at preventing warmup from altering the control flow. - for module in func.modules(): - if hasattr(module, "is_first_microbatch"): - module.is_first_microbatch = True torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory,