diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index b53e9efe21..a77a691489 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -257,6 +257,13 @@ def fsdp_apply_gradients( flat_grads = tree_flatten(gradients) flat_params = tree_flatten(parameters) + # FSDP saving state improvement + # - tree_flatten order is not stable → sort by key for deterministic grouping + # - reconstruct the original tree structure before the optimizer step so the + # optimizer state is no longer tied to the number of communication groups + flat_grads = sorted(flat_grads, key=lambda x: x[0]) + flat_params = sorted(flat_params, key=lambda x: x[0]) + keys, shapes, sizes, dtypes = _extract_info(flat_grads) itemsize = dtypes[0].size @@ -264,45 +271,61 @@ def fsdp_apply_gradients( S = fsdp_group.size() fsdp_rank = fsdp_group.rank() - # reduce-scatter gradients, shard parameters - grad_slices = {} - param_slices = {} + + # reduce-scatter gradients + shard parameters (grouped for performance) + grad_shards = [] + param_shards = [] for group_idx, arr_group in enumerate(groups): big_grad = mx.concatenate( [flat_grads[i][1].reshape(S, -1) for i in arr_group], axis=1 ) - grad_slices[group_idx] = ( - mx.distributed.sum_scatter( - big_grad, group=fsdp_group, stream=communication_stream - ) - / N - ) + grad_shard = mx.distributed.sum_scatter( + big_grad, group=fsdp_group, stream=communication_stream + ) / N if dp_group is not None: - grad_slices[group_idx] = mx.distributed.all_sum( - grad_slices[group_idx], group=dp_group, stream=communication_stream + grad_shard = mx.distributed.all_sum( + grad_shard, group=dp_group, stream=communication_stream ) + grad_shards.append(grad_shard) + big_param = mx.concatenate( [flat_params[i][1].reshape(S, -1) for i in arr_group], axis=1 ) - param_slices[group_idx] = big_param[fsdp_rank] + param_shards.append(big_param[fsdp_rank]) + + # Reconstruct original tree with local shards + sharded_grads_flat = [] + sharded_params_flat = [] + for g_idx, arr_group in enumerate(groups): + split_sizes = [sizes[i] // S for i in arr_group] + for idx_in_group, i in enumerate(arr_group): + sharded_grads_flat.append( + (keys[i], grad_shards[g_idx].split(split_sizes)[idx_in_group]) + ) + sharded_params_flat.append((keys[i], param_shards[g_idx])) + + sharded_grads = tree_unflatten(sharded_grads_flat) + sharded_params = tree_unflatten(sharded_params_flat) # clip gradients if needed grad_norm = None if max_norm is not None: - grad_slices, grad_norm = _clip_grads_fsdp( - grad_slices, max_norm, group=fsdp_group + sharded_grads, grad_norm = _clip_grads_fsdp( + sharded_grads, max_norm, group=fsdp_group ) # optimizer step - updated_param_slices = optimizer.apply_gradients(grad_slices, param_slices) + updated_sharded_params = optimizer.apply_gradients(sharded_grads, sharded_params) + + # all-gather and reconstruct full parameters + flat_updated = tree_flatten(updated_sharded_params) + flat_updated = sorted(flat_updated, key=lambda x: x[0]) - # all-gather and reconstruct new_flat = [] + idx = 0 for group_idx, arr_group in enumerate(groups): big_gathered = mx.distributed.all_gather( - updated_param_slices[group_idx], - group=fsdp_group, - stream=communication_stream, + flat_updated[idx][1], group=fsdp_group, stream=communication_stream ) split_sizes = [sizes[i] // S for i in arr_group] split_indices = [] @@ -314,6 +337,7 @@ def fsdp_apply_gradients( parts = mx.split(big_gathered, split_indices[:-1], axis=1) for idx_in_group, i in enumerate(arr_group): new_flat.append((keys[i], parts[idx_in_group].reshape(shapes[i]))) + idx += 1 result = tree_unflatten(new_flat) if max_norm is not None: