diff --git a/deepspeed/compile/backend.py b/deepspeed/compile/backend.py index 0df72ed1666c..006240985d47 100644 --- a/deepspeed/compile/backend.py +++ b/deepspeed/compile/backend.py @@ -384,3 +384,4 @@ def compiler_fn(gm, sample_inputs): raise ValueError(f"Unsupported backend {backend}") return backend_fn + diff --git a/deepspeed/compile/config.py b/deepspeed/compile/config.py index 739add99271c..c10f54b06bd2 100644 --- a/deepspeed/compile/config.py +++ b/deepspeed/compile/config.py @@ -3,8 +3,10 @@ # DeepSpeed Team +from typing import List, Optional, Literal from deepspeed.runtime.config_utils import DeepSpeedConfigModel +PassName = Literal["z1", "z3", "autosp"] class CompileConfig(DeepSpeedConfigModel): """ Configure compile settings """ @@ -53,3 +55,12 @@ class CompileConfig(DeepSpeedConfigModel): keep_all_input_tensors: bool = False """ Keep real values for all input tensors in InputStorage instead of using dummy values """ + + passes: Optional[List[PassName]] = None + """ Composes different optimizations. """ + + sp_size: int = 1 + """ SP group-size """ + + dp_size: int = 1 + """ DP group-size """ diff --git a/deepspeed/compile/constants.py b/deepspeed/compile/constants.py new file mode 100644 index 000000000000..3580082a2499 --- /dev/null +++ b/deepspeed/compile/constants.py @@ -0,0 +1,6 @@ +######################################### +# AUTOSP +######################################### +AUTOSP_INPUT_ID_KEY = "input_id" +AUTOSP_LABEL_ID_KEY = "label_id" +AUTOSP_POSITION_ID_KEY = "position_id" diff --git a/deepspeed/compile/custom_ops/__init__.py b/deepspeed/compile/custom_ops/__init__.py new file mode 100644 index 000000000000..85164c7beabc --- /dev/null +++ b/deepspeed/compile/custom_ops/__init__.py @@ -0,0 +1,4 @@ +from .all_to_all import all_to_all +from . import sp_dp_registry + +__all__ = ["all_to_all", "sp_dp_registry"] diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py new file mode 100644 index 000000000000..03d20aae83b0 --- /dev/null +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -0,0 +1,74 @@ +import torch +import torch.distributed as dist +from .sp_dp_registry import get_group, is_setup, sp_size, dp_size + +@torch.library.custom_op("autosp::all_to_all", mutates_args=()) +def all_to_all( + input: torch.Tensor, + scatter_idx: int, + gather_idx: int, + name: str, +) -> torch.Tensor: + """ + All-to-all collective for SDPA tensors [B, N, S, H]. + + For QKV (scatter_idx=1, gather_idx=2): + [B, N, S/P, H] -> [B, N/P, S, H] + For O (scatter_idx=2, gather_idx=1): + [B, N/P, S, H] -> [B, N, S/P, H] + """ + assert is_setup(), 'Incorrect initialization of SP/DP mesh.' + B, dim1, dim2, H = input.shape + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + if scatter_idx == 1: + N, local_S = dim1, dim2 + input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H) + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + output = output.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(B, N // sp_size(), sp_size() * local_S, H) + else: # scatter_idx == 2, O: scatter sequence, gather heads + local_N, S = dim1, dim2 + input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H) + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + output = output.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(B, sp_size() * local_N, S // sp_size(), H) + + return output + + +@torch.library.register_fake("autosp::all_to_all") +def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str): + B, dim1, dim2, H = input.shape + if scatter_idx == 1: + return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H) + else: + return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H) + + +def _all_to_all_backward_setup(ctx, inputs, output): + _, scatter_idx, gather_idx, name = inputs + ctx.scatter_idx = gather_idx + ctx.gather_idx = scatter_idx + ctx.name = name + "_grad" + + +def _all_to_all_backward(ctx, grad): + return ( + all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), + None, None, None + ) + + +torch.library.register_autograd( + "autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup +) diff --git a/deepspeed/compile/custom_ops/sp_dp_registry.py b/deepspeed/compile/custom_ops/sp_dp_registry.py new file mode 100644 index 000000000000..b7875699b663 --- /dev/null +++ b/deepspeed/compile/custom_ops/sp_dp_registry.py @@ -0,0 +1,48 @@ +import torch +import torch.distributed as dist + +GROUP_REGISTRY = {} # int -> dist.ProcessGroup + +def register_groups(groups): + """groups: List[List[int]], e.g. [[0,1],[2,3]]""" + for gid, ranks in enumerate(groups): + if gid not in GROUP_REGISTRY: + GROUP_REGISTRY[gid] = dist.new_group(ranks) + +def get_group(gid: int): + return GROUP_REGISTRY[gid] if gid is not None else dist.group.WORLD + +def get_registry(): + return GROUP_REGISTRY + +def is_setup(): + return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False + +def sp_size(): + assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' + + return GROUP_REGISTRY['SP_SIZE'] + +def dp_size(): + assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' + + return GROUP_REGISTRY['DP_SIZE'] + +def populate_registry(SP_SIZE, DP_SIZE): + """ Populate rank to SP/DP mesh index. """ + + if GROUP_REGISTRY.get('is_reg', False): + return + + group_listing = [] + offset = 0 + for _ in range(DP_SIZE): + group_listing.append([i + offset for i in range(SP_SIZE)]) + offset += SP_SIZE + + register_groups(group_listing) + + ## Extraneous metadata required for proper instatiation. ## + GROUP_REGISTRY['SP_SIZE'] = SP_SIZE + GROUP_REGISTRY['DP_SIZE'] = DP_SIZE + GROUP_REGISTRY['is_reg'] = True diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py index 7b3408b56afe..d745bbda4624 100644 --- a/deepspeed/compile/fx.py +++ b/deepspeed/compile/fx.py @@ -3,11 +3,11 @@ # DeepSpeed Team -from typing import Callable, Any, List, Dict +from typing import Callable, Any, List, Dict, Optional from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from .util import get_last_uses @@ -138,3 +138,28 @@ def free_tensors(tensors: List[torch.Tensor]): # Python version for debugging # graph.create_node('call_function', free_tensors, args, {}, name=node_name) + +def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]: + for node in gm.graph.nodes: + if node.name == name: + return node + return None + +def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]: + return node.meta.get("val") or node.meta.get("example_value") + +def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]: + input_id_node = None + for node in gm.graph.nodes: + # https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048 + tensor_dict = node.meta.get('tensor_dict') + if tensor_dict and tensor_dict.get('tag') == tag: + input_id_node = node + break + return input_id_node + +def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None): + exclude = exclude or [] + to_replace = [u for u in node.users if u not in exclude] + for user in to_replace: + user.replace_input_with(node, replacement) diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py new file mode 100644 index 000000000000..de15bcb5f925 --- /dev/null +++ b/deepspeed/compile/init_sp.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from torch.fx import GraphModule +from .passes.sp_compile import apply_autosp + +def init_autosp(compile_config): + def backend_fn(gm: GraphModule, real_inputs): + apply_autosp(gm, real_inputs, debug=False, sp_size=compile_config.sp_size, dp_size=compile_config.dp_size) + return torch._inductor.compile(gm, real_inputs) + return backend_fn diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py new file mode 100644 index 000000000000..308187b45d26 --- /dev/null +++ b/deepspeed/compile/passes/sp_compile.py @@ -0,0 +1,246 @@ +"""AutoSP: Automatic Sequence Parallel (Ulysses) pass for graph modules. + +Ulysses Transformation: + Input: [B, N, S/P, H] (all heads, partitioned sequence) + After A2A on QKV: [B, N/P, S, H] (partitioned heads, full sequence) + After SDPA: [B, N/P, S, H] + After A2A on O: [B, N, S/P, H] (all heads, partitioned sequence) + +Where: + B = batch size, N = num heads, S = full sequence length, H = head dim, P = world size +""" + +import operator +from typing import Optional, List, Callable + +import torch +import torch.distributed as dist +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import GraphModule, Node +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +from deepspeed.compile import constants + +from ..custom_ops import all_to_all, sp_dp_registry +from ..fx import find_node_by_name, get_node_shape_meta +from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes + +def prepare_autosp_inputs(input_id: torch.Tensor, label_id: torch.Tensor, position_id: torch.Tensor = None, attention_mask: torch.Tensor = None, seq_dim: int = 1): + """ + Prepare inputs for AutoSP by marking dynamic dimensions and tagging tensors. + + Args: + input_id: Token IDs tensor (required) + label_id: Label IDs tensor (required) + position_id: Position IDs tensor (optional) + attention_mask: Attention mask tensor (optional) + seq_dim: Sequence dimension index to mark as dynamic (default: 1) + """ + + if input_id is None: + raise ValueError("input_id is required") + if label_id is None: + raise ValueError("label_id is required") + + if seq_dim < 0 or seq_dim >= input_id.ndim: + raise ValueError(f"seq_dim {seq_dim} must be a valid index for input_id with shape {input_id.shape}") + + if position_id is not None: + if seq_dim >= position_id.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for position_id with shape {position_id.shape}") + + if attention_mask is not None: + if seq_dim >= attention_mask.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for attention_mask with shape {attention_mask.shape}") + + torch._dynamo.decorators.mark_dynamic(input_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(label_id, seq_dim) + if position_id is not None: + torch._dynamo.decorators.mark_dynamic(position_id, seq_dim) + if attention_mask is not None: + torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) + + input_id.tag = constants.AUTOSP_INPUT_ID_KEY + label_id.tag = constants.AUTOSP_LABEL_ID_KEY + if position_id is not None: + position_id.tag = constants.AUTOSP_POSITION_ID_KEY + + return input_id, label_id, position_id, attention_mask + +def pass_shard_seq_dim(gm: GraphModule, example_inputs): + """ + Finds all direct and indirect consumers of the input sequence, label and position ids. + Shard the sequence dimension used by all such consumers. + """ + sp_size = sp_dp_registry.sp_size() + + input_ids_node = get_input_id_node(gm) + val = get_node_shape_meta(input_ids_node) + seq_symint = val.shape[1] + assert isinstance(seq_symint, torch.SymInt), f"expected sequence dimension to be of type `torch.SymInt` but found `{type(seq_symint)}`" + + sym_seq_dim_node = find_node_by_name(gm, str(seq_symint)) + if sym_seq_dim_node is None: + print(f"WARNING: Could not find the symbolic node for the sequence dimension") + return + + with gm.graph.inserting_after(sym_seq_dim_node): + sharded_node = gm.graph.call_function( + operator.floordiv, + args=(sym_seq_dim_node, sp_size) + ) + + sharded_input_nodes = set() + label_ids_node = get_label_id_node(gm) + position_ids_node = get_position_id_node(gm) + + if input_ids_node is not None: + sharded_input_nodes.add(input_ids_node) + if label_ids_node is not None: + sharded_input_nodes.add(label_ids_node) + if position_ids_node is not None: + sharded_input_nodes.add(position_ids_node) + + # find all consumers of the sharded inputs + consumer_nodes = set() + worklist = list(sharded_input_nodes) + visited = set() + + while worklist: + node = worklist.pop(0) + if node in visited: + continue + visited.add(node) + consumer_nodes.add(node) + + for user in node.users: + if user not in visited: + worklist.append(user) + + to_replace = [] + for node in consumer_nodes: + if sym_seq_dim_node in node.all_input_nodes: + to_replace.append(node) + + for user in to_replace: + user.replace_input_with(sym_seq_dim_node, sharded_node) + + +def pass_shard_input_ids(gm: GraphModule, example_inputs): + input_ids_node = get_input_id_node(gm) + shard_tensor_node(gm, input_ids_node) + + +def pass_shard_label_ids(gm: GraphModule, example_inputs): + label_ids_node = get_label_id_node(gm) + shard_tensor_node(gm, label_ids_node) + +def pass_shard_position_ids(gm: GraphModule, example_inputs): + position_ids_node = get_position_id_node(gm) + if position_ids_node is None: + print("[WARNING] position id node not found. Skipping sharding of position ids.") + return + shard_tensor_node(gm, position_ids_node) + + +def pass_insert_attention_all_to_all(gm: GraphModule, real_inputs): + def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node: + with gm.graph.inserting_after(node): + a2a_node = gm.graph.call_function( + torch.ops.autosp.all_to_all.default, + args=(node, scatter_idx, gather_idx, name), + ) + a2a_node.name = f"a2a_{name}" + node.replace_all_uses_with(a2a_node) + a2a_node.update_arg(0, node) + return a2a_node + + attention_nodes = get_sdpa_nodes(gm) + if len(attention_nodes) == 0: + raise RuntimeError( + "AutoSP currently supports torch.nn.functional.scaled_dot_product_attention as the " + "attention backend. No SDPA attention operations were found in the compiled graph. " + "Please ensure your model uses torch.nn.functional.scaled_dot_product_attention " + "for AutoSP to work as expected." + ) + + for idx, attn_node in enumerate(attention_nodes): + q, k, v = attn_node.args[:3] + suffix = f"_{idx}" if len(attention_nodes) > 1 else "" + + # QKV: [B, N, S/P, H] -> [B, N/P, S, H] + insert_a2a(q, scatter_idx=1, gather_idx=2, name=f"q{suffix}") + insert_a2a(k, scatter_idx=1, gather_idx=2, name=f"k{suffix}") + insert_a2a(v, scatter_idx=1, gather_idx=2, name=f"v{suffix}") + + # O: [B, N/P, S, H] -> [B, N, S/P, H] + insert_a2a(attn_node, scatter_idx=2, gather_idx=1, name=f"o{suffix}") + + +def pass_canonicalize(gm: GraphModule, real_inputs): + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + +def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + fake_inputs = [] + for t in real_inputs: + if isinstance(t, torch.Tensor): + fake_inputs.append(fake_mode.from_tensor(t)) + else: + fake_inputs.append(t) + FakeTensorProp(gm).propagate(*fake_inputs) + + +def apply_autosp( + gm: GraphModule, + real_inputs, + debug: bool = False, + passes: Optional[List[Callable]] = None, + sp_size: int = 2, + dp_size: int = 1 +): + """ + Apply AutoSP (Ulysses) transformation passes to the graph and setup either DP/SP (2D) or SP (1D) mesh. + + Args: + gm: GraphModule to transform + real_inputs: Example inputs for shape propagation + debug: If True, print graph before/after each pass + passes: Optional custom list of passes (default: DEFAULT_PASSES) + """ + assert sp_size * dp_size <= dist.get_world_size(), 'Insufficient device count for mesh size' + + sp_dp_registry.populate_registry(sp_size, dp_size) + + AUTOSP_PASSES = [ + pass_shard_seq_dim, + pass_shard_input_ids, + pass_shard_label_ids, + pass_shard_position_ids, + pass_insert_attention_all_to_all, + pass_propagate_shapes, + pass_canonicalize, + ] + + passes = passes or AUTOSP_PASSES + rank = dist.get_rank() + + for p in passes: + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" BEFORE: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + + p(gm, real_inputs) + + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" AFTER: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index e8abcc2c8b3c..baf70f60e3fc 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -9,8 +9,9 @@ from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from torch.fx.node import map_aggregate, Argument, map_arg +import torch.nn.functional as F try: from torch._subclasses.fake_tensor import unset_fake_temporarily @@ -22,7 +23,9 @@ from deepspeed.accelerator import get_accelerator from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder +from deepspeed.compile import constants +from .custom_ops import sp_dp_registry def is_deepcompile_supported() -> bool: return required_torch_version(min_version=2.6, max_version=2.9) and get_accelerator().device_name() == "cuda" @@ -521,3 +524,84 @@ def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor padded.append(out) return padded + +def create_shard_offsets( + gm: GraphModule, + s0_node: Node +) -> Tuple[Node, Node]: + sp_size: int = sp_dp_registry.sp_size() + sp_rank: int = dist.get_rank() % sp_dp_registry.sp_size() + with gm.graph.inserting_after(s0_node): + chunk_size_node = gm.graph.call_function(operator.floordiv, args=(s0_node, sp_size)) + with gm.graph.inserting_after(chunk_size_node): + start_node = gm.graph.call_function(operator.mul, args=(sp_rank, chunk_size_node)) + with gm.graph.inserting_after(start_node): + end_node = gm.graph.call_function(operator.add, args=(start_node, chunk_size_node)) + + return start_node, end_node + + +def get_sdpa_nodes(gm: GraphModule) -> List[Node]: + return list(gm.graph.find_nodes( + op="call_function", + target=F.scaled_dot_product_attention, + )) + +def get_input_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_INPUT_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the input sequence.") + return node + +def get_label_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_LABEL_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the label.") + return node + +def get_position_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_POSITION_ID_KEY) + return node + + +def create_symbolic_slice_indices( + gm: GraphModule, + sym_seq_dim_node: Node, +) -> Tuple[Node, Node]: + start_node, end_node = create_shard_offsets(gm, sym_seq_dim_node) + + with gm.graph.inserting_after(end_node): + slice_all = gm.graph.call_function(slice, args=(None, None, None)) + with gm.graph.inserting_after(slice_all): + slice_range = gm.graph.call_function(slice, args=(start_node, end_node, None)) + + return slice_all, slice_range + +def shard_tensor_node( + gm: GraphModule, + tensor_node: Node +): + from .fx import find_node_by_name, get_node_shape_meta, replace_node_users + val = get_node_shape_meta(tensor_node) + assert val is not None, f"Node {tensor_node.name} has no shape metadata" + + seq_len = val.shape[1] + + assert isinstance(seq_len, torch.SymInt), f"Expected sequence dimension to be `torch.SymInt` but instead found `{type(seq_len)}`" + + symb_seq_int_node = find_node_by_name(gm, str(seq_len)) + assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}" + + slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node) + indices = (slice_all, slice_range) + + with gm.graph.inserting_after(tensor_node): + sliced_node = gm.graph.call_function( + operator.getitem, + args=(tensor_node, indices), + ) + + replace_node_users(tensor_node, sliced_node, exclude=[sliced_node]) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 9e73bad73376..92abb58a2a49 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -501,3 +501,4 @@ class ValidationMode: ######################################### USE_DATA_BEFORE_EXPERT_PARALLEL = "use_data_before_expert_parallelism" USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT = False + diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9a4e4608c847..3e4f5c550417 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -127,6 +127,7 @@ from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states from deepspeed.compile.init_z1 import init_z1 from deepspeed.compile.init_z3 import init_z3 +from deepspeed.compile.init_sp import init_autosp MEMORY_OPT_ALLREDUCE_SIZE = 500000000 @@ -1004,6 +1005,14 @@ def zero_sub_group_size(self): def zero_optimization_stage(self): return self._config.zero_optimization_stage + def compile_zero_optimization_stage(self): + """Determines if zero-pass is set in deepcompile's passes attributes.""" + return "z1" in self._config.compile_config.passes or "z3" in self._config.compile_config.passes + + def compile_autosp(self): + """Determines if AutoSP is set in deepcompile's passes attributes.""" + return "autosp" in (getattr(self._config.compile_config, "passes", None) or []) + def mics_shard_size(self): return self._config.mics_shard_size @@ -2373,7 +2382,7 @@ def print_forward_breakdown(self, fwd_time): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Skip gradient reduction when DeepCompile is enabled # DeepCompile handles its own gradient reduction through compiled graph operations - if self.is_deepcompile_active(): + if self.is_deepcompile_active() and not self.compile_autosp(): return # Pass (PP) gas boundary flag to optimizer (required for zero) @@ -4336,6 +4345,61 @@ def empty_partition_cache(self): gc.collect() get_accelerator().empty_cache() + def get_autosp_backend(self, compile_kwargs): + if self.compile_autosp() and self.zero_optimization_stage() not in [ZeroStageEnum.disabled, ZeroStageEnum.optimizer_states]: + logger.info( + f"Currently AutoSP does not compose with ZeRO stage 2 and 3. Falling back to the torch compiler." + ) + return None + + compile_config = self._config.compile_config + compile_kwargs['fullgraph'] = True + return init_autosp(compile_config) + + def get_deepcompile_backend(self, backend, compile_kwargs, schedule): + if self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ + and self.zero_optimization_stage() != ZeroStageEnum.weights \ + and self.zero_optimization_stage() != ZeroStageEnum.gradients: + logger.info( + f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." + ) + return None + + compile_config = self._config.compile_config + if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] + and "offload_param" in self.config["zero_optimization"]) + and self._config.zero_config.offload_param.device == "cpu" + and self._config.zero_config.offload_optimizer.device == "cpu"): + compile_config.offload_parameters = True + if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: + return init_z1(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.gradients: + return init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) + elif self.zero_optimization_stage() == ZeroStageEnum.weights: + return init_z3(self, backend, compile_config, compile_kwargs, schedule) + return None + + def get_deepspeed_compile_backend(self, backend, compile_kwargs, schedule): + resolved_backend = None + + if schedule is not None: + + def passes_name_to_fn(passes): + for p in passes: + assert callable(p) or p in opt_passes, f"Unknown pass {p}" + return [p if callable(p) else opt_passes[p] for p in passes] + + schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] + + assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." + + if self.compile_autosp(): + resolved_backend = self.get_autosp_backend(compile_kwargs) + else: + resolved_backend = self.get_deepcompile_backend(backend, compile_kwargs, schedule) + + return resolved_backend, schedule + def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, @@ -4358,53 +4422,23 @@ def compile(self, logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") - enable_deepcompile = self.is_deepcompile_enabled() - if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ - and self.zero_optimization_stage() != ZeroStageEnum.weights \ - and self.zero_optimization_stage() != ZeroStageEnum.gradients: - logger.info( - f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." - ) - enable_deepcompile = False - - if enable_deepcompile: - - if schedule is not None: - - def passes_name_to_fn(passes): - for p in passes: - assert callable(p) or p in opt_passes, f"Unknown pass {p}" - return [p if callable(p) else opt_passes[p] for p in passes] - - schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] - - assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." - - compile_config = self._config.compile_config - if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] - and "offload_param" in self.config["zero_optimization"]) - and self._config.zero_config.offload_param.device == "cpu" - and self._config.zero_config.offload_optimizer.device == "cpu"): - compile_config.offload_parameters = True - if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule) - elif self.zero_optimization_stage() == ZeroStageEnum.gradients: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) - elif self.zero_optimization_stage() == ZeroStageEnum.weights: - if required_torch_version(min_version=2.9): - raise RuntimeError( - "DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. " - "Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.") - backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) + resolved_backend = None + if self.is_deepcompile_enabled(): + resolved_backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule) + + is_deepspeed_compile_backend = resolved_backend is not None + # default to torch.compiler backend if deepspeed config validation fails + backend = resolved_backend or backend + # Hook state must align with whether DeepCompile is active. - self._set_deepcompile_active(enable_deepcompile) + self._set_deepcompile_active(is_deepspeed_compile_backend) # create new dict to avoid modifying original dict try: self.module.compile(**{**compile_kwargs, 'backend': backend}) except Exception: - if enable_deepcompile: + if is_deepspeed_compile_backend: # Restore default hooks if compilation fails before completing. self._set_deepcompile_active(False) raise