From a7c8e3b8c2e4ca79526e8bb556c7d210bf19b007 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 11 Dec 2025 15:05:56 +0100 Subject: [PATCH 1/2] feat: collective operations --- optimum/neuron/accelerate/accelerator.py | 31 +- optimum/neuron/accelerate/utils/__init__.py | 10 + optimum/neuron/accelerate/utils/misc.py | 3 +- optimum/neuron/accelerate/utils/operations.py | 276 +++++++++++++++--- 4 files changed, 271 insertions(+), 49 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index bcb1be71b..7d0d0d325 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -23,16 +23,18 @@ from typing import Any, Callable import torch +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from accelerate import Accelerator from accelerate.checkpointing import save_accelerator_state, save_custom_state from accelerate.utils import AutocastKwargs, DistributedType -from accelerate.utils.operations import gather_object, recursively_apply +from accelerate.utils.operations import recursively_apply from neuronx_distributed import parallel_layers from neuronx_distributed.optimizer import NeuronZero1Optimizer from neuronx_distributed.parallel_layers.parallel_state import ( get_context_model_parallel_size, + get_data_parallel_group, get_data_parallel_replica_groups, get_data_parallel_size, get_tensor_model_parallel_replica_groups, @@ -58,15 +60,13 @@ from .optimizer import NeuronAcceleratedOptimizer from .scheduler import NeuronAcceleratedScheduler from .state import NeuronAcceleratorState -from .utils import ( - patch_accelerate_is_torch_xla_available, -) from .utils.dataclasses import MixedPrecisionConfig, MixedPrecisionMode from .utils.misc import ( apply_activation_checkpointing, create_patched_save_pretrained, + patch_accelerate_is_torch_xla_available, ) -from .utils.operations import _xla_gather +from .utils.operations import gather_object # Setup logging so that the main process logs at the INFO level and the others are silent. @@ -390,7 +390,7 @@ def prepare_model( move_model_to_device(model, xm.xla_device()) model.tie_weights() - xm.mark_step() + torch_xla.sync() # Adding the model to the list of prepared models. self._models.append(model) @@ -474,7 +474,7 @@ def _inner(folder): logger.info(f"Saving current state to {output_dir}") # Finish running the previous step before checkpointing - xm.mark_step() + torch_xla.sync() # Save the models if save_model_func is not None: @@ -547,10 +547,18 @@ def save_state( output_dir=output_dir, safe_serialization=safe_serialization, **save_model_func_kwargs ) - def gather(self, tensor, out_of_graph: bool = False): - return _xla_gather(tensor, out_of_graph=out_of_graph) + def gather(self, tensor: torch.Tensor, sync: bool = False) -> torch.Tensor: + groups = get_data_parallel_group(as_list=True) + + # Ensure tensor is at least 1D for all_gather (scalars need to be unsqueezed) + input_tensor = tensor.unsqueeze(0) if tensor.ndim == 0 else tensor + gathered = xm.all_gather(input_tensor, dim=0, groups=groups, pin_layout=False) + + if sync: + torch_xla.sync() + return gathered - def gather_for_metrics(self, input_data, use_gather_object: bool = False): + def gather_for_metrics(self, input_data, use_gather_object: bool = False, sync: bool = False): try: recursively_apply(lambda x: x, input_data, error_on_other_type=True) all_tensors = True @@ -562,8 +570,7 @@ def gather_for_metrics(self, input_data, use_gather_object: bool = False): if use_gather_object: data = gather_object(input_data) else: - # It is needed to perform out-of-graph gather otherwise re-compilation happens at every evaluation step. - data = self.gather(input_data, out_of_graph=True) + data = self.gather(input_data, sync=sync) try: if self.gradient_state.end_of_dataloader: diff --git a/optimum/neuron/accelerate/utils/__init__.py b/optimum/neuron/accelerate/utils/__init__.py index 6218b6f31..294e53bb3 100644 --- a/optimum/neuron/accelerate/utils/__init__.py +++ b/optimum/neuron/accelerate/utils/__init__.py @@ -15,3 +15,13 @@ from .dataclasses import MixedPrecisionConfig, MixedPrecisionMode from .misc import patch_accelerate_is_torch_xla_available +from .operations import ( + broadcast_object, + broadcast_object_to_data_parallel_group, + broadcast_object_to_pipeline_model_parallel_group, + broadcast_object_to_tensor_model_parallel_group, + gather_object, + gather_object_from_data_parallel_group, + gather_object_from_pipeline_model_parallel_group, + gather_object_from_tensor_model_parallel_group, +) diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 6cd32e8bc..9aa60b13e 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -23,6 +23,7 @@ import accelerate import torch +import torch_xla import torch_xla.core.xla_model as xm from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_rank, @@ -119,7 +120,7 @@ def wrapper(*args, **kwargs): with patcher: output = orig_func(*args, **kwargs) self.load_state_dict(orig_state_dict, assign=True) - xm.mark_step() + torch_xla.sync() del cpu_state_dict gc.collect() return output diff --git a/optimum/neuron/accelerate/utils/operations.py b/optimum/neuron/accelerate/utils/operations.py index 683086684..ab6ffa407 100644 --- a/optimum/neuron/accelerate/utils/operations.py +++ b/optimum/neuron/accelerate/utils/operations.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,44 +13,248 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle +from typing import Any, Callable + +import numpy as np import torch +import torch_xla import torch_xla.core.xla_model as xm -from accelerate.utils.operations import recursively_apply +import torch_xla.runtime as xr from neuronx_distributed.parallel_layers.parallel_state import ( - get_data_parallel_group, - model_parallel_is_initialized, + get_context_model_parallel_size, + get_data_parallel_replica_groups, + get_data_parallel_size, + get_pipeline_model_parallel_replica_groups, + get_tensor_model_parallel_replica_groups, ) +from ...utils.misc import is_precompilation + + +def broadcast_object( + obj: Any, + src: int = 0, + groups: list[list[int]] | None = None, + world_size_function: Callable[[], int] = xr.world_size, + get_rank_function: Callable[[], int] = xr.global_ordinal, + fixed_size: int | None = None, +) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed processes. + Returns the object from the source rank on all ranks. + If `groups` is specified, broadcast is done separately in each group, and the `src` rank is relative to each group. + """ + world_size = world_size_function() + if world_size == 1: + return obj + + rank = get_rank_function() + + if rank == src: + bytes_ = pickle.dumps(obj) + length = len(bytes_) + # Ensure the serialized object fits in the fixed size if specified. + # Otherwise we would corrupt the transferred data. + if fixed_size is not None and length > fixed_size: + raise ValueError(f"Serialized object size {length} exceeds the specified fixed_size {fixed_size}") + else: + bytes_ = b"" + length = 0 + + # First, broadcast the length of the serialized object. + max_length = xm.all_reduce("max", torch.tensor([length], dtype=torch.int64).to(xm.xla_device())) + max_length = max_length.cpu() + + # Ensure all ranks agree on the max length. + torch_xla.sync() + + max_length = int(max_length.item()) + + if fixed_size is not None: + target_length = fixed_size + else: + target_length = max_length + + if rank == src: + np_buffer = np.frombuffer(bytes_, dtype=np.uint8) + padding_length = target_length - length + if padding_length > 0: + padding = np.zeros([padding_length], dtype=np.uint8) + np_buffer = np.concatenate([np_buffer, padding], axis=0) + data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) + else: + data_tensor = torch.zeros(target_length, dtype=torch.uint8, device=xm.xla_device()) + + data_tensor = xm.all_reduce("sum", data_tensor, groups=groups) + torch_xla.sync() + + data_tensor_cpu = data_tensor.cpu() + reduced_bytes = data_tensor_cpu.numpy().tobytes() + + reduced_bytes = reduced_bytes[:max_length] + + return pickle.loads(reduced_bytes) + + +def broadcast_object_to_data_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed data parallel group. + Returns the object from the source rank on all ranks in the data parallel group. + """ + groups = get_data_parallel_replica_groups() + return broadcast_object( + obj, + src=src, + groups=groups, + world_size_function=get_data_parallel_size, + get_rank_function=get_data_parallel_replica_groups, + fixed_size=fixed_size, + ) + + +def broadcast_object_to_tensor_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed tensor model parallel group. + Returns the object from the source rank on all ranks in the tensor model parallel group. + """ + groups = get_tensor_model_parallel_replica_groups() + return broadcast_object( + obj, + src=src, + groups=groups, + world_size_function=get_context_model_parallel_size, + get_rank_function=get_tensor_model_parallel_replica_groups, + fixed_size=fixed_size, + ) + + +def broadcast_object_to_pipeline_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: + """ + Broadcasts arbitrary objects across XLA-distributed pipeline model parallel group. + Returns the object from the source rank on all ranks in the pipeline model parallel group. + """ + groups = get_pipeline_model_parallel_replica_groups() + return broadcast_object( + obj, + src=src, + groups=groups, + world_size_function=get_context_model_parallel_size, + get_rank_function=get_pipeline_model_parallel_replica_groups, + fixed_size=fixed_size, + ) + + +def gather_object( + obj: Any, + groups: list[list[int]] | None = None, + world_size_function: Callable[[], int] = xr.world_size, + fixed_size: int | None = None, +) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed processes. + Returns list of objects from all ranks on all ranks. + If `groups` is specified, gather is done separately in each group. + """ + world_size = world_size_function() + + # Early exit for single process + if world_size == 1: + return [obj] + + serialized = pickle.dumps(obj) + length = len(serialized) + + if fixed_size is not None and length > fixed_size: + raise ValueError(f"Serialized object size {length} exceeds the specified fixed_size {fixed_size}") + + lengths = xm.all_gather( + torch.tensor([length], dtype=torch.int64).to(device=xm.xla_device()), + dim=0, + groups=groups, + pin_layout=False, + ) + torch_xla.sync() + lengths_cpu = lengths.cpu() + max_length = lengths_cpu.max() + max_length = int(max_length.item()) + + if fixed_size is not None: + target_length = fixed_size + else: + target_length = max_length + + np_buffer = np.frombuffer(serialized, dtype=np.uint8) + padding_length = target_length - length + if padding_length > 0: + padding = np.zeros([padding_length], dtype=np.uint8) + np_buffer = np.concatenate([np_buffer, padding], axis=0) + data_tensor = torch.from_numpy(np_buffer).to(xm.xla_device()) + + data_tensor = xm.all_gather( + data_tensor, + dim=0, + groups=groups, + pin_layout=False, + ) + torch_xla.sync() + + data_tensors_cpu = data_tensor.cpu().split(target_length) + data_bytes = [t.numpy().tobytes() for t in data_tensors_cpu] + + # During precompilation, all_gather returns tensors with uninitialized data or zeros, + # breaking the pickle.loads step below. So we return a list of the original object instead, + # it should not break anything since precompilation does not rely on the gathered objects. + if is_precompilation(): + return [obj for _ in range(world_size)] + + results = [] + for i in range(world_size): + length_i = lengths_cpu[i].item() + bytes_i = data_bytes[i][:length_i] + obj_i = pickle.loads(bytes_i) + results.append(obj_i) + + return results + + +def gather_object_from_data_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed data parallel group. + Returns list of objects from all ranks in the data parallel group on all ranks. + """ + groups = get_data_parallel_replica_groups() + return gather_object( + obj, + groups=groups, + world_size_function=get_data_parallel_size, + fixed_size=fixed_size, + ) + + +def gather_object_from_tensor_model_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed tensor model parallel group. + Returns list of objects from all ranks in the tensor model parallel group on all ranks. + """ + groups = get_tensor_model_parallel_replica_groups() + return gather_object( + obj, + groups=groups, + world_size_function=get_context_model_parallel_size, + fixed_size=fixed_size, + ) + -def _xla_gather(tensor, out_of_graph: bool = False): - groups = None - if model_parallel_is_initialized(): - groups = get_data_parallel_group(as_list=True) - - def _xla_gather_one(tensor): - if tensor.ndim == 0: - tensor = tensor.clone()[None] - # Can only gather contiguous tensors - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - - if out_of_graph: - gathered_tensors = xm.mesh_reduce("nested_xla_gather", tensor, lambda x: x) - if groups is not None: - new_gathered_tensors = [] - # Since groups is containing list of group of replicas, we consider that visiting the first group of - # replicas is enough since the value should be the same across other axes. - replicas_to_consider = set(groups[0]) - for idx, tensor in enumerate(gathered_tensors): - if idx not in replicas_to_consider: - continue - new_gathered_tensors.append(tensor) - gathered_tensors = new_gathered_tensors - gathered = torch.cat(gathered_tensors) - else: - gathered = xm.all_gather(tensor, groups=groups, pin_layout=False) - return gathered - - res = recursively_apply(_xla_gather_one, tensor, error_on_other_type=True) - xm.mark_step() - return res +def gather_object_from_pipeline_model_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: + """ + Gathers arbitrary objects across XLA-distributed pipeline model parallel group. + Returns list of objects from all ranks in the pipeline model parallel group on all ranks. + """ + groups = get_pipeline_model_parallel_replica_groups() + return gather_object( + obj, + groups=groups, + world_size_function=get_context_model_parallel_size, + fixed_size=fixed_size, + ) From 9db1c6b285ccfcc0db406c67932f38af6922bc5a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 12 Dec 2025 15:50:31 +0100 Subject: [PATCH 2/2] feat: add docstring --- optimum/neuron/accelerate/utils/operations.py | 73 +++++++++++++++++-- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/optimum/neuron/accelerate/utils/operations.py b/optimum/neuron/accelerate/utils/operations.py index ab6ffa407..eb6645a84 100644 --- a/optimum/neuron/accelerate/utils/operations.py +++ b/optimum/neuron/accelerate/utils/operations.py @@ -41,9 +41,25 @@ def broadcast_object( fixed_size: int | None = None, ) -> Any: """ - Broadcasts arbitrary objects across XLA-distributed processes. + Broadcasts arbitrary picklable objects across XLA-distributed processes. Returns the object from the source rank on all ranks. If `groups` is specified, broadcast is done separately in each group, and the `src` rank is relative to each group. + + Args: + obj (Any): The object to broadcast. Must be picklable (serializable via pickle). + src (int, defaults to `0`): The source rank within each group. + groups (list[list[int]] | None, defaults to `None`): Optional list of process groups for separate broadcasts. + world_size_function (Callable[[], int], defaults to `xr.world_size`): Function to get the world size. + get_rank_function (Callable[[], int], defaults to `xr.global_ordinal`): Function to get the current rank. + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. If specified, + the serialized object must not exceed this size. + + Returns: + Any: The broadcast object from the source rank. + + Note: + Objects must be picklable. This includes most Python built-in types, but excludes + certain objects like lambdas, local functions, or objects with open file handles. """ world_size = world_size_function() if world_size == 1: @@ -99,8 +115,13 @@ def broadcast_object( def broadcast_object_to_data_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: """ - Broadcasts arbitrary objects across XLA-distributed data parallel group. + Broadcasts arbitrary picklable objects across XLA-distributed data parallel group. Returns the object from the source rank on all ranks in the data parallel group. + + Args: + obj (Any): The object to broadcast. Must be picklable (serializable via pickle). + src (int, defaults to `0`): The source rank within the data parallel group. + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. """ groups = get_data_parallel_replica_groups() return broadcast_object( @@ -115,8 +136,13 @@ def broadcast_object_to_data_parallel_group(obj: Any, src: int = 0, fixed_size: def broadcast_object_to_tensor_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: """ - Broadcasts arbitrary objects across XLA-distributed tensor model parallel group. + Broadcasts arbitrary picklable objects across XLA-distributed tensor model parallel group. Returns the object from the source rank on all ranks in the tensor model parallel group. + + Args: + obj (Any): The object to broadcast. Must be picklable (serializable via pickle). + src (int, defaults to `0`): The source rank within the tensor model parallel group. + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. """ groups = get_tensor_model_parallel_replica_groups() return broadcast_object( @@ -131,8 +157,13 @@ def broadcast_object_to_tensor_model_parallel_group(obj: Any, src: int = 0, fixe def broadcast_object_to_pipeline_model_parallel_group(obj: Any, src: int = 0, fixed_size: int | None = None) -> Any: """ - Broadcasts arbitrary objects across XLA-distributed pipeline model parallel group. + Broadcasts arbitrary picklable objects across XLA-distributed pipeline model parallel group. Returns the object from the source rank on all ranks in the pipeline model parallel group. + + Args: + obj (Any): The object to broadcast. Must be picklable (serializable via pickle). + src (int, defaults to `0`): The source rank within the pipeline model parallel group. + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. """ groups = get_pipeline_model_parallel_replica_groups() return broadcast_object( @@ -152,9 +183,23 @@ def gather_object( fixed_size: int | None = None, ) -> list[Any]: """ - Gathers arbitrary objects across XLA-distributed processes. + Gathers arbitrary picklable objects across XLA-distributed processes. Returns list of objects from all ranks on all ranks. If `groups` is specified, gather is done separately in each group. + + Args: + obj (Any): The object to gather. Must be picklable (serializable via pickle). + groups (list[list[int]] | None, defaults to `None`): Optional list of process groups for separate gathers. + world_size_function (Callable[[], int], defaults to `xr.world_size`): Function to get the world size. + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. If specified, + the serialized object must not exceed this size. + + Returns: + list[Any]: List of objects from all ranks. + + Note: + Objects must be picklable. This includes most Python built-in types, but excludes + certain objects like lambdas, local functions, or objects with open file handles. """ world_size = world_size_function() @@ -220,8 +265,12 @@ def gather_object( def gather_object_from_data_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: """ - Gathers arbitrary objects across XLA-distributed data parallel group. + Gathers arbitrary picklable objects across XLA-distributed data parallel group. Returns list of objects from all ranks in the data parallel group on all ranks. + + Args: + obj (Any): The object to gather. Must be picklable (serializable via pickle). + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. """ groups = get_data_parallel_replica_groups() return gather_object( @@ -234,8 +283,12 @@ def gather_object_from_data_parallel_group(obj: Any, fixed_size: int | None = No def gather_object_from_tensor_model_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: """ - Gathers arbitrary objects across XLA-distributed tensor model parallel group. + Gathers arbitrary picklable objects across XLA-distributed tensor model parallel group. Returns list of objects from all ranks in the tensor model parallel group on all ranks. + + Args: + obj (Any): The object to gather. Must be picklable (serializable via pickle). + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. """ groups = get_tensor_model_parallel_replica_groups() return gather_object( @@ -248,8 +301,12 @@ def gather_object_from_tensor_model_parallel_group(obj: Any, fixed_size: int | N def gather_object_from_pipeline_model_parallel_group(obj: Any, fixed_size: int | None = None) -> list[Any]: """ - Gathers arbitrary objects across XLA-distributed pipeline model parallel group. + Gathers arbitrary picklable objects across XLA-distributed pipeline model parallel group. Returns list of objects from all ranks in the pipeline model parallel group on all ranks. + + Args: + obj (Any): The object to gather. Must be picklable (serializable via pickle). + fixed_size (int | None, defaults to `None`): Optional fixed buffer size for serialization. """ groups = get_pipeline_model_parallel_replica_groups() return gather_object(