Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 71 additions & 52 deletions megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,35 @@ def __init__(
self._init_sub_optimizers()
self._register_load_state_dict_hooks()

def _iter_current_param_group_entries(self):
for group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group["params"]):
yield (group_id, param_id), param

def _get_current_param(self, param_group_index):
group_id, param_id = param_group_index
return self.param_groups[group_id]["params"][param_id]

def _set_sub_optimizer_grads(self):
if self.param_update_in_fp32:
for param in self.param_to_fp32_param:
if param in self.gpu_params_map_cpu_copy:
# Skip if the param is offloaded to CPU, it should be handled
# in the following part.
for param_group_index, param in self._iter_current_param_group_entries():
inner_param = self.param_group_index_to_inner_param[param_group_index]
if not inner_param.is_cuda:
# Offloaded params are handled in the following part.
continue
fp32_param = self.param_to_fp32_param[param]
grad = getattr(param, "decoupled_grad", param.grad)
if grad is not None:
fp32_param.grad = grad.to(fp32_param.dtype)
fp32_param.requires_grad = True
inner_param.grad = grad.to(inner_param.dtype)
inner_param.requires_grad = True
else:
fp32_param.requires_grad = False
inner_param.requires_grad = False

# Sync the grads from GPU to CPU.
for optimizer in self.cpu_optimizers:
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
grad = getattr(gpu_param, "decoupled_grad", gpu_param.grad)
param_group_index = self.inner_param_to_param_group_index[param]
current_param = self._get_current_param(param_group_index)
grad = getattr(current_param, "decoupled_grad", current_param.grad)
if grad is None:
param.requires_grad = False
continue
Expand All @@ -120,24 +129,20 @@ def param_copy_back_gpu_hook(optimizer, args, kwargs):
self._h2d_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._h2d_stream):
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
gpu_param.data.copy_(param.data, non_blocking=True)
param_group_index = self.inner_param_to_param_group_index[param]
current_param = self._get_current_param(param_group_index)
current_param.data.copy_(param.data, non_blocking=True)
self._h2d_stream.record_event().wait(torch.cuda.current_stream())

return param_copy_back_gpu_hook

def fp32_param_copy_back_gpu_hook_closure():
def fp32_param_copy_back_gpu_hook(optimizer, args, kwargs):
for group in self.param_groups:
for param in group["params"]:
if param in self.gpu_params_map_cpu_copy:
# Skip if the param is offloaded to GPU, it has been
# copied back in the previous hook.
continue

if param in self.param_to_fp32_param:
fp32_param = self.param_to_fp32_param[param]
param.data.copy_(fp32_param.data)
for param_group_index, param in self._iter_current_param_group_entries():
inner_param = self.param_group_index_to_inner_param[param_group_index]
if not inner_param.is_cuda or inner_param is param:
continue
param.data.copy_(inner_param.data)

return fp32_param_copy_back_gpu_hook

Expand Down Expand Up @@ -188,16 +193,23 @@ def _init_sub_optimizers(self):
) = self._get_sub_optimizer_param_groups(self.offload_fraction)
self.param_to_inner_param = {}
self.inner_param_to_orig_param = {}
for group in self.param_groups:
for param in group["params"]:
if param in self.param_to_fp32_param:
inner_param = self.param_to_fp32_param[param]
elif param in self.gpu_params_map_cpu_copy:
inner_param = self.gpu_params_map_cpu_copy[param]
else:
inner_param = param
self.param_to_inner_param[param] = inner_param
self.inner_param_to_orig_param[inner_param] = param
# Keep a stable logical mapping by param-group position so wrapped optimizers
# can replace tensor objects in-place without breaking HDO's bookkeeping.
self.param_group_index_to_orig_param = {}
self.param_group_index_to_inner_param = {}
self.inner_param_to_param_group_index = {}
for param_group_index, param in self._iter_current_param_group_entries():
self.param_group_index_to_orig_param[param_group_index] = param
if param in self.param_to_fp32_param:
inner_param = self.param_to_fp32_param[param]
elif param in self.gpu_params_map_cpu_copy:
inner_param = self.gpu_params_map_cpu_copy[param]
else:
inner_param = param
self.param_to_inner_param[param] = inner_param
self.inner_param_to_orig_param[inner_param] = param
self.param_group_index_to_inner_param[param_group_index] = inner_param
self.inner_param_to_param_group_index[inner_param] = param_group_index
Comment on lines 194 to +212
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: param_to_inner_param, inner_param_to_orig_param, and param_group_index_to_orig_param are still populated here but no longer read anywhere after this PR. They could be removed to avoid confusion about which mapping is authoritative (the new param_group_index_to_* / inner_param_to_param_group_index dicts).

self.fp32_param_to_orig_param = {v: k for k, v in self.param_to_fp32_param.items()}

self.cpu_optimizers = []
Expand Down Expand Up @@ -314,39 +326,35 @@ def _sync_sub_optimizers_state_to_hdo(self):
new_state = defaultdict(dict)
for optimizer in self.sub_optimizers:
for param in optimizer.state:
orig_param = self.inner_param_to_orig_param[param]
new_state[orig_param] = optimizer.state[param]
param_group_index = self.inner_param_to_param_group_index[param]
current_param = self._get_current_param(param_group_index)
new_state[current_param] = optimizer.state[param]
if self.param_update_in_fp32:
new_state[orig_param]["master_param"] = param
new_state[current_param]["master_param"] = param
self.state = new_state

def _sync_hdo_state_to_sub_optimizers(self):
for optimizer in self.sub_optimizers:
new_state = defaultdict(dict)
for group in optimizer.param_groups:
for param in group["params"]:
orig_param = self.inner_param_to_orig_param[param]
new_state[param] = self.state[orig_param]
param_group_index = self.inner_param_to_param_group_index[param]
current_param = self._get_current_param(param_group_index)
new_state[param] = self.state[current_param]
optimizer.state = new_state
self._update_fp32_params_by_new_state()
self._move_new_state_to_right_device()

