diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index a782dadc60..165cbad3dd 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -630,12 +630,185 @@ def test_make_graphed_callables_with_kwargs( assert_all_equal(outputs, graph_outputs) +def test_make_graphed_callables_returns_owned_parameter_grads() -> None: + """Parameter grads returned from graph replay must not alias static graph buffers.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + + seen_grads = [] + + def save_grad(grad): + seen_grads.append(grad) + return grad + + hook = model.weight.register_hook(save_grad) + try: + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 1 + first_grad = seen_grads[0] + first_grad_ptr = first_grad.data_ptr() + first_grad_snapshot = first_grad.clone() + + model.zero_grad(set_to_none=True) + + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 2 + assert first_grad.data_ptr() == first_grad_ptr + assert seen_grads[1].data_ptr() != first_grad_ptr + torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0) + finally: + hook.remove() + reset_graphs(model) + + +def test_make_graphed_callables_accumulates_owned_parameter_grads() -> None: + """Parameter grad accumulation must not reuse overwritten static graph buffers.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + + input_1 = generate_data(model_config, dtype, requires_grad=False) + grad_1 = generate_data(model_config, dtype, requires_grad=False) + input_2 = generate_data(model_config, dtype, requires_grad=False) + grad_2 = generate_data(model_config, dtype, requires_grad=False) + expected_grad = torch.einsum("...o,...i->oi", grad_1, input_1) + torch.einsum( + "...o,...i->oi", grad_2, input_2 + ) + + try: + model.zero_grad(set_to_none=True) + model(input_1).backward(grad_1) + model(input_2).backward(grad_2) + torch.testing.assert_close(model.weight.grad, expected_grad, rtol=0, atol=0) + finally: + reset_graphs(model) + + +def test_make_graphed_callables_preserves_skipped_parameter_grad_alias() -> None: + """Delayed-wgrad parameters are excluded from returned-grad clone handling.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model.weight.skip_backward_post_hook = True + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + + seen_grads = [] + + def save_grad(grad): + seen_grads.append(grad) + return grad + + hook = model.weight.register_hook(save_grad) + try: + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 1 + first_grad_ptr = seen_grads[0].data_ptr() + + model.zero_grad(set_to_none=True) + + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 2 + assert seen_grads[1].data_ptr() == first_grad_ptr + finally: + hook.remove() + reset_graphs(model) + + +def test_make_graphed_callables_snapshots_parameter_grad_clone_policy() -> None: + """Parameter grad clone policy is fixed at capture time.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + model.weight.skip_backward_post_hook = True + + seen_grads = [] + + def save_grad(grad): + seen_grads.append(grad) + return grad + + hook = model.weight.register_hook(save_grad) + try: + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 1 + first_grad = seen_grads[0] + first_grad_ptr = first_grad.data_ptr() + first_grad_snapshot = first_grad.clone() + + model.zero_grad(set_to_none=True) + + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 2 + assert seen_grads[1].data_ptr() != first_grad_ptr + torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0) + finally: + hook.remove() + reset_graphs(model) + + def _test_cuda_graphs_with_interleaved_pipeline_parallelism( *, with_graph: bool, model_config: ModelConfig, dtype: torch.dtype, -) -> List[torch.Tensor]: + reuse_graph_input_output_buffers: bool = False, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Simulate Megatron-LM interleaved pipeline parallelism.""" reset_rng_states() @@ -675,6 +848,7 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( sample_args, allow_unused_input=True, _order=layer_order, + _reuse_graph_input_output_buffers=reuse_graph_input_output_buffers, ) layer_forwards = { (i // num_microbatches, i % num_microbatches): forward @@ -701,11 +875,15 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( # Cache for layer outputs. outputs = {} + output_snapshots = {} if reuse_graph_input_output_buffers else None def forward(layer_idx: int, microbatch_idx: int): """Helper function for forward steps""" idxs = (layer_idx, microbatch_idx) outputs[idxs] = layer_forwards[idxs](inputs[idxs]) + if output_snapshots is not None: + # Reused graph output buffers are only valid until their corresponding backward. + output_snapshots[idxs] = outputs[idxs].detach().clone() def backward(layer_idx: int, microbatch_idx: int): """Helper function for backward steps""" @@ -728,11 +906,13 @@ def backward(layer_idx: int, microbatch_idx: int): # Optimizer step. optimizer.step() - outputs = [y for _, y in sorted(outputs.items())] - outputs = get_outputs(model, outputs) + output_values = output_snapshots if output_snapshots is not None else outputs + output_values = [y for _, y in sorted(output_values.items())] + outputs = get_outputs(model, output_values) + final_weights = [param.detach().clone() for param in model.parameters()] if with_graph: reset_graphs(layer_forwards) - return outputs + return outputs, final_weights def test_make_graphed_callables_with_interleaved_pipeline_parallelism( @@ -743,12 +923,34 @@ def test_make_graphed_callables_with_interleaved_pipeline_parallelism( """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" model_config = model_configs[model_config] kwargs = dict(model_config=model_config, dtype=dtype) - outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=False, + **kwargs, + ) + graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=True, + **kwargs, + ) + assert_all_equal(outputs, graph_outputs) + assert_all_equal(weights, graph_weights) + + +def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers( + *, + model_config: str = "small", + dtype: torch.dtype = torch.float16, +) -> None: + """Test CUDA graphs with reused input/output buffers.""" + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) + outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=False, **kwargs, ) - graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=True, + reuse_graph_input_output_buffers=True, **kwargs, ) assert_all_equal(outputs, graph_outputs) + assert_all_equal(weights, graph_weights) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 075db1394b..fba2178786 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -407,6 +407,17 @@ def _make_graphed_callables( bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] + def _returned_param_grad_slots(static_grad_inputs, module_params): + """Snapshot static grad slots that are consumed through Graphed.backward.""" + module_param_start = len(static_grad_inputs) - len(module_params) + return tuple( + idx >= module_param_start + and not getattr( + module_params[idx - module_param_start], "skip_backward_post_hook", False + ) + for idx in range(len(static_grad_inputs)) + ) + # For cases with multiple active RNG states, e.g. TP. if graph_safe_rng_available(): for _, state in get_all_rng_states().items(): @@ -569,6 +580,7 @@ def hook_fn( per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) per_callable_static_grad_outputs = [None] * len(flatten_sample_args) per_callable_static_grad_inputs = [None] * len(flatten_sample_args) + per_callable_returned_param_grad_slots = [None] * len(flatten_sample_args) fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks static_grad_outputs_dict = {} @@ -716,6 +728,13 @@ def hook_fn( per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs + returned_param_grad_slots = _returned_param_grad_slots( + static_grad_inputs, + per_callable_module_params[per_callable_bwd_idx], + ) + per_callable_returned_param_grad_slots[per_callable_bwd_idx] = ( + returned_param_grad_slots + ) # Weak ref the static outputs and static grad inputs that are no longer needed # in the following steps. These two type of tensors are both in cudagraph @@ -728,6 +747,18 @@ def hook_fn( static_outputs ) + # Parameter grads are cloned before being returned from + # Graphed.backward, so their static buffers can be weak-refed now. + static_grad_inputs = per_callable_static_grad_inputs[per_callable_bwd_idx] + per_callable_static_grad_inputs[per_callable_bwd_idx] = tuple( + ( + make_weak_ref(grad_input) + if returned_param_grad_slots[idx] and grad_input is not None + else grad_input + ) + for idx, grad_input in enumerate(static_grad_inputs) + ) + # Weak ref the static grad inputs of the previous backward pass within the # same chunk. if previous_per_callable_bwd_idx is not None: @@ -769,6 +800,7 @@ def hook_fn( # Capture backward graphs in reverse order per_callable_static_grad_outputs = [] per_callable_static_grad_inputs = [] + per_callable_returned_param_grad_slots = [] for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip( reversed(per_callable_static_input_surfaces), reversed(per_callable_static_outputs), @@ -813,10 +845,19 @@ def hook_fn( per_callable_static_grad_outputs.append(static_grad_outputs) per_callable_static_grad_inputs.append(static_grad_inputs) + per_callable_returned_param_grad_slots.append( + _returned_param_grad_slots( + static_grad_inputs, + per_callable_module_params[bwd_idx], + ) + ) - # Reverses the most recent two lists + # Reverse the most recent per-callable lists. per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs)) per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) + per_callable_returned_param_grad_slots = list( + reversed(per_callable_returned_param_grad_slots) + ) # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. def make_graphed_autograd_function( @@ -830,6 +871,7 @@ def make_graphed_autograd_function( static_outputs, static_grad_outputs, static_grad_inputs, + returned_param_grad_slots, ): class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" @@ -911,9 +953,17 @@ def backward(ctx, *grads): "Expected static_grad_inputs to be a tuple, but got" f" {type(static_grad_inputs).__name__}" ) - return (None, None, None) + tuple( - b.detach() if b is not None else b for b in static_grad_inputs - ) + grad_inputs = [] + for idx, grad_input in enumerate(static_grad_inputs): + if grad_input is None: + grad_inputs.append(None) + elif returned_param_grad_slots[idx]: + # Returned parameter grads may be installed directly as param.grad. + # Clone to avoid exposing CUDA graph static buffers to autograd users. + grad_inputs.append(grad_input.detach().clone()) + else: + grad_inputs.append(grad_input.detach()) + return (None, None, None) + tuple(grad_inputs) def functionalized(*user_args, **user_kwargs): @@ -1008,6 +1058,7 @@ def reset(): per_callable_static_outputs[i], per_callable_static_grad_outputs[i], per_callable_static_grad_inputs[i], + per_callable_returned_param_grad_slots[i], ) func = graph_callables[i]