Skip to content

Commit 36df9ff

Browse files
wz337meta-codesync[bot]
authored andcommitted
sync D77694359 to OSS
Summary: sync D77694359 to OSS Reviewed By: jialun-zhang Differential Revision: D86571324 fbshipit-source-id: a81e8868a85f275ecf917f837e659e21e5e6b11a
1 parent f669cbf commit 36df9ff

6 files changed

Lines changed: 251 additions & 27 deletions

File tree

distributed_shampoo/distributed_shampoo.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,13 @@ def _instantiate_distributor(self) -> None:
549549
param_assignment_strategy=FSDPParamAssignmentStrategy.DEFAULT
550550
):
551551
distributor_cls = FullyShardDistributor
552-
case FullyShardDistributedConfig(
553-
param_assignment_strategy=FSDPParamAssignmentStrategy.REPLICATE
552+
case (
553+
FullyShardDistributedConfig(
554+
param_assignment_strategy=FSDPParamAssignmentStrategy.REPLICATE
555+
)
556+
| FullyShardDistributedConfig(
557+
param_assignment_strategy=FSDPParamAssignmentStrategy.ROUND_ROBIN
558+
)
554559
):
555560
distributor_cls = FullyShardLosslessDistributor
556561
case _:

distributed_shampoo/distributor/_shampoo_fully_shard_lossless_distributor.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
FullyShardDistributedConfig,
2222
PARAMS,
2323
)
24+
from distributed_shampoo.utils.shampoo_utils import (
25+
prepare_update_param_buffers,
26+
redistribute_and_update_params,
27+
)
2428
from torch import distributed as dist, Tensor
2529

2630
logger: logging.Logger = logging.getLogger(__name__)
@@ -50,14 +54,41 @@ def __init__(self, param_group: dict[str, Any]) -> None:
5054
logger.info(
5155
f"Shampoo FullyShardLosslessDistributor {self._param_assignment_strategy=}",
5256
)
53-
# Stores full parameters (as opposed to DTensors) for the model parameters assigned to this rank.
54-
# For example, when the strategy is REPLICATE, it stores the full parameters on all ranks.
57+
58+
self._group_size: int = dist.get_world_size()
59+
self._dist_group: dist.ProcessGroup = dist.new_subgroups(
60+
group_size=self._group_size
61+
)[0]
62+
self._group_rank: int = dist.get_rank(group=self._dist_group)
63+
64+
should_assign_param_idx = (
65+
lambda i: i % self._group_size == self._group_rank
66+
if self._param_assignment_strategy
67+
== FSDPParamAssignmentStrategy.ROUND_ROBIN
68+
else True
69+
)
70+
self._assigned_params_mask: tuple[bool, ...] = tuple(
71+
should_assign_param_idx(idx) for idx in range(len(param_group[PARAMS]))
72+
)
73+
74+
# Collects and stores the model parameters assigned to this rank.
5575
# Note that we explicitly disable the unnecessary gradient tracking for the all-gather collectives
5676
# used to initialize the full parameters.
5777
with torch.no_grad():
58-
self._assigned_full_params: tuple[torch.Tensor, ...] = tuple(
59-
p.full_tensor() for p in param_group[PARAMS]
60-
)
78+
full_params: list[Tensor] = [p.full_tensor() for p in param_group[PARAMS]]
79+
self._assigned_full_params: list[Tensor] = [
80+
p
81+
for p, assigned in zip(full_params, self._assigned_params_mask)
82+
if assigned
83+
]
84+
85+
# For ROUND_ROBIN strategy, creates a buffer for receiving the updated param shards.
86+
self._update_param_buffers: list[Tensor] | None = (
87+
prepare_update_param_buffers(param_group[PARAMS], self._group_size)
88+
if self._param_assignment_strategy
89+
== FSDPParamAssignmentStrategy.ROUND_ROBIN
90+
else None
91+
)
6192

6293
super().__init__(param_group)
6394