def _sync_hdo_param_groups_to_sub_optimizers(self):
"""Sync HDO new param_groups attribute (e.g. lr, wd, etc.) to sub-optimizers."""
param_in_param_group_index = {}
for i, group in enumerate(self.param_groups):
for p_id, param in enumerate(group["params"]):
inner_param = self.param_to_inner_param[param]
param_in_param_group_index[inner_param] = (i, p_id)

for optimizer in self.sub_optimizers:
new_param_groups = []
for group in optimizer.param_groups:
new_group = group.copy()
# After sync-up the sub-optimizer last update, we need to sync-up the
# HDO new param_groups attributes to the sub-optimizer.
assert len(group["params"]) > 0, "param_groups should not be empty"
group_id, _ = param_in_param_group_index[group["params"][0]]
group_id, _ = self.inner_param_to_param_group_index[group["params"][0]]
update_group_attrs = self.param_groups[group_id].copy()
del update_group_attrs["params"]
new_group.update(update_group_attrs)
Expand All @@ -360,25 +368,36 @@ def _move_new_state_to_right_device(self):
for k, v in state.items():
if not isinstance(v, torch.Tensor):
continue
orig_param = self.inner_param_to_orig_param.get(param, param)
param_group_index = self.inner_param_to_param_group_index.get(param)
current_param = (
self._get_current_param(param_group_index)
if param_group_index is not None
else param
)
if isinstance(optimizer, self.defaults["cpu_optimizer_cls"]):
self.state[orig_param][k] = state[k] = v.to("cpu")
self.state[current_param][k] = state[k] = v.to("cpu")
else:
self.state[orig_param][k] = state[k] = v.to("cuda")
self.state[current_param][k] = state[k] = v.to("cuda")

def _update_fp32_params_by_new_state(self):
if not self.param_update_in_fp32:
return
for param, v in self.state.items():
fp32_param = self.param_to_fp32_param[param]
fp32_param.data.copy_(v["master_param"])
for param_group_index, inner_param in self.param_group_index_to_inner_param.items():
current_param = self._get_current_param(param_group_index)
state = self.state.get(current_param)
if state is None or "master_param" not in state:
continue
inner_param.data.copy_(state["master_param"])

def update_fp32_param_by_new_param(self):
"""
Update the fp32 parameters by the new parameters.
"""
for param, fp32_param in self.param_to_fp32_param.items():
fp32_param.data.copy_(param)
for param_group_index, inner_param in self.param_group_index_to_inner_param.items():
current_param = self._get_current_param(param_group_index)
if inner_param is current_param or inner_param.dtype != torch.float32:
continue
inner_param.data.copy_(current_param)

def _register_load_state_dict_hooks(self):
def pre_load_state_dict_hook(self, state_dict):
Expand Down
64 changes: 64 additions & 0 deletions tests/unit_tests/test_optimizer_cpu_offloading.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import random
from unittest.mock import patch

import numpy as np
import pytest
Expand All @@ -17,6 +18,9 @@
from torch.optim import Adam as GPUAdam

from megatron.core.optimizer.cpu_offloading import HybridDeviceOptimizer
from megatron.core.optimizer.grad_scaler import ConstantGradScaler
from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params
from megatron.core.optimizer.optimizer_config import OptimizerConfig


class Net(nn.Module):
Expand Down Expand Up @@ -253,3 +257,63 @@ def test_overlap_cpu_optimizer_d2h_h2d_sync_correctness(
assert torch.allclose(
v, ref_params[k], atol=1e-03
), f"Weight {k} value mismatch, max error: {(v - ref_params[k]).abs().max()}"


@pytest.mark.parametrize(
'dtype',
[
torch.bfloat16,
torch.float16,
],
)
@pytest.mark.skipif(
torch.__version__ < '2.3.0',
reason=(
"Requires PyTorch 2.3.0 or higher, lower versions of pytorch have "
"misaligned optimizer accuracy for CPU and GPU."
),
)
def test_hybrid_optimizer_with_float16_wrapper_first_step(dtype):
setup_seed(123)
net = Net().cuda().to(dtype=dtype)
original_param = next(net.parameters())
optimizer_config = OptimizerConfig(
optimizer='adam',
lr=1e-3,
bf16=dtype == torch.bfloat16,
fp16=dtype == torch.float16,
clip_grad=0.0,
)
hdo = HybridDeviceOptimizer(
list(net.parameters()),
offload_fraction=0.5,
lr=1e-3,
cpu_optimizer_cls=Adam,
gpu_optimizer_cls=GPUAdam,
overlap_cpu_optimizer_d2h_h2d=False,
param_update_in_fp32=True,
)
optimizer = Float16OptimizerWithFloat16Params(
hdo,
optimizer_config,
None if dtype == torch.bfloat16 else ConstantGradScaler(1.0),
lambda *args, **kwargs: None,
)
optimizer.grad_stats_parallel_group = None

wrapped_param = hdo.param_groups[0]['params'][0]
assert wrapped_param is not original_param
assert wrapped_param.dtype == torch.float32

input = torch.randn(1, 3, 32, 32, device='cuda', dtype=dtype)
before = original_param.detach().float().clone()
output = net(input)
output.sum().backward()

with patch('torch.distributed.all_reduce', return_value=None):
success, _, _ = optimizer.step()

assert success
assert len(hdo.state_dict()["state"]) != 0
assert len(optimizer.state_dict()["optimizer"]["state"]) != 0
assert not torch.allclose(original_param.detach().float(), before)
Loading