Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions python/mlx/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,52 +257,75 @@ def fsdp_apply_gradients(
flat_grads = tree_flatten(gradients)
flat_params = tree_flatten(parameters)

# FSDP saving state improvement
# - tree_flatten order is not stable → sort by key for deterministic grouping
# - reconstruct the original tree structure before the optimizer step so the
# optimizer state is no longer tied to the number of communication groups
flat_grads = sorted(flat_grads, key=lambda x: x[0])
flat_params = sorted(flat_params, key=lambda x: x[0])

keys, shapes, sizes, dtypes = _extract_info(flat_grads)
itemsize = dtypes[0].size

groups = _group_by_size(keys, sizes, itemsize, communication_size)

S = fsdp_group.size()
fsdp_rank = fsdp_group.rank()
# reduce-scatter gradients, shard parameters
grad_slices = {}
param_slices = {}

# reduce-scatter gradients + shard parameters (grouped for performance)
grad_shards = []
param_shards = []
for group_idx, arr_group in enumerate(groups):
big_grad = mx.concatenate(
[flat_grads[i][1].reshape(S, -1) for i in arr_group], axis=1
)
grad_slices[group_idx] = (
mx.distributed.sum_scatter(
big_grad, group=fsdp_group, stream=communication_stream
)
/ N
)
grad_shard = mx.distributed.sum_scatter(
big_grad, group=fsdp_group, stream=communication_stream
) / N
if dp_group is not None:
grad_slices[group_idx] = mx.distributed.all_sum(
grad_slices[group_idx], group=dp_group, stream=communication_stream
grad_shard = mx.distributed.all_sum(
grad_shard, group=dp_group, stream=communication_stream
)
grad_shards.append(grad_shard)

big_param = mx.concatenate(
[flat_params[i][1].reshape(S, -1) for i in arr_group], axis=1
)
param_slices[group_idx] = big_param[fsdp_rank]
param_shards.append(big_param[fsdp_rank])

# Reconstruct original tree with local shards
sharded_grads_flat = []
sharded_params_flat = []
for g_idx, arr_group in enumerate(groups):
split_sizes = [sizes[i] // S for i in arr_group]
for idx_in_group, i in enumerate(arr_group):
sharded_grads_flat.append(
(keys[i], grad_shards[g_idx].split(split_sizes)[idx_in_group])
)
sharded_params_flat.append((keys[i], param_shards[g_idx]))

sharded_grads = tree_unflatten(sharded_grads_flat)
sharded_params = tree_unflatten(sharded_params_flat)

# clip gradients if needed
grad_norm = None
if max_norm is not None:
grad_slices, grad_norm = _clip_grads_fsdp(
grad_slices, max_norm, group=fsdp_group
sharded_grads, grad_norm = _clip_grads_fsdp(
sharded_grads, max_norm, group=fsdp_group
)

# optimizer step
updated_param_slices = optimizer.apply_gradients(grad_slices, param_slices)
updated_sharded_params = optimizer.apply_gradients(sharded_grads, sharded_params)

# all-gather and reconstruct full parameters
flat_updated = tree_flatten(updated_sharded_params)
flat_updated = sorted(flat_updated, key=lambda x: x[0])

# all-gather and reconstruct
new_flat = []
idx = 0
for group_idx, arr_group in enumerate(groups):
big_gathered = mx.distributed.all_gather(
updated_param_slices[group_idx],
group=fsdp_group,
stream=communication_stream,
flat_updated[idx][1], group=fsdp_group, stream=communication_stream
)
split_sizes = [sizes[i] // S for i in arr_group]
split_indices = []
Expand All @@ -314,6 +337,7 @@ def fsdp_apply_gradients(
parts = mx.split(big_gathered, split_indices[:-1], axis=1)
for idx_in_group, i in enumerate(arr_group):
new_flat.append((keys[i], parts[idx_in_group].reshape(shapes[i])))
idx += 1

result = tree_unflatten(new_flat)
if max_norm is not None:
Expand Down