@@ -86,11 +117,18 @@ def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None
86117
if get_grad:
87118
# Getting grads at every optimizer step triggers implicit all-gather. Note that p.numel()
88119
# returns total number of elements in the tensor (as opposed to local shard of DTensor).
89-
return (
120+
full_grads = (
90121
None if p.grad is None else p.grad.full_tensor()
91122
for p in self._param_group[PARAMS]
92-
if p.numel() > 0
93123
)
124+
return (
125+
full_grad
126+
for full_grad, assigned in zip(
127+
full_grads, self._assigned_params_mask, strict=True
128+
)
129+
if assigned and (full_grad is None or full_grad.numel() > 0)
130+
)
131+
94132
else:
95133
return filter(
96134
lambda p: isinstance(p, Tensor) and p.numel() > 0,
@@ -115,7 +153,15 @@ def update_params(
115153
# For example, when the strategy is REPLICATE, we need to take each updated full parameter `full_param`,
116154
# redistribute it according to the device mesh to get the locally assigned slice, and copy the slice to the
117155
# corresponding local parameter `local_param` in the param group.
118-
if self._param_assignment_strategy == FSDPParamAssignmentStrategy.REPLICATE:
156+
if self._param_assignment_strategy == FSDPParamAssignmentStrategy.ROUND_ROBIN:
157+
redistribute_and_update_params(
158+
self._param_group[PARAMS],
159+
self._assigned_full_params,
160+
self._update_param_buffers, # type: ignore
161+
self._dist_group,
162+
)
163+
164+
elif self._param_assignment_strategy == FSDPParamAssignmentStrategy.REPLICATE:
119165
local_params = list(
120166
filter(lambda p: p.numel() > 0, self._param_group[PARAMS])
121167
)
@@ -143,5 +189,11 @@ def update_params(
143189
def _construct_local_block_info_list(self) -> tuple[BlockInfo, ...]:
144190
"""Construct local block info list from param_group."""
145191
return self._construct_local_block_info_list_with_params(
146-
params=filter(lambda p: p.numel() > 0, self._param_group[PARAMS])
192+
params=(
193+
p
194+
for assigned, p in zip(
195+
self._assigned_params_mask, self._param_group[PARAMS], strict=True
196+
)
197+
if assigned and p.numel() > 0
198+
),
147199
)

distributed_shampoo/distributor/gpu_tests/shampoo_fully_shard_lossless_distributor_test.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,20 @@ def _shampoo_optim_factory(
127127
@with_comms
128128
@skip_if_lt_x_gpu(2)
129129
@parametrize("model_linear_layers_dims", TEST_MODEL_LAYER_DIMS)
130+
@parametrize(
131+
"param_assignment_strategy",
132+
(
133+
FSDPParamAssignmentStrategy.REPLICATE,
134+
FSDPParamAssignmentStrategy.ROUND_ROBIN,
135+
),
136+
)
130137
def test_all_ranks_with_no_grads(
131138
self,
132139
model_linear_layers_dims: tuple[int, ...],
140+
param_assignment_strategy: FSDPParamAssignmentStrategy,
133141
) -> None:
134142
fully_shard_config = FullyShardDistributedConfig(
135-
param_assignment_strategy=FSDPParamAssignmentStrategy.REPLICATE
143+
param_assignment_strategy=param_assignment_strategy
136144
)
137145

138146
steps_without_gradients = 2
@@ -156,13 +164,21 @@ def test_all_ranks_with_no_grads(
156164

157165
@with_comms
158166
@skip_if_lt_x_gpu(2)
167+
@parametrize(
168+
"param_assignment_strategy",
169+
(
170+
FSDPParamAssignmentStrategy.REPLICATE,
171+
FSDPParamAssignmentStrategy.ROUND_ROBIN,
172+
),
173+
)
159174
@parametrize("model_linear_layers_dims", TEST_MODEL_LAYER_DIMS)
160175
def test_fully_shard_shampoo_against_default_shampoo(
161176
self,
177+
param_assignment_strategy: FSDPParamAssignmentStrategy,
162178
model_linear_layers_dims: tuple[int, ...],
163179
) -> None:
164180
fully_shard_config = FullyShardDistributedConfig(
165-
param_assignment_strategy=FSDPParamAssignmentStrategy.REPLICATE
181+
param_assignment_strategy=param_assignment_strategy
166182
)
167183
control_model_factory = partial(
168184
ShampooFullyShardLosslessDistributorTest._construct_model,

distributed_shampoo/tests/distributed_shampoo_test.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
DistributedConfig,
3737
EigendecomposedShampooPreconditionerConfig,
3838
EigenvalueCorrectedShampooPreconditionerConfig,
39-
FSDPParamAssignmentStrategy,
40-
FullyShardDistributedConfig,
4139
PreconditionerConfig,
4240
RootInvShampooPreconditionerConfig,
4341
ShampooPT2CompileConfig,
@@ -259,16 +257,6 @@ class NotSupportedDistributedConfig(DistributedConfig):
259257
distributed_config=NotSupportedDistributedConfig(),
260258
)
261259

262-
self.assertRaisesRegex(
263-
NotImplementedError,
264-
r"group\[DISTRIBUTED_CONFIG\]=.*FullyShardDistributedConfig\(.*ROUND_ROBIN.*\) not supported!",
265-
DistributedShampoo,
266-
params=self._model.parameters(),
267-
distributed_config=FullyShardDistributedConfig(
268-
param_assignment_strategy=FSDPParamAssignmentStrategy.ROUND_ROBIN
269-
),
270-
)
271-
272260

273261
class DistributedShampooTest(unittest.TestCase):
274262
def setUp(self) -> None:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
All rights reserved.
4+
5+
This source code is licensed under the BSD-style license found in the
6+
LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
10+
#!/usr/bin/env python3
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
from distributed_shampoo.utils.shampoo_utils import (
17+
prepare_update_param_buffers,
18+
redistribute_and_update_params,
19+
)
20+
from torch import distributed as dist
21+
from torch.distributed.device_mesh import init_device_mesh
22+
from torch.distributed.tensor import distribute_tensor, Shard
23+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
24+
from torch.testing._internal.common_utils import (
25+
instantiate_parametrized_tests,
26+
parametrize,
27+
)
28+
from torch.testing._internal.distributed._tensor.common_dtensor import (
29+
DTensorTestBase,
30+
with_comms,
31+
)
32+
33+
34+
def generate_param_shapes(num_params: int) -> list[tuple[int, ...]]:
35+
"""Generate parameter shapes for testing.
36+
37+
For N parameters, we generate the following shapes:
38+
[(1, 2), (2, 3), (3, 4), ..., (N, N + 1)].
39+
"""
40+
return [(i, i + 1) for i in range(1, num_params + 1)]
41+
42+
43+
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
44+
@instantiate_parametrized_tests
45+
class RedistributeAndUpdateParamsTest(DTensorTestBase):
46+
@property
47+
def world_size(self) -> int:
48+
return 4
49+
50+
@with_comms
51+
@skip_if_lt_x_gpu(4)
52+
@parametrize("num_params", (1, 4, 7))
53+
def test_redistribute_and_update_params(self, num_params: int) -> None:
54+
device_mesh = init_device_mesh("cuda", (4,))
55+
shapes = generate_param_shapes(num_params)
56+
params = [torch.zeros(s, device="cuda") for s in shapes]
57+
dtensor_params = tuple(
58+
distribute_tensor(t, device_mesh, [Shard(0)]) for t in params
59+
)
60+
61+
update_buffers = prepare_update_param_buffers(dtensor_params, self.world_size)
62+
self.assertEqual(
63+
len(update_buffers),
64+
int(np.ceil(num_params / self.world_size) * self.world_size),
65+
)
66+
for i, buffer in enumerate(update_buffers):
67+
if i < num_params:
68+
self.assertEqual(buffer.numel(), dtensor_params[i].to_local().numel())
69+
else:
70+
self.assertEqual(buffer.numel(), 0)
71+
72+
rank = dist.get_rank()
73+
dist_group = dist.distributed_c10d._get_default_group()
74+
# Fill the locally assigned parameters with the rank as value.
75+
local_full_params = [
76+
torch.zeros(s, device="cuda").fill_(rank)
77+
for i, s in enumerate(shapes)
78+
if i % self.world_size == rank
79+
]
80+
redistribute_and_update_params(
81+
dtensor_params, local_full_params, update_buffers, dist_group
82+
)
83+
for i, param in enumerate(dtensor_params):
84+
np.testing.assert_allclose(
85+
param.to_local().cpu().numpy(), i % self.world_size
86+
)

distributed_shampoo/utils/shampoo_utils.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
from types import TracebackType
1717
from typing import Any, TypeVar
1818

19+
import numpy as np
20+
1921
import torch
2022
from distributed_shampoo.shampoo_types import LoadBalancingConfig
2123
from distributed_shampoo.utils.load_balancing_utils import AlignedMemoryCostModel
22-
from torch import Tensor
24+
from torch import distributed as dist, Tensor
25+
from torch.distributed.tensor import DTensor
2326

2427

2528
@cache
@@ -329,4 +332,78 @@ def distribute_buffer_sizes(
329332

330333
buffer_size_ranks = tuple(zip(buffer_sizes_aligned, param_block_ranks, strict=True))
331334

332-
return buffer_size_ranks
335+
return tuple(buffer_size_ranks)
336+
337+
338+
def prepare_update_param_buffers(
339+
params: tuple[DTensor, ...], group_size: int
340+
) -> list[Tensor]:
341+
"""Allocates a persistent shadow copy of updated parameters."""
342+
if any(p.dtype != params[0].dtype for p in params):
343+
raise NotImplementedError(
344+
"When using round-robin assignment in FSDP Shampoo, parameters of "
345+
"different dtypes are not currently supported."
346+
)
347+
348+
param_sizes = [p.to_local().numel() for p in params]
349+
buffer_size = sum(param_sizes)
350+
buffer = params[0].to_local().new_zeros(buffer_size)
351+
buffer_offsets = np.cumsum(param_sizes).tolist()
352+
353+
def round_up_to_multiple_of(x: int, y: int) -> int:
354+
return ((x + y - 1) // y) * y
355+
356+
pad_len = round_up_to_multiple_of(len(buffer_offsets), group_size) - len(
357+
buffer_offsets
358+
)
359+
# Pad the list with empty tensors to ensure each rank participates in all-to-all.
360+
buffer_offsets.extend([buffer_size] * pad_len)
361+
# Drop the last element as torch.tensor_split takes indices as split points.
362+
buffer_offsets = buffer_offsets[:-1]
363+
364+
return list(torch.tensor_split(buffer, buffer_offsets))
365+
366+
367+
def redistribute_and_update_params(
368+
params: tuple[DTensor, ...],
369+
local_full_params: list[Tensor],
370+
update_param_buffers: list[Tensor],
371+
dist_group: torch.distributed.ProcessGroup,
372+
) -> None:
373+
"""Redistributes updated parameters to each parameter's rank."""
374+
group_size = dist_group.size()
375+
376+
# Run all-to-all collectives to exchange the updated parameters across
377+
# ranks in group. This implementation runs multiple rounds of a2a ops
378+
# if the number of parameters is larger than the world size.
379+
for a2a_round in range(len(update_param_buffers) // group_size):
380+
# Send either a valid full parameter, or a padding zero tensor.
381+
send_param = (
382+
local_full_params[a2a_round]
383+
if a2a_round < len(local_full_params)
384+
else params[0].to_local().new_zeros(0)
385+
)
386+
# Chunk the send_param to exactly group_size slices to distribute to
387+
# all ranks. We need to manually pad the result of torch.chunk since
388+
# it does not guarantee that the result has the desired chunks.
389+
send_list = [t.flatten() for t in torch.chunk(send_param, group_size, dim=0)]
390+
if len(send_list) < group_size:
391+
# NOTE: Intentionally use `torch.tensor_split` here to do a trivial
392+
# split to ensure that the padding is in contiguous memory space as
393+
# is required for all-to-all collectives.
394+
append_len = group_size - len(send_list)
395+
last_t = send_list[-1]
396+
split_indices = [send_list[-1].shape[0]] * append_len
397+
send_list.extend(torch.tensor_split(last_t, split_indices, dim=0)[1:])
398+
assert len(send_list) == group_size
399+
400+
# Specify receive list as a range of update_param_buffers.
401+
recv_list = update_param_buffers[
402+
a2a_round * group_size : (a2a_round + 1) * group_size
403+
]
404+
405+
dist.all_to_all(recv_list, send_list, dist_group)
406+
407+
torch._foreach_copy_(
408+
[p.to_local().flatten() for p in params], update_param_buffers[: len(params)]
409+
)

0 commit comments

Comments
 (0)