From b57ccb87ccfcd058fbf7eeeaa88cdc95d72fd0ab Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sat, 14 Feb 2026 20:10:50 -0600 Subject: [PATCH 01/14] add autosp backend --- deepspeed/compile/custom_ops/all_to_all.py | 64 +++++++ deepspeed/compile/fx.py | 29 +++- deepspeed/compile/init_sp.py | 14 ++ deepspeed/compile/passes/sp_compile.py | 192 +++++++++++++++++++++ deepspeed/compile/util.py | 98 ++++++++++- 5 files changed, 394 insertions(+), 3 deletions(-) create mode 100644 deepspeed/compile/custom_ops/all_to_all.py create mode 100644 deepspeed/compile/init_sp.py create mode 100644 deepspeed/compile/passes/sp_compile.py diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py new file mode 100644 index 000000000000..6f4b5172b67a --- /dev/null +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -0,0 +1,64 @@ +import torch +import torch.distributed as dist + +@torch.library.custom_op("autosp::all_to_all", mutates_args=()) +def all_to_all( + input: torch.Tensor, + scatter_idx: int, + gather_idx: int, + world_size: int, + name: str, +) -> torch.Tensor: + B, dim1, dim2, H = input.shape + + if scatter_idx == 1: + N, local_S = dim1, dim2 + input_t = input.reshape(B, world_size, N // world_size, local_S, H) + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=dist.group.WORLD) + + output = output.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(B, N // world_size, world_size * local_S, H) + else: + local_N, S = dim1, dim2 + input_t = input.reshape(B, local_N, world_size, S // world_size, H) + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=dist.group.WORLD) + + output = output.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(B, world_size * local_N, S // world_size, H) + + return output + + +@torch.library.register_fake("autosp::all_to_all") +def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, world_size: int, name: str): + B, dim1, dim2, H = input.shape + if scatter_idx == 1: + return input.new_empty(B, dim1 // world_size, dim2 * world_size, H) + else: + return input.new_empty(B, dim1 * world_size, dim2 // world_size, H) + + +def _all_to_all_backward_setup(ctx, inputs, output): + _, scatter_idx, gather_idx, world_size, name = inputs + ctx.scatter_idx = gather_idx + ctx.gather_idx = scatter_idx + ctx.world_size = world_size + ctx.name = name + "_grad" + + +def _all_to_all_backward(ctx, grad): + return ( + all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.world_size, ctx.name), + None, None, None, None, + ) + + +torch.library.register_autograd( + "autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup +) diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py index 7b3408b56afe..fea046000516 100644 --- a/deepspeed/compile/fx.py +++ b/deepspeed/compile/fx.py @@ -3,11 +3,11 @@ # DeepSpeed Team -from typing import Callable, Any, List, Dict +from typing import Callable, Any, List, Dict, Optional from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from .util import get_last_uses @@ -138,3 +138,28 @@ def free_tensors(tensors: List[torch.Tensor]): # Python version for debugging # graph.create_node('call_function', free_tensors, args, {}, name=node_name) + +def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]: + for node in gm.graph.nodes: + if node.name == name: + return node + return None + +def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]: + return node.meta.get("val") or node.meta.get("example_value") + +def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]: + input_id_node = None + for node in gm.graph.nodes: + # https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048 + tensor_dict = node.meta.get('tensor_dict') + if tensor_dict and tensor_dict.get('tag') == tag: + input_id_node = node + break + return input_id_node + +def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None): + exclude = exclude or [] + to_replace = [u for u in node.users if u not in exclude] + for user in to_replace: + user.replace_input_with(node, replacement) \ No newline at end of file diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py new file mode 100644 index 000000000000..17ce36a833df --- /dev/null +++ b/deepspeed/compile/init_sp.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from torch.fx import GraphModule +from .passes.autosp import apply_autosp + +def init_ulysses(): + def backend_fn(gm: GraphModule, real_inputs): + apply_autosp(gm, real_inputs, debug_log=False) + return torch._inductor.compile(gm, real_inputs) + return backend_fn diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py new file mode 100644 index 000000000000..726d998e518a --- /dev/null +++ b/deepspeed/compile/passes/sp_compile.py @@ -0,0 +1,192 @@ +"""AutoSP: Automatic Sequence Parallel (Ulysses) pass for graph modules. + +Ulysses Transformation: + Input: [B, N, S/P, H] (all heads, partitioned sequence) + After A2A on QKV: [B, N/P, S, H] (partitioned heads, full sequence) + After SDPA: [B, N/P, S, H] + After A2A on O: [B, N, S/P, H] (all heads, partitioned sequence) + +Where: + B = batch size, N = num heads, S = full sequence length, H = head dim, P = world size +""" + +import operator +from typing import Optional, List, Callable + +import torch +import torch.distributed as dist +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import GraphModule, Node +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +from ..custom_ops import all_to_all +from ..fx import find_node_by_name, get_node_shape_meta +from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes, ShardingConfig + +def pass_shard_seq_dim(gm: GraphModule, example_inputs): + """ + Finds all direct and indirect consumers of the input sequence, label and position ids. + Shard the sequence dimension used by all such consumers. + """ + world_size = dist.get_world_size() + + input_ids_node = get_input_id_node(gm) + val = get_node_shape_meta(input_ids_node) + seq_symint = val.shape[1] + assert isinstance(seq_symint, torch.SymInt), f"expected sequence dimension to be of type `torch.SymInt` but found `{type(seq_symint)}`" + + sym_seq_dim_node = find_node_by_name(gm, str(seq_symint)) + if sym_seq_dim_node is None: + print(f"WARNING: Could not find the symbolic node for the sequence dimension") + return + + with gm.graph.inserting_after(sym_seq_dim_node): + sharded_node = gm.graph.call_function( + operator.floordiv, + args=(sym_seq_dim_node, world_size) + ) + + sharded_input_nodes = set() + label_ids_node = get_label_id_node(gm) + position_ids_node = get_position_id_node(gm) + + if input_ids_node is not None: + sharded_input_nodes.add(input_ids_node) + if label_ids_node is not None: + sharded_input_nodes.add(label_ids_node) + if position_ids_node is not None: + sharded_input_nodes.add(position_ids_node) + + # find all consumers of the sharded inputs + consumer_nodes = set() + worklist = list(sharded_input_nodes) + visited = set() + + while worklist: + node = worklist.pop(0) + if node in visited: + continue + visited.add(node) + consumer_nodes.add(node) + + for user in node.users: + if user not in visited: + worklist.append(user) + + to_replace = [] + for node in consumer_nodes: + if sym_seq_dim_node in node.all_input_nodes: + to_replace.append(node) + + for user in to_replace: + user.replace_input_with(sym_seq_dim_node, sharded_node) + + +def pass_shard_input_ids(gm: GraphModule, example_inputs): + config = ShardingConfig.from_distributed() + input_ids_node = get_input_id_node(gm) + shard_tensor_node(gm, input_ids_node, config) + + +def pass_shard_label_ids(gm: GraphModule, example_inputs): + config = ShardingConfig.from_distributed() + label_ids_node = get_label_id_node(gm) + shard_tensor_node(gm, label_ids_node, config) + +def pass_shard_position_ids(gm: GraphModule, example_inputs): + config = ShardingConfig.from_distributed() + position_ids_node = get_position_id_node(gm) + if position_ids_node is None: + print("[WARNING] position id node not found. Skipping sharding of position ids.") + return + shard_tensor_node(gm, position_ids_node, config) + + +def pass_insert_attention_all_to_all(gm: GraphModule, real_inputs): + """ + Insert all-to-all collectives around SDPA for Ulysses parallelism. + + For each SDPA: + - Before Q, K, V: scatter heads (dim=1), gather sequence (dim=2) + - After O: scatter sequence (dim=2), gather heads (dim=1) + """ + world_size = dist.get_world_size() + attention_nodes = get_sdpa_nodes(gm) + + def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node: + with gm.graph.inserting_after(node): + a2a_node = gm.graph.call_function( + torch.ops.autosp.all_to_all.default, + args=(node, scatter_idx, gather_idx, world_size, name), + ) + a2a_node.name = f"a2a_{name}" + node.replace_all_uses_with(a2a_node) + a2a_node.update_arg(0, node) + return a2a_node + + for idx, attn_node in enumerate(attention_nodes): + q, k, v = attn_node.args[:3] + suffix = f"_{idx}" if len(attention_nodes) > 1 else "" + + # QKV: [B, N, S/P, H] -> [B, N/P, S, H] + insert_a2a(q, scatter_idx=1, gather_idx=2, name=f"q{suffix}") + insert_a2a(k, scatter_idx=1, gather_idx=2, name=f"k{suffix}") + insert_a2a(v, scatter_idx=1, gather_idx=2, name=f"v{suffix}") + + # O: [B, N/P, S, H] -> [B, N, S/P, H] + insert_a2a(attn_node, scatter_idx=2, gather_idx=1, name=f"o{suffix}") + + +def pass_canonicalize(gm: GraphModule, real_inputs): + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + +def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + fake_inputs = [] + for t in real_inputs: + if isinstance(t, torch.Tensor): + fake_inputs.append(fake_mode.from_tensor(t)) + else: + fake_inputs.append(t) + FakeTensorProp(gm).propagate(*fake_inputs) + + +def apply_autosp( + gm: GraphModule, + real_inputs, + debug: bool = False, + passes: Optional[List[Callable]] = None, +): + AUTOSP_PASSES = [ + pass_shard_seq_dim, + pass_shard_input_ids, + pass_shard_label_ids, + pass_shard_position_ids, + pass_insert_attention_all_to_all, + pass_propagate_shapes, + pass_canonicalize, + ] + + passes = passes or AUTOSP_PASSES + rank = dist.get_rank() + + for p in passes: + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" BEFORE: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + + p(gm, real_inputs) + + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" AFTER: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index e8abcc2c8b3c..5963468d98ec 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -6,11 +6,13 @@ import functools import operator from typing import List, Tuple, Dict, Optional +from dataclasses import dataclass from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from torch.fx.node import map_aggregate, Argument, map_arg +import torch.nn.functional as F try: from torch._subclasses.fake_tensor import unset_fake_temporarily @@ -23,6 +25,11 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder +from .fx import find_node_by_name, find_node_by_tag, get_node_shape_meta, replace_node_users + +INPUT_ID_KEY = "input_id" +LABEL_ID_KEY = "label_id" +POSITION_ID_KEY = "position_id" def is_deepcompile_supported() -> bool: return required_torch_version(min_version=2.6, max_version=2.9) and get_accelerator().device_name() == "cuda" @@ -521,3 +528,92 @@ def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor padded.append(out) return padded + +@dataclass +class ShardingConfig: + world_size: int + rank: int + + @classmethod + def from_distributed(cls) -> "ShardingConfig": + return cls( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + ) + +def get_sdpa_nodes(gm: GraphModule) -> List[Node]: + return list(gm.graph.find_nodes( + op="call_function", + target=F.scaled_dot_product_attention, + )) + +def get_input_id_node(gm: GraphModule) -> Node: + node = find_node_by_tag(gm, INPUT_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the input sequence.") + return node + +def get_label_id_node(gm: GraphModule) -> Node: + node = find_node_by_tag(gm, LABEL_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the label.") + return node + +def get_position_id_node(gm: GraphModule) -> Node: + node = find_node_by_tag(gm, POSITION_ID_KEY) + return node + +def create_shard_offsets( + gm: GraphModule, + sym_seq_dim_node: Node, + world_size: int, + rank: int +) -> Tuple[Node, Node]: + with gm.graph.inserting_after(sym_seq_dim_node): + chunk_size_node = gm.graph.call_function(operator.floordiv, args=(sym_seq_dim_node, world_size)) + with gm.graph.inserting_after(chunk_size_node): + start_node = gm.graph.call_function(operator.mul, args=(rank, chunk_size_node)) + with gm.graph.inserting_after(start_node): + end_node = gm.graph.call_function(operator.add, args=(start_node, chunk_size_node)) + + return start_node, end_node + +def create_symbolic_slice_indices( + gm: GraphModule, + sym_seq_dim_node: Node, + config: ShardingConfig +) -> Tuple[Node, Node]: + start_node, end_node = create_shard_offsets(gm, sym_seq_dim_node, config.world_size, config.rank) + + with gm.graph.inserting_after(end_node): + slice_all = gm.graph.call_function(slice, args=(None, None, None)) + with gm.graph.inserting_after(slice_all): + slice_range = gm.graph.call_function(slice, args=(start_node, end_node, None)) + + return slice_all, slice_range + +def shard_tensor_node( + gm: GraphModule, + tensor_node: Node, + config: ShardingConfig +): + val = get_node_shape_meta(tensor_node) + assert val is not None, f"Node {tensor_node.name} has no shape metadata" + + seq_len = val.shape[1] + + assert isinstance(seq_len, torch.SymInt), f"Expected sequence dimension to be `torch.SymInt` but instead found `{type(seq_len)}`" + + symb_seq_int_node = find_node_by_name(gm, str(seq_len)) + assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}" + + slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node, config) + indices = (slice_all, slice_range) + + with gm.graph.inserting_after(tensor_node): + sliced_node = gm.graph.call_function( + operator.getitem, + args=(tensor_node, indices), + ) + + replace_node_users(tensor_node, sliced_node, exclude=[sliced_node]) From 4df32d1fef2110ef6ed5c871bc7217306af3fbba Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Thu, 19 Feb 2026 00:11:00 -0600 Subject: [PATCH 02/14] add benchmarking script --- bench_dc_ulysses/.gitignore | 10 + bench_dc_ulysses/README.md | 51 ++ bench_dc_ulysses/configs/config.yaml | 16 + .../configs/deepcompile_config.json | 19 + .../configs/deepcompile_config.yaml | 16 + bench_dc_ulysses/configs/ds_config.json | 25 + .../configs/ds_config.json.template | 30 + .../configs/ds_config.yaml.template | 19 + .../configs/torchcompile_config.json | 14 + .../configs/torchcompile_config.yaml | 16 + bench_dc_ulysses/distributed_attention.py | 104 ++++ bench_dc_ulysses/gen_chart_acc_steps.py | 263 +++++++++ bench_dc_ulysses/generate_conf.py | 42 ++ bench_dc_ulysses/hostfile_n1 | 1 + bench_dc_ulysses/hostfile_n2 | 1 + bench_dc_ulysses/hostfile_n4 | 1 + bench_dc_ulysses/launch_jobs.sh | 13 + bench_dc_ulysses/ring_attention.py | 530 ++++++++++++++++++ bench_dc_ulysses/run.sh | 56 ++ bench_dc_ulysses/run_acc_lm.py | 251 +++++++++ bench_dc_ulysses/run_bench.sh | 19 + bench_dc_ulysses/run_bench_acc.sh | 42 ++ bench_dc_ulysses/run_correctness_test.sh | 4 + bench_dc_ulysses/run_multinode.sh | 112 ++++ bench_dc_ulysses/run_ulysses.sh | 65 +++ bench_dc_ulysses/sample.slurm | 43 ++ bench_dc_ulysses/sp_dp_registry.py | 45 ++ deepspeed/compile/init_sp.py | 6 +- deepspeed/compile/util.py | 6 +- deepspeed/runtime/engine.py | 6 +- 30 files changed, 1820 insertions(+), 6 deletions(-) create mode 100644 bench_dc_ulysses/.gitignore create mode 100644 bench_dc_ulysses/README.md create mode 100644 bench_dc_ulysses/configs/config.yaml create mode 100644 bench_dc_ulysses/configs/deepcompile_config.json create mode 100644 bench_dc_ulysses/configs/deepcompile_config.yaml create mode 100644 bench_dc_ulysses/configs/ds_config.json create mode 100644 bench_dc_ulysses/configs/ds_config.json.template create mode 100644 bench_dc_ulysses/configs/ds_config.yaml.template create mode 100644 bench_dc_ulysses/configs/torchcompile_config.json create mode 100644 bench_dc_ulysses/configs/torchcompile_config.yaml create mode 100644 bench_dc_ulysses/distributed_attention.py create mode 100644 bench_dc_ulysses/gen_chart_acc_steps.py create mode 100644 bench_dc_ulysses/generate_conf.py create mode 100644 bench_dc_ulysses/hostfile_n1 create mode 100644 bench_dc_ulysses/hostfile_n2 create mode 100644 bench_dc_ulysses/hostfile_n4 create mode 100755 bench_dc_ulysses/launch_jobs.sh create mode 100644 bench_dc_ulysses/ring_attention.py create mode 100755 bench_dc_ulysses/run.sh create mode 100644 bench_dc_ulysses/run_acc_lm.py create mode 100755 bench_dc_ulysses/run_bench.sh create mode 100755 bench_dc_ulysses/run_bench_acc.sh create mode 100755 bench_dc_ulysses/run_correctness_test.sh create mode 100755 bench_dc_ulysses/run_multinode.sh create mode 100755 bench_dc_ulysses/run_ulysses.sh create mode 100644 bench_dc_ulysses/sample.slurm create mode 100644 bench_dc_ulysses/sp_dp_registry.py diff --git a/bench_dc_ulysses/.gitignore b/bench_dc_ulysses/.gitignore new file mode 100644 index 000000000000..4b197544883c --- /dev/null +++ b/bench_dc_ulysses/.gitignore @@ -0,0 +1,10 @@ +*.log +*.pyc +profiles +results +slurm_jobs +slurm* +experiments +logs +*. +*.pt diff --git a/bench_dc_ulysses/README.md b/bench_dc_ulysses/README.md new file mode 100644 index 000000000000..4f0c68a3458d --- /dev/null +++ b/bench_dc_ulysses/README.md @@ -0,0 +1,51 @@ +# Benchmark for DeepCompile + +## Setup + +This experiment scripts require 1 node that has 2 A100/A40 GPUs +We tested the scripts with Python 3.10.12 and CUDA 12.3. + +### Libraries + +In addition, you need to install the following: + +- PyTorch 2.5.1 +- [modified version of DeepSpeed](https://github.com/tohtana/DeepSpeed-internal/tree/neeld2/debug-loss) + +Here are an example of installation commands: + +```bash +pip3 install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 +pip3 install datasets==3.1 accelerate + +# Install DeepSpeed and DeepCompile +git clone -b neeld2/debug-loss https://github.com/tohtana/DeepSpeed-internal.git +cd DeepSpeed-internal +pip install -e transformers +cd .. +pip install -e DeepSpeed + +# Clone this repository +git clone https://github.com/neeldani/bench_dc_ulysses.git +``` + +## Running the scripts + +Test the setup by running the script: +```bash +bash run_ulysses.sh 6 [compile|deepcompile|eager|ringattn] +``` + +Here, 6 is the sequence length and is hardcoded because the input sequence inside run_acc_lm.py is hardcoded to easily verify the Q, K and V before and after the all-to-all. You may use pass `compile` to run compiled Ulysses (Ulysses with graph breaks) or `deepcompile` to run deepcompiled Ulysses (allwall inserted within the compiler pass) + +We save the Q, K and V tensors before and after the all-toa-all: +For deepcompiled Ulysses, the tensors are saved here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/deepspeed/compile/patch_aot_module.py#L243 + +For compiled Ulysses, the tensors are saved here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/deepspeed/sequence/layer.py#L381 + +You can then run the script [check_qkv.py](https://github.com/neeldani/bench_dc_ulysses/blob/main/check_qkv.py) to compare the tensors at various stages i.e before all2all, after all2all, attention outputs etc + +## Code walkthrough +1. Script: [run_ulyssess.sh](https://github.com/neeldani/bench_dc_ulysses/blob/main/run_ulysses.sh) +2. The script calls: [run_acc_lm.py](https://github.com/neeldani/bench_dc_ulysses/blob/main/run_acc_lm.py). We have added support for another attention backend in HuggingFace called "ulysses" which uses DistributedAttention. The implementation can be found here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/transformers/src/transformers/models/llama/modeling_llama.py#L306 +3. If the `deepcompile` arg is passed to the config file, then a compiler pass will add the all2all's directy at the Torch IR level. The code for it can be found here: https://github.com/tohtana/DeepSpeed-internal/blob/neeld2/debug-loss/deepspeed/compile/patch_aot_module.py diff --git a/bench_dc_ulysses/configs/config.yaml b/bench_dc_ulysses/configs/config.yaml new file mode 100644 index 000000000000..254b83004286 --- /dev/null +++ b/bench_dc_ulysses/configs/config.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + deepspeed_config_file: configs/ds_config.json +distributed_type: DEEPSPEED +machine_rank: 1 +main_training_function: main +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/bench_dc_ulysses/configs/deepcompile_config.json b/bench_dc_ulysses/configs/deepcompile_config.json new file mode 100644 index 000000000000..93deb7402e19 --- /dev/null +++ b/bench_dc_ulysses/configs/deepcompile_config.json @@ -0,0 +1,19 @@ +{ + + "bf16": { + "enabled": true + }, + + "zero_optimization":{ + "stage": 0 + }, + "compile": { + "deepcompile": true + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/bench_dc_ulysses/configs/deepcompile_config.yaml b/bench_dc_ulysses/configs/deepcompile_config.yaml new file mode 100644 index 000000000000..405eb7163508 --- /dev/null +++ b/bench_dc_ulysses/configs/deepcompile_config.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + deepspeed_config_file: configs/deepcompile_config.json +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/bench_dc_ulysses/configs/ds_config.json b/bench_dc_ulysses/configs/ds_config.json new file mode 100644 index 000000000000..548aeaac6b78 --- /dev/null +++ b/bench_dc_ulysses/configs/ds_config.json @@ -0,0 +1,25 @@ +{ + + "bf16": { + "enabled": true + }, + + "zero_optimization":{ + "stage": 0 + }, + "compile": { + "deepcompile": true, + "offload_activation": false, + "offload_opt_states": false, + "double_buffer": true, + "symmetric_memory": false, + "free_activation": false, + "dump_graphs": false + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/bench_dc_ulysses/configs/ds_config.json.template b/bench_dc_ulysses/configs/ds_config.json.template new file mode 100644 index 000000000000..9afb8b6a7159 --- /dev/null +++ b/bench_dc_ulysses/configs/ds_config.json.template @@ -0,0 +1,30 @@ +{ + {% if fp16 %} + "fp16": { + "enabled": true, + "initial_scale_power": 8 + }, + {% else %} + "bf16": { + "enabled": true + }, + {% endif %} + "zero_optimization":{ + "stage": 0 + }, + "compile": { + "deepcompile": {{ deepcompile }}, + "offload_activation": false, + "offload_opt_states": false, + "double_buffer": true, + "symmetric_memory": false, + "free_activation": false, + "dump_graphs": false + }, + "gradient_accumulation_steps": {{ gradient_accumulation_steps }}, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/bench_dc_ulysses/configs/ds_config.yaml.template b/bench_dc_ulysses/configs/ds_config.yaml.template new file mode 100644 index 000000000000..f130fbea7f98 --- /dev/null +++ b/bench_dc_ulysses/configs/ds_config.yaml.template @@ -0,0 +1,19 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + {%- if zero_stage == 3 %} + zero3_init_flag: true + {%- endif %} + deepspeed_config_file: configs/ds_config.json +distributed_type: DEEPSPEED +machine_rank: {{ machine_rank }} +main_training_function: main +num_machines: {{ num_machines }} +num_processes: {{ num_processes }} +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/bench_dc_ulysses/configs/torchcompile_config.json b/bench_dc_ulysses/configs/torchcompile_config.json new file mode 100644 index 000000000000..d61b17b9f047 --- /dev/null +++ b/bench_dc_ulysses/configs/torchcompile_config.json @@ -0,0 +1,14 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization":{ + "stage": 0 + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/bench_dc_ulysses/configs/torchcompile_config.yaml b/bench_dc_ulysses/configs/torchcompile_config.yaml new file mode 100644 index 000000000000..cebc281c2b97 --- /dev/null +++ b/bench_dc_ulysses/configs/torchcompile_config.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + deepspeed_config_file: configs/torchcompile_config.json +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/bench_dc_ulysses/distributed_attention.py b/bench_dc_ulysses/distributed_attention.py new file mode 100644 index 000000000000..32d87eae2bef --- /dev/null +++ b/bench_dc_ulysses/distributed_attention.py @@ -0,0 +1,104 @@ +import os +import torch +import torch.distributed as dist +from deepspeed.sequence.layer import DistributedAttention +from sp_dp_registry import get_group, is_setup, sp_size + +#TODO: Hacky, need to fix it +_padding_mask_context = None + +def set_padding_mask(mask): + global _padding_mask_context + _padding_mask_context = mask + +def get_padding_mask(): + global _padding_mask_context + return _padding_mask_context + +def clear_padding_mask(): + global _padding_mask_context + _padding_mask_context = None + +def ulysses_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=None, + dropout=0.0, + is_causal=True, + **kwargs, +): + assert is_setup(), 'Incorrectly setup SP/DP Groups.' + + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + # Ulysses expects (batch, seq, heads, dim) + # HF standard provides (batch, heads, seq, dim) + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + + if not hasattr(self, "ulysses_engine"): + self.ulysses_engine = DistributedAttention( + sdpa_wrapper, + group, + scatter_idx=2, # Shard heads + gather_idx=1 # Gather sequences + ) + + # b, s, n, h + # Note: we don't pass attention_mask here because it's the 4D mask created by HF + # based on sharded dimensions. We'll create the correct mask in sdpa_wrapper + # using the original unsharded padding mask stored in context. + attn_output = self.ulysses_engine( + q, k, v, + batch_dim_idx=0, + attn_mask=None, + dropout_p=dropout, + is_causal=False, + scale=scaling + ) + + # Return to HF format: (batch, seq, heads, dim) -> (batch, heads, seq, dim) + # Note: Transformers usually expects (B, N, S, H) back, + # but Llama's forward handles the reshape if we are careful. + return attn_output, None + +def sdpa_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None): + # Permute from [b, s, n, h] to [b, n, s, h] for SDPA + q = query.permute(0, 2, 1, 3).contiguous() + k = key.permute(0, 2, 1, 3).contiguous() + v = value.permute(0, 2, 1, 3).contiguous() + + # Create the attention mask from padding mask + causal mask + padding_mask = get_padding_mask() + combined_mask = None + + if padding_mask is not None: + B, S = padding_mask.shape # [B, S] + device = padding_mask.device + + causal_mask = torch.tril(torch.ones(S, S, device=device, dtype=torch.bool)) + padding_mask_bool = (padding_mask != 0).unsqueeze(1) # [B, 1, S] + causal_expanded = causal_mask.unsqueeze(0) # [1, S, S] + combined_mask = causal_expanded & padding_mask_bool # [B, S, S] + combined_mask = combined_mask.unsqueeze(1) # [B, 1, S, S] + + elif is_causal: + pass + + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=combined_mask, + dropout_p=dropout_p, + is_causal=(combined_mask is None and is_causal), + scale=scale, + enable_gqa=False + ) + + # Permute back from [b, n, s, h] to [b, s, n, h] for all-to-all on output + output = output.permute(0, 2, 1, 3).contiguous() + return output diff --git a/bench_dc_ulysses/gen_chart_acc_steps.py b/bench_dc_ulysses/gen_chart_acc_steps.py new file mode 100644 index 000000000000..8b3cbd9201ab --- /dev/null +++ b/bench_dc_ulysses/gen_chart_acc_steps.py @@ -0,0 +1,263 @@ +import argparse +import re +import pandas as pd +import matplotlib.pyplot as plt +from pathlib import Path + +def throughput_calculator(micro_batch_size, acc_steps, np, elapsed_time_per_iter, + hidden_size, num_attention_heads, num_key_value_heads, + ffn_hidden_size, num_layers, padded_vocab_size, seq_len, + topk: int, swiglu: bool, checkpoint_activations: bool): + batch_size = micro_batch_size * acc_steps * np + samples_per_second = batch_size / elapsed_time_per_iter + + head_dim = hidden_size // num_attention_heads + gqa = num_attention_heads // num_key_value_heads + ffn_multiplier = 3 if swiglu else 2 + macs_per_flops = 2 + + pre_and_post_mha_gemm_macs = batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len + mha_bgemm_macs = batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2) + ffn_gemm_macs = batch_size * num_layers * ffn_multiplier * ffn_hidden_size * hidden_size * seq_len * topk + logit_lmhead_gemm_macs = batch_size * padded_vocab_size * hidden_size * seq_len + + fwd_macs = pre_and_post_mha_gemm_macs + mha_bgemm_macs + ffn_gemm_macs + logit_lmhead_gemm_macs + bwd_macs = 2 * fwd_macs + fwd_bwd_macs = fwd_macs + bwd_macs + + if checkpoint_activations: + fwd_bwd_macs += fwd_macs + + flops_per_iteration = fwd_bwd_macs * macs_per_flops + tflops = flops_per_iteration / (elapsed_time_per_iter * np * (10**12)) + return samples_per_second, tflops + + +model_info = { + "meta-llama/Meta-Llama-3-8B": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "ffn_hidden_size": 16384, + "num_layers": 32, + "padded_vocab_size": 32000, + "topk": 1, + "swiglu": True # Meta-Llama-3ではswigluが使われていると仮定 + }, + "meta-llama/Meta-Llama-3-70B-Instruct": { + "hidden_size": 8192, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "ffn_hidden_size": 32768, + "num_layers": 80, + "padded_vocab_size": 32000, + "topk": 1, + "swiglu": True # Meta-Llama-3ではswigluが使われていると仮定 + }, + "mistralai/Mixtral-8x7B-v0.1": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "ffn_hidden_size": 16384, + "num_layers": 32, + "padded_vocab_size": 32000, + "topk": 2, # MixtralではMoEで2エキスパート + "swiglu": False # Mistralはswigluを使っていないと仮定 + } +} + +parser = argparse.ArgumentParser(description="Plot performance metrics.") +parser.add_argument("--metric", choices=["iteration_time", "throughput", "flops", "mfu", "peak_mem"], required=True, + help="Metric to plot: 'iteration_time', 'flops', 'mfu', or 'peak_mem'") +parser.add_argument("--result_dir", type=str, required=True, help="Path to the directory containing results.txt") +parser.add_argument("--result_file", type=str, default="results.txt", help="Name of the result file") +args = parser.parse_args() + + +# データのパース +pattern = re.compile( + r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) " + r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) iteration time: (?P[\d.]+) " + r"alloc_mem: (?P\d+) peak_mem: (?P\d+)" +) +pattern_ctime = re.compile( + r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) " + r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) " + r"alloc_mem: (?P\d+) peak_mem: (?P\d+)" +) +pattern_cs = re.compile( + r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) " + r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) schedule=(?P\w+) passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) " + r"alloc_mem: (?P\d+) peak_mem: (?P\d+)" +) + +file = Path(args.result_dir) / args.result_file +matches = [] +with open(file) as f: + for line in f: + match = pattern.match(line) + if not match: + match = pattern_ctime.match(line) + if not match: + match = pattern_cs.match(line) + if not match: + print(f"Not matched: {line}") + if match: + d = match.groupdict() + if "passes" not in d: + d["passes"] = "" + if "compile_time" not in d: + d["compile_time"] = 0 + if "schedule" not in d: + d["schedule"] = d["compile"] + matches.append(d) + +df = pd.DataFrame(matches) + +# 型変換 +df["ds"] = df["ds"] == "True" +df["compile"] = df["compile"] == "True" +df["np"] = df["np"].astype(int) +df["batch_size"] = df["batch_size"].astype(int) # batch_sizeをfloatに変換 +df["seq"] = df["seq"].astype(int) +df["iteration_time"] = df["iteration_time"].astype(float) # iteration_timeをfloatに変換 +df["alloc_mem"] = df["alloc_mem"].astype(float) +df["peak_mem"] = df["peak_mem"].astype(float) +df["acc"] = df["acc"].astype(int) # accも明示的にint型へ +df["ac"] = df["ac"] == "True" # acを真偽値に変換 +df["compile_time"] = df["compile_time"].astype(float) +df["schedule"] = df["schedule"] == "True" + + +# モデルごとの計算とプロット +grouped = df.groupby(["model", "np", "batch_size"]) + +theoretical_peak = 312 # 理論ピーク性能 (TFLOPS) + + +LABEL_ZERO3 = "ZeRO3" +LABEL_ZERO3_C = "ZeRO3 (C)" +LABEL_FSDP = "FSDP" +LABEL_DC_PS = "DeepCompile (P+S)" +LABEL_DC_P = "DeepCompile (P)" +LABEL_DC_S = "DeepCompile (S)" + +for (model, np, batch_size), group in grouped: + group = group.sort_values("acc") + acc_labels = group["acc"].unique() + + print(f"acc_labels: {acc_labels}") + + metric_values = {LABEL_ZERO3: [0] * len(acc_labels), + LABEL_ZERO3_C: [0] * len(acc_labels), + LABEL_FSDP: [0] * len(acc_labels), + LABEL_DC_PS: [0] * len(acc_labels), + LABEL_DC_P: [0] * len(acc_labels), + LABEL_DC_S: [0] * len(acc_labels)} + + for _, row in group.iterrows(): + + if row["ds"] and not row["compile"]: + category = LABEL_ZERO3 + elif not row["ds"] and not row["compile"]: + category = LABEL_FSDP + elif row["ds"] and row["compile"]: + if not row["schedule"]: + category = LABEL_ZERO3_C + elif row["passes"] == "" or row["passes"] == 'prefetch,selective_gather': + category = LABEL_DC_PS + # print(f"found prefetch,selective_gather") + elif row["passes"] == 'prefetch': + category = LABEL_DC_P + # print(f"found prefetch") + elif row["passes"] == 'selective_gather': + category = LABEL_DC_S + # print(f"found selective_gather") + else: + print(f"Unknown category: {row}") + continue + else: + print(f"Unknown category: {row}") + continue + + acc_index = list(acc_labels).index(row["acc"]) + if args.metric == "iteration_time": + metric_values[category][acc_index] = row["iteration_time"] + elif args.metric == "peak_mem": + metric_values[category][acc_index] = row["peak_mem"] / (1024**3) + elif args.metric == "throughput": + metric_values[category][acc_index] = row["batch_size"] * row["seq"] * row["acc"] / row["iteration_time"] + elif args.metric in ["flops", "mfu"]: + # モデル情報を使用して FLOPs を計算 + model_params = model_info[row["model"]] + samples_per_second, tflops = throughput_calculator( + micro_batch_size=row["batch_size"], + acc_steps=row["acc"], # ログから取得 + np=row["np"], + elapsed_time_per_iter=row["iteration_time"], + hidden_size=model_params["hidden_size"], + num_attention_heads=model_params["num_attention_heads"], + num_key_value_heads=model_params["num_key_value_heads"], + ffn_hidden_size=model_params["ffn_hidden_size"], + num_layers=model_params["num_layers"], + padded_vocab_size=model_params["padded_vocab_size"], + seq_len=row["seq"], + topk=model_params["topk"], + swiglu=model_params["swiglu"], # モデル定義から取得 + checkpoint_activations=row["ac"] # ログから取得 + ) + if args.metric == "flops": + metric_values[category][acc_index] = tflops + elif args.metric == "mfu": + metric_values[category][acc_index] = tflops / theoretical_peak + + # グラフ作成 + x = range(len(acc_labels)) + width = 0.15 # 棒グラフの幅 + ylabel = { + "iteration_time": "Iteration Time (s)", + "flops": "TFLOPS", + "throughput": "Throughput (tokens/s/GPU)", + "mfu": "MFU", + "peak_mem": "Peak Memory (GB)" + }[args.metric] + + plt.figure(figsize=(10, 8)) + adjust = - 0.5 * width + plt.bar([i - width*2 + adjust for i in x], metric_values[LABEL_ZERO3], width, label=LABEL_ZERO3, alpha=0.7) + plt.bar([i - width + adjust for i in x], metric_values[LABEL_ZERO3_C], width, label=LABEL_ZERO3_C, alpha=0.7) + plt.bar([i + adjust for i in x], metric_values[LABEL_FSDP], width, label=LABEL_FSDP, alpha=0.7) + plt.bar([i + width + adjust for i in x], metric_values[LABEL_DC_P], width, label=LABEL_DC_P, alpha=0.7) + plt.bar([i + width*2 + adjust for i in x], metric_values[LABEL_DC_S], width, label=LABEL_DC_S, alpha=0.7) + plt.bar([i + width*3 + adjust for i in x], metric_values[LABEL_DC_PS], width, label=LABEL_DC_PS, alpha=0.7) + + gain_zero3 = [metric_values[LABEL_DC_PS][i] / metric_values[LABEL_ZERO3][i] for i in range(len(acc_labels))] + print(f"model {model} np {np} batch_size {batch_size} {LABEL_ZERO3} metric_values: {metric_values[LABEL_ZERO3]} gain_zero3: {gain_zero3}") + print(f"model {model} np {np} batch_size {batch_size} {LABEL_DC_PS} metric_values: {metric_values[LABEL_DC_PS]}") + + model = model.split('/')[1] + model = model.replace("Meta-Llama-3-8B", "Llama-3-8B") + model = model.replace("Meta-Llama-3-70B-Instruct", "Llama-3-70B") + model = model.replace("Mixtral-8x7B-v0.1", "Mixtral-8x7B") + + plt.title(f"Model: {model}, #GPUs: {np}, Batch Size: {batch_size}", fontsize=24) + plt.xlabel("Acc Steps", fontsize=24) + plt.ylabel(ylabel, fontsize=24) + plt.xticks(x, acc_labels, fontsize=24) + + if args.metric == "peak_mem": + plt.ylim(0, 80) + + plt.yticks(fontsize=20) + plt.legend(loc="lower right", fontsize=18) + plt.grid(axis="y") + + # ファイル保存 + metric_name = args.metric + model = model.replace("/", "_") + chart_dir = Path(args.result_dir) / Path(metric_name) + chart_dir.mkdir(parents=True, exist_ok=True) + conf_str = f"{metric_name}_{model}_np{np}_bs{batch_size}" + img_path = chart_dir / f"chart_{conf_str}.png" + plt.savefig(str(img_path)) + plt.close() diff --git a/bench_dc_ulysses/generate_conf.py b/bench_dc_ulysses/generate_conf.py new file mode 100644 index 000000000000..29fa1c4f4c96 --- /dev/null +++ b/bench_dc_ulysses/generate_conf.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +from jinja2 import Template +from pathlib import Path + +def get_args(): + parser = argparse.ArgumentParser(description='Config generation') + + parser.add_argument('--machine_rank', type=int, help='machine_rank') + parser.add_argument('--num_machines', type=int, help='num_machines') + parser.add_argument('--num_processes', type=int, help='num_processes') + parser.add_argument('--zero_stage', type=int, choices=[0, 1, 2, 3], help='ZeRO stage') + parser.add_argument('--fp16', action='store_true', help='Use fp16') + parser.add_argument('--gradient_accumulation_steps', type=int, default=1) + parser.add_argument('--deepcompile', action='store_true', help='Use deepcompile') + + parser.add_argument('--template_file', type=Path, help='Template file') + parser.add_argument('--output_file', type=Path, help='Output file') + + return parser.parse_args() + + +def main(args): + with open(args.template_file, 'r') as f: + template = Template(f.read()) + + with open(args.output_file, 'w') as f: + f.write(template.render(machine_rank=args.machine_rank, + num_machines=args.num_machines, + num_processes=args.num_processes, + zero_stage=args.zero_stage, + fp16=args.fp16, + gradient_accumulation_steps=args.gradient_accumulation_steps, + deepcompile=str(args.deepcompile).lower())) + +if __name__ == '__main__': + args = get_args() + main(args) diff --git a/bench_dc_ulysses/hostfile_n1 b/bench_dc_ulysses/hostfile_n1 new file mode 100644 index 000000000000..f81666ed14a0 --- /dev/null +++ b/bench_dc_ulysses/hostfile_n1 @@ -0,0 +1 @@ +node-0 slots=1 diff --git a/bench_dc_ulysses/hostfile_n2 b/bench_dc_ulysses/hostfile_n2 new file mode 100644 index 000000000000..5d6bf941211b --- /dev/null +++ b/bench_dc_ulysses/hostfile_n2 @@ -0,0 +1 @@ +node-0 slots=2 diff --git a/bench_dc_ulysses/hostfile_n4 b/bench_dc_ulysses/hostfile_n4 new file mode 100644 index 000000000000..5d6bf941211b --- /dev/null +++ b/bench_dc_ulysses/hostfile_n4 @@ -0,0 +1 @@ +node-0 slots=2 diff --git a/bench_dc_ulysses/launch_jobs.sh b/bench_dc_ulysses/launch_jobs.sh new file mode 100755 index 000000000000..e59e4a901706 --- /dev/null +++ b/bench_dc_ulysses/launch_jobs.sh @@ -0,0 +1,13 @@ +# launch job_*.slurm in slurm_jobs +# Usage: bash launch_jobs.sh + +# delete .out files in slurm_out +# for out in slurm_out/*.out; do +# rm -vf $out +# done + + +for job in slurm_jobs/job_*.slurm; do + # echo "Submitting job $job" + sbatch $job +done \ No newline at end of file diff --git a/bench_dc_ulysses/ring_attention.py b/bench_dc_ulysses/ring_attention.py new file mode 100644 index 000000000000..7b01da7b96cf --- /dev/null +++ b/bench_dc_ulysses/ring_attention.py @@ -0,0 +1,530 @@ +## Code is taken directly from the RingFlashAttention +## repository: https://github.com/zhuzilin/ring-flash-attention +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import inspect +from functools import cache + +from sp_dp_registry import get_group, is_setup, sp_size +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +__all__ = ["update_out_and_lse", "RingComm", "get_default_args"] + +## Utility communication files. ## +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty( + (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device + ) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() + return next_k, next_v + + +class AllGatherComm: + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + handle = dist.all_gather_into_tensor( + output_tensor, input_tensor, group=self.group, async_op=True + ) + self.handles.append(handle) + + def wait(self): + for handle in self.handles: + handle.wait() + self.handles = [] + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + if not causal or step <= comm.rank: + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal and step == 0, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout, + "q": q, + "k": k, + "v": v, + "out": out, + "softmax_lse": softmax_lse, + "dq": block_dq_buffer, + "dk": block_dk_buffer, + "dv": block_dv_buffer, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": bwd_causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk, dv = next_dk, next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv(dk, dv) + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +# HuggingFace-compatible wrapper for ring attention +# This follows the same pattern as ulysses_attention_forward in distributed_attention.py +def ring_attention_forward( + self, # This will be the LlamaAttention instance + query_states, + key_states, + value_states, + attention_mask=None, + scaling=None, + dropout=0.0, + is_causal=True, + **kwargs, +): + """ + Ring attention forward pass compatible with HuggingFace's attention interface. + + Args: + self: The LlamaAttention module instance + query_states: (batch, heads, seq, dim) - HuggingFace format + key_states: (batch, heads, seq, dim) - HuggingFace format + value_states: (batch, heads, seq, dim) - HuggingFace format + attention_mask: Not used (ring attention handles masking internally) + scaling: Softmax scaling factor + dropout: Dropout probability + is_causal: Whether to use causal masking + **kwargs: Additional arguments (ignored) + + Returns: + tuple: (attn_output, None) where attn_output is (batch, seq, heads, dim) + """ + # Convert from HF format (batch, heads, seq, dim) to flash_attn format (batch, seq, heads, dim) + assert is_setup(), 'Incorrectly setup SP/DP Groups.' + + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + + # Ring attention expects (batch, seq, heads, dim) + # Call the ring flash attention function + attn_output = ring_flash_attn_func( + q, + k, + v, + dropout_p=dropout, + softmax_scale=scaling, + causal=is_causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=group, + ) + + # Output is already in (batch, seq, heads, dim) format, which HF expects after attention + # Note: Llama's forward handles the reshape internally + return attn_output, None diff --git a/bench_dc_ulysses/run.sh b/bench_dc_ulysses/run.sh new file mode 100755 index 000000000000..08bb41b09744 --- /dev/null +++ b/bench_dc_ulysses/run.sh @@ -0,0 +1,56 @@ +HOST_IP=$1 +NUM_NODES=$2 +NUM_PROCESSES=$3 +BACKEND=$4 +MODEL=$5 +GRADIENT_ACCUMULATION_STEPS=$6 +DEEPCOMPILE=$7 +shift 7 +EXTRA_OPTS="$@" + +export NCCL_DEBUG=WARN + +CONFIG_TEMPLATE=configs/ds_config.yaml.template + +echo "HOST_IP: ${HOST_IP}" +echo "NUM_NODES: ${NUM_NODES}" +echo "NUM_PROCESSES: ${NUM_PROCESSES}" +echo "BACKEND: ${BACKEND}" +echo "MODEL: ${MODEL}" +echo "GRADIENT_ACCUMULATION_STEPS: ${GRADIENT_ACCUMULATION_STEPS}" +echo "EXTRA_OPTS: ${EXTRA_OPTS}" + +MACHINE_RANK=$(hostname | sed 's/[^0-9]*//g') + +python generate_conf.py \ + --machine_rank ${MACHINE_RANK} \ + --num_machines ${NUM_NODES} \ + --num_processes ${NUM_PROCESSES} \ + --template_file ${CONFIG_TEMPLATE} \ + --output_file configs/config.yaml + +GAS_OPTS="--gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS}" +DEEPCOMPILE_OPTS="" +if [ "${DEEPCOMPILE}" == "1" ]; then + DEEPCOMPILE_OPTS="--deepcompile" +fi + +if [ "${BACKEND}" == "deepspeed" ]; then + python generate_conf.py \ + --machine_rank ${MACHINE_RANK} \ + --num_machines ${NUM_NODES} \ + --num_processes ${NUM_PROCESSES} \ + --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \ + ${DEEPCOMPILE_OPTS} \ + --template_file configs/ds_config.json.template \ + --output_file configs/ds_config.json +fi + +accelerate launch --main_process_ip ${HOST_IP} --main_process_port 12345 \ +--num_machines ${NUM_NODES} --num_processes ${NUM_PROCESSES} --machine_rank ${MACHINE_RANK} \ +--config_file configs/config.yaml \ +run_acc_lm.py \ +--model_name "${MODEL}" \ +${GAS_OPTS} \ +${EXTRA_OPTS} \ +2>&1 | tee ${LOG_FILE} \ No newline at end of file diff --git a/bench_dc_ulysses/run_acc_lm.py b/bench_dc_ulysses/run_acc_lm.py new file mode 100644 index 000000000000..ce3aeaff443c --- /dev/null +++ b/bench_dc_ulysses/run_acc_lm.py @@ -0,0 +1,251 @@ +import os + +# Suppress tokenizers parallelism warning (must be before importing transformers) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import argparse +from datetime import datetime + +import torch + +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, enable_full_determinism +from datasets import load_dataset +from accelerate import Accelerator +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +import torch.distributed as dist + +import torch +import random +import numpy as np +import time +import os + +from distributed_attention import ulysses_attention_forward +# from ring_attention import ring_attention_forward +from sp_dp_registry import get_group, populate_registry, get_registry + +torch.set_float32_matmul_precision("high") + +def set_seed(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def seed_worker(worker_id): + worker_seed = 12 + worker_id + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def prepare_autosp_inputs(input_id: torch.Tensor, label_id: torch.Tensor, position_id: torch.Tensor, attention_mask: torch.Tensor, seq_dim: int): + torch._dynamo.decorators.mark_dynamic(input_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(label_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(position_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) + input_id.tag = "input_id" + label_id.tag = "label_id" + position_id.tag = "position_id" + return input_id, label_id, position_id, attention_mask + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-7b-hf") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=512) + parser.add_argument("--steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--activation_checkpointing", action="store_true") + parser.add_argument("--dataset_name", type=str, default="timdettmers/openassistant-guanaco") + parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--compile", type=str, default="deepcompile") + parser.add_argument("--passes", type=str, default=None) + parser.add_argument("--backend", type=str, default="inductor") + parser.add_argument("--offload_opt_states", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--profile_memory", action="store_true") + parser.add_argument("--deterministic", action="store_true") + parser.add_argument("--profile_dir", type=str, default="profiles") + parser.add_argument("--bench_step", type=int, default=1) + parser.add_argument("--warmup_step", type=int, default=15) + parser.add_argument("--print_interval", type=int, default=1) + parser.add_argument("--experiment_folder", type=str, default="") + parser.add_argument("--sp_size", type=int, default=2) + parser.add_argument("--dp_size", type=int, default=1) + + return parser.parse_args() + +def main(): + args = get_args() + set_seed(12) + + if args.deterministic: + enable_full_determinism(12) + from torch._inductor import config + config.fallback_random = True + torch.use_deterministic_algorithms(True) + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + device = accelerator.device + is_deepspeed = accelerator.state.deepspeed_plugin is not None + assert accelerator.num_processes == args.sp_size * args.dp_size, 'Incorrect dp/sp sizing' + + ## Set sp/dp groups accordingly. ## + if args.compile in ['compile', 'eager', 'ringattn']: + populate_registry(args.sp_size, args.dp_size) + + if accelerator.is_main_process: + print(f'GROUP_REGISTRY: {get_registry()}') + + # Load model and tokenizer + if accelerator.is_main_process: + print("Loading model and tokenizer...") + + model_name = args.model_name + if args.compile == "deepcompile": + attention_backend = "sdpa" + else: + if args.compile == "eager" or args.compile == "compile": + from transformers.models.llama import modeling_llama + attention_backend = "ulyssess" + modeling_llama.ALL_ATTENTION_FUNCTIONS["ulyssess"] = ulysses_attention_forward + elif args.compile == "ringattn": + from transformers.models.llama import modeling_llama + attention_backend = "ringattn" + modeling_llama.ALL_ATTENTION_FUNCTIONS["ringattn"] = ring_attention_forward + + if args.num_layers is not None: + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + if accelerator.is_main_process: + print(f"num_hidden_layers: {model_config.num_hidden_layers} -> {args.num_layers}") + model_config.num_hidden_layers = args.num_layers + model_config._attn_implementation = attention_backend + model = AutoModelForCausalLM.from_config(model_config, trust_remote_code=True) + else: + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model_config._attn_implementation = attention_backend + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config, trust_remote_code=True) + + if args.activation_checkpointing: + model.gradient_checkpointing_enable() + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + tokenizer.pad_token = tokenizer.eos_token + + # Load dataset + if accelerator.is_main_process: + print("Loading dataset...") + + g = torch.Generator() + g.manual_seed(12) + dataset = load_dataset('ag_news', split='train[:1%]') + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + tokenizer.pad_token = tokenizer.convert_ids_to_tokens(2) + + def tokenize_function(examples): + return tokenizer(examples['text'], padding='max_length', max_length=args.seq_length, truncation=True) ## Fix max_length and generate fake data instead to not exhaust disk. + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) + + num_replicas_ = args.dp_size + rank_ = accelerator.process_index // args.sp_size + + sampler = DistributedSampler(tokenized_dataset, num_replicas=num_replicas_, rank=rank_, seed=12, shuffle=False) + data_loader = DataLoader(tokenized_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4, worker_init_fn=seed_worker, generator=g) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + + model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) + print(f"Model prepared: {model.__class__}") + + if args.compile == "deepcompile": + print(f"Running deepcompile with backend={args.backend}") + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True + model.compile(backend=args.backend) + elif args.compile in ["compile", "ringattn"]: + print(f"Running torch.compile with backend={args.backend}") + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True + model = torch.compile(model, backend=args.backend) + else: + print(f"Running eager") + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + model_name = args.model_name.split("/")[-1] + exp_name = f"{model_name}_np{accelerator.num_processes}_{args.compile}_" \ + f"B{args.backend}_" \ + f"L{0 if args.num_layers is None else args.num_layers}_" \ + f"bs{args.batch_size}_seq{args.seq_length}_" \ + f"T{timestamp}" + if args.profile_dir: + if accelerator.is_main_process and args.profile_dir: + os.makedirs(args.profile_dir, exist_ok=True) + if args.profile: + prof_dir = f"{args.profile_dir}/{exp_name}" + os.makedirs(prof_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Training loop + model.train() + global_step = 0 + print(f"Using global sequence length: {args.seq_length}") + + os.makedirs("logs", exist_ok=True) + loss_log_file = open(f"logs/loss_{args.compile}_{args.seq_length}_{accelerator.process_index}.csv", "w") + loss_log_file.write("step,loss\n") + + sp_rank = dist.get_rank() % args.sp_size + for epoch in range(args.num_epochs): + start_iter = time.time() + + for step, batch in enumerate(data_loader): + input_ids = batch['input_ids'].to(device) # [B, S] + B, S = input_ids.shape + + label_ids = input_ids.clone() # [B, S] + position_ids = torch.arange(S, device=device).unsqueeze(0) + attention_mask = batch['attention_mask'].to(device) + + #HACK: store the padding mask to be accessed directly in local attention + from distributed_attention import set_padding_mask + set_padding_mask(attention_mask) + + if args.compile == 'deepcompile': + input_ids, label_ids, position_ids, attention_mask = prepare_autosp_inputs(input_ids, label_ids, position_ids, attention_mask, seq_dim=1) + else: + chunk_size = S // args.sp_size + start = sp_rank * chunk_size + end = start + chunk_size + input_ids = input_ids[:, start:end] # [B, S_shard] + label_ids = label_ids[:, start:end] # [B, S_shard] - must match input_ids + position_ids = position_ids[:, start:end] + + outputs = model(input_ids=input_ids, labels=label_ids, position_ids=position_ids, attention_mask=attention_mask) + loss = outputs.loss + print(f"Epoch {epoch+1}, Step {global_step}, Loss: {loss.item()} time: {time.time() - start_iter} alloc_mem: {torch.cuda.memory_allocated() / (1024 ** 3)} peak_mem: {torch.cuda.max_memory_allocated() / (1024 ** 3)}") + + accelerator.backward(loss) + + loss_log_file.write(f"{global_step},{loss.item()}\n") + loss_log_file.flush() + + global_step += 1 + if global_step > args.steps: + break + +if __name__ == "__main__": + torch._dynamo.config.accumulated_cache_size_limit = 256 + torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.optimize_ddp = False + main() + + diff --git a/bench_dc_ulysses/run_bench.sh b/bench_dc_ulysses/run_bench.sh new file mode 100755 index 000000000000..a46d2412df2c --- /dev/null +++ b/bench_dc_ulysses/run_bench.sh @@ -0,0 +1,19 @@ +PROFILE_DIR=${PROFILE_DIR:-profiles} +mkdir -p ${PROFILE_DIR} +PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}" +COMPILE_OPTS="--compile" +DC_OPTS="--compile --deepcompile" +ACC_OPTS="--gradient-accumulation-steps 1" +AC_OPTS="--activation-checkpointing" + +MODEL="meta-llama/Llama-2-7b-chat-hf" +BATCH_SIZE_OPTS=(1) +SEQ_LENGTH=$1 + +for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do + ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} ${ACC_OPTS} ${PROFILE_OPTS}" + + # compiled ulysses + bash ./run_multinode.sh --backend inductor ${ARGS} ${DC_OPTS} --num-layers 1 --num-gpus 2 --seq-length ${SEQ_LENGTH} + cp -r logs ${PROFILE_DIR}/ +done diff --git a/bench_dc_ulysses/run_bench_acc.sh b/bench_dc_ulysses/run_bench_acc.sh new file mode 100755 index 000000000000..a3b66844d279 --- /dev/null +++ b/bench_dc_ulysses/run_bench_acc.sh @@ -0,0 +1,42 @@ +PROFILE_DIR=${PROFILE_DIR:-profiles} +mkdir -p ${PROFILE_DIR} +PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}" +COMPILE_OPTS="--compile" +N3Z_OPTS="--compile --deepcompile" +AC_OPTS="--activation-checkpointing" + +MODEL="meta-llama/Meta-Llama-3-70B-Instruct" +BATCH_SIZE_OPTS=(1) +SEQ_LENGTH_OPTS=(1024) +ACC_OPTS=(2 4 8 16) +for ACC_STEP in ${ACC_OPTS[@]}; do + for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do + for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do + ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${AC_OPTS} ${PROFILE_OPTS} --gradient-accumulation-steps ${ACC_STEP}" + bash ./run_multinode.sh --backend deepspeed ${ARGS} + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS} + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch,selective_gather + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes selective_gather + cp -r logs ${PROFILE_DIR}/ + done + done +done + +MODEL="mistralai/Mixtral-8x7B-v0.1" +BATCH_SIZE_OPTS=(1) +SEQ_LENGTH_OPTS=(1024) +ACC_OPTS=(2 4 8 16) +for ACC_STEP in ${ACC_OPTS[@]}; do + for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do + for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do + ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${AC_OPTS} ${PROFILE_OPTS} --gradient-accumulation-steps ${ACC_STEP}" + bash ./run_multinode.sh --backend deepspeed ${ARGS} + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS} + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch,selective_gather + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch + bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes selective_gather + cp -r logs ${PROFILE_DIR}/ + done + done +done diff --git a/bench_dc_ulysses/run_correctness_test.sh b/bench_dc_ulysses/run_correctness_test.sh new file mode 100755 index 000000000000..37decf60a9d4 --- /dev/null +++ b/bench_dc_ulysses/run_correctness_test.sh @@ -0,0 +1,4 @@ +MASTER_ADDR="" + +ds_ssh "cd /scratch/amlt_code/run_z3_graph_rewrite/accelerate; BLOB_BASE_DIR=/mnt/post-training-ppo bash ./run.sh ${MASTER_ADDR} 4 32 deepspeed 3 meta-llama/Meta-Llama-3-70B-Instruct 1 --batch_size 1 --seq_length 512 --activation_checkpointing --bench_step 1000 --print_interval 1" 2>&1 | tee logs/debug_Meta-Llama-3-70B-Instruct_deepspeed_np32c0b1s512g1a1pALL.log +ds_ssh "cd /scratch/amlt_code/run_z3_graph_rewrite/accelerate; BLOB_BASE_DIR=/mnt/post-training-ppo bash ./run.sh ${MASTER_ADDR} 4 32 deepspeed 3 meta-llama/Meta-Llama-3-70B-Instruct 1 --compile --batch_size 1 --seq_length 512 --activation_checkpointing --passes prefetch,selective_gather --bench_step 1000 --print_interval 1" 2>&1 | tee logs/debug_Meta-Llama-3-70B-Instruct_deepspeed_np32c1b1s512g1a1pprefetch_selective_gather.log diff --git a/bench_dc_ulysses/run_multinode.sh b/bench_dc_ulysses/run_multinode.sh new file mode 100755 index 000000000000..c3e90a05ca79 --- /dev/null +++ b/bench_dc_ulysses/run_multinode.sh @@ -0,0 +1,112 @@ +#!/bin/bash + +NUM_NODES=${NUM_NODES:-$(wc -l < hostfile_n4)} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} +NUM_PROCESSES=$((${NUM_NODES} * ${NGPUS_PER_NODE})) + +BACKEND="deepspeed" # ignore +MODEL="meta-llama/Meta-Llama-3-8B" +COMPILE=0 +PASSES="ALL" +EXTRA_OPTS="" + +EAGER=0 +DEEPCOMPILE=0 +GRADIENT_ACCUMULATION_STEPS=1 +ACTIVATION_CHECKPOINTING=1 +BATCH_SIZE=1 +SEQ_LENGTH=512 + +while [[ $# -gt 0 ]]; do + case $1 in + --backend) + BACKEND="$2" + shift 2 + ;; + --batch-size) + BATCH_SIZE="$2" + EXTRA_OPTS="${EXTRA_OPTS} --batch_size $2" + shift 2 + ;; + --seq-length) + SEQ_LENGTH="$2" + EXTRA_OPTS="${EXTRA_OPTS} --seq_length $2" + shift 2 + ;; + --gradient-accumulation-steps) + GRADIENT_ACCUMULATION_STEPS="$2" + # EXTRA_OPTS="${EXTRA_OPTS} --gradient_accumulation_steps $2" + shift 2 + ;; + --activation-checkpointing) + ACTIVATION_CHECKPOINTING=1 + EXTRA_OPTS="${EXTRA_OPTS} --activation_checkpointing" + shift + ;; + --compile) + COMPILE=1 + EXTRA_OPTS="${EXTRA_OPTS} $1" + shift + ;; + --eager) + EAGER=1 + EXTRA_OPTS="${EXTRA_OPTS} --backend eager" + shift + ;; + --deepcompile) + DEEPCOMPILE=1 + shift + ;; + --passes) + PASSES="$2" + EXTRA_OPTS="${EXTRA_OPTS} $1 $2" + shift 2 + ;; + --profile) + EXTRA_OPTS="${EXTRA_OPTS} $1" + shift + ;; + --profile-dir) + EXTRA_OPTS="${EXTRA_OPTS} --profile_dir $2" + shift 2 + ;; + --model) + MODEL="$2" + shift 2 + ;; + --num-layers) + EXTRA_OPTS="${EXTRA_OPTS} --num_layers $2" + shift 2 + ;; + --num-gpus) + NGPUS_PER_NODE="$2" + NUM_PROCESSES=$((${NUM_NODES} * ${NGPUS_PER_NODE})) + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + + +HOST_IP=$(hostname -i) + +mkdir -p logs + +SCRIPT_DIR=$(dirname $(realpath $0)) + +#replace , with _ in PASSES +PASSES=$(echo $PASSES | tr ',' '_') + +LOG_FILE=debug_b${BACKEND}np${NUM_PROCESSES}c${COMPILE}dc${DEEPCOMPILE}bs${BATCH_SIZE}seq${SEQ_LENGTH}.log + +if [ "${NUM_NODES}" == "1" ]; then + # avoid dependency on pdsh when possible + cd ${SCRIPT_DIR}; bash ./run.sh ${HOST_IP} ${NUM_NODES} ${NUM_PROCESSES} ${BACKEND} ${MODEL} ${GRADIENT_ACCUMULATION_STEPS} ${DEEPCOMPILE} ${EXTRA_OPTS} \ + 2>&1 | tee logs/${LOG_FILE} +else + ds_ssh -f hostfile_n${NUM_NODES} "cd ${SCRIPT_DIR}; bash ./run.sh ${HOST_IP} ${NUM_NODES} ${NUM_PROCESSES} ${BACKEND} ${MODEL} ${GRADIENT_ACCUMULATION_STEPS} ${SCHEDULE} ${OFFLOAD_OPT_STATES} ${EXTRA_OPTS}" \ + 2>&1 | tee logs/${LOG_FILE} +fi diff --git a/bench_dc_ulysses/run_ulysses.sh b/bench_dc_ulysses/run_ulysses.sh new file mode 100755 index 000000000000..0554c92bed5f --- /dev/null +++ b/bench_dc_ulysses/run_ulysses.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +SEQ_LEN=${1:-1024} +COMPILE=${2:-eager} +SP_SIZE=${3:-2} +DP_SIZE=${4:-1} +LAYER_COUNT=${5:-""} +EXP_NAME=${6:-""} + +if [[ "$COMPILE" != "eager" && "$COMPILE" != "compile" && "$COMPILE" != "deepcompile" && "$COMPILE" != "ringattn" ]]; then + echo "Invalid mode: $COMPILE. Choose from eager, compile, deepcompile, ringattn." + exit 1 +fi + +HOST_IP=$(hostname -i | awk '{print $1}') +PORT=$(python3 -c "import socket; s = socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()") +NUM_NODES=1 +NUM_PROCESSES=$((SP_SIZE * DP_SIZE)) +MODEL="meta-llama/Llama-2-7b-chat-hf" +# MODEL="meta-llama/Llama-3.1-8B" +# MODEL="meta-llama/Llama-3.2-1B" +# MODEL="meta-llama/Llama-3.2-3B" +PROFILE_DIR=${PROFILE_DIR:-profiles} +mkdir -p ${PROFILE_DIR} +PROFILE_OPTS="--profile_dir ${PROFILE_DIR}" + +COMPILE_OPTS="--compile ${COMPILE}" +CONFIG_FILE="configs/torchcompile_config.yaml" +if [ "${COMPILE}" == "deepcompile" ]; then + CONFIG_FILE="configs/deepcompile_config.yaml" +fi + + +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +LOG_FILE=logs/log_${COMPILE}_seq${SEQ_LEN}_${TIMESTAMP}.log + +echo "HOST_IP: ${HOST_IP}" +echo "PORT: ${PORT}" +echo "NUM_NODES: ${NUM_NODES}" +echo "NUM_PROCESSES: ${NUM_PROCESSES}" +echo "MODEL: ${MODEL}" +echo "COMPILE: ${COMPILE}" +echo "SEQ_LEN: ${SEQ_LEN}" +echo "LOG_FILE: ${LOG_FILE}" + +EXTRA_OPTS="--seq_length=${SEQ_LEN} --experiment_folder=${EXP_NAME} --sp_size=${SP_SIZE} --dp_size=${DP_SIZE}" + +# Only pass --num_layers if provided +NUM_LAYER_OPTS="" +if [[ -n "${LAYER_COUNT}" ]]; then + NUM_LAYER_OPTS="--num_layers ${LAYER_COUNT}" +fi + +( +accelerate launch --main_process_ip ${HOST_IP} --main_process_port ${PORT} \ +--num_machines ${NUM_NODES} --num_processes ${NUM_PROCESSES} --machine_rank 0 \ +--config_file ${CONFIG_FILE} \ +run_acc_lm.py \ +--model_name "${MODEL}" ${NUM_LAYER_OPTS} \ +${PROFILE_OPTS} \ +${EXTRA_OPTS} \ +${COMPILE_OPTS} +) 2>&1 | tee ${LOG_FILE} + + diff --git a/bench_dc_ulysses/sample.slurm b/bench_dc_ulysses/sample.slurm new file mode 100644 index 000000000000..199e49a53168 --- /dev/null +++ b/bench_dc_ulysses/sample.slurm @@ -0,0 +1,43 @@ +#!/bin/bash +# SLURM Job Submission OR Interactive Environment Setup +# For batch submission: sbatch sample.slurm +# For interactive: +# srun -A bcjw-delta-gpu --time=1:00:00 --nodes=1 --mem=100G --gpus=2 --partition=gpuA100x4-interactive --pty /bin/bash +# +# SBATCH directives (ignored when sourced, used for sbatch submission): +#SBATCH -A bcjw-delta-gpu +#SBATCH --time=0:20:00 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --cpus-per-task=16 +#SBATCH --partition=gpuA100x4-interactive +#SBATCH --gpus=2 +#SBATCH --gpu-bind=closest +#SBATCH --mem=100G +#SBATCH --output=slurm_out/out_%j.log + +source ~/.bashrc +conda activate /u/ndani/autosp-env + +# # NCCL configuration +export NCCL_INCLUDE_DIR=/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/12.8/nccl-2.25/include +export NCCL_HOME=/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/12.8/nccl-2.25 +export CPATH=$NCCL_HOME/include:$CPATH +export LD_LIBRARY_PATH=$NCCL_HOME/lib:$LD_LIBRARY_PATH +# export CPATH=$NCCL_INCLUDE_DIR:$CPATH + +export TRITON_CACHE_DIR="/tmp/triton_$USER" +export NCCL_DEBUG=WARN +export NCCL_SOCKET_IFNAME=lo +export NCCL_IB_DISABLE=1 +export NCCL_P2P_LEVEL=NVL + +export LD_LIBRARY_PATH=/usr/local/cuda-12.8/lib64:$LD_LIBRARY_PATH +export PATH=/usr/local/cuda-12.8/bin:$PATH + +export HF_DATASETS_CACHE="/u/$USER/.cache" +export HF_HOME=$HF_DATASETS_CACHE +export HF_HUB_CACHE=$HF_DATASETS_CACHE +export HF_ASSETS_CACHE=$HF_DATASETS_CACHE +export TRANSFORMERS_CACHE=$HF_DATASETS_CACHE + diff --git a/bench_dc_ulysses/sp_dp_registry.py b/bench_dc_ulysses/sp_dp_registry.py new file mode 100644 index 000000000000..4fc1913f1499 --- /dev/null +++ b/bench_dc_ulysses/sp_dp_registry.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist + +GROUP_REGISTRY = {} # int -> dist.ProcessGroup + +def register_groups(groups): + """groups: List[List[int]], e.g. [[0,1],[2,3]]""" + for gid, ranks in enumerate(groups): + if gid not in GROUP_REGISTRY: + GROUP_REGISTRY[gid] = dist.new_group(ranks) + +def get_group(gid: int): + return GROUP_REGISTRY[gid] if gid is not None else dist.group.WORLD + +def get_registry(): + return GROUP_REGISTRY + +def is_setup(): + return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False + +def sp_size(): + assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' + + return GROUP_REGISTRY['SP_SIZE'] + +def dp_size(): + assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' + + return GROUP_REGISTRY['DP_SIZE'] + +def populate_registry(SP_SIZE, DP_SIZE): + ## We register in the run_acc_lm.py file for baselines to reduce code-duplication. + ## Else the registration happens within the SP compiler pass within deepspeed. + group_listing = [] + offset = 0 + for _ in range(DP_SIZE): + group_listing.append([i + offset for i in range(SP_SIZE)]) + offset += SP_SIZE + + register_groups(group_listing) + + ## Extraneous metadata required for proper instatiation. ## + GROUP_REGISTRY['SP_SIZE'] = SP_SIZE + GROUP_REGISTRY['DP_SIZE'] = DP_SIZE + GROUP_REGISTRY['is_reg'] = True diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py index 17ce36a833df..7862420a2006 100644 --- a/deepspeed/compile/init_sp.py +++ b/deepspeed/compile/init_sp.py @@ -5,10 +5,10 @@ import torch from torch.fx import GraphModule -from .passes.autosp import apply_autosp +from .passes.sp_compile import apply_autosp -def init_ulysses(): +def init_autosp(): def backend_fn(gm: GraphModule, real_inputs): - apply_autosp(gm, real_inputs, debug_log=False) + apply_autosp(gm, real_inputs, debug=False) return torch._inductor.compile(gm, real_inputs) return backend_fn diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 5963468d98ec..5d4fc0a01751 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -25,8 +25,6 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder -from .fx import find_node_by_name, find_node_by_tag, get_node_shape_meta, replace_node_users - INPUT_ID_KEY = "input_id" LABEL_ID_KEY = "label_id" POSITION_ID_KEY = "position_id" @@ -548,18 +546,21 @@ def get_sdpa_nodes(gm: GraphModule) -> List[Node]: )) def get_input_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag node = find_node_by_tag(gm, INPUT_ID_KEY) if node is None: raise RuntimeError("Failed to find a node for the input sequence.") return node def get_label_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag node = find_node_by_tag(gm, LABEL_ID_KEY) if node is None: raise RuntimeError("Failed to find a node for the label.") return node def get_position_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag node = find_node_by_tag(gm, POSITION_ID_KEY) return node @@ -597,6 +598,7 @@ def shard_tensor_node( tensor_node: Node, config: ShardingConfig ): + from .fx import find_node_by_name, get_node_shape_meta, replace_node_users val = get_node_shape_meta(tensor_node) assert val is not None, f"Node {tensor_node.name} has no shape metadata" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e6d838df5adf..eb71a7c51765 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -127,6 +127,7 @@ from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states from deepspeed.compile.init_z1 import init_z1 from deepspeed.compile.init_z3 import init_z3 +from deepspeed.compile.init_sp import init_autosp MEMORY_OPT_ALLREDUCE_SIZE = 500000000 @@ -4361,7 +4362,8 @@ def compile(self, enable_deepcompile = self.is_deepcompile_enabled() if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ and self.zero_optimization_stage() != ZeroStageEnum.weights \ - and self.zero_optimization_stage() != ZeroStageEnum.gradients: + and self.zero_optimization_stage() != ZeroStageEnum.gradients \ + and self.zero_optimization_stage() != ZeroStageEnum.disabled: logger.info( f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." ) @@ -4396,6 +4398,8 @@ def passes_name_to_fn(passes): "DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. " "Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.") backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.disabled: + backend = init_autosp() # Hook state must align with whether DeepCompile is active. self._set_deepcompile_active(enable_deepcompile) From ecbc6eaad635ee08ed23b573133ae11f5d16c411 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sun, 22 Feb 2026 20:30:35 -0600 Subject: [PATCH 03/14] move bench scripts to DeepSpeedExamples --- bench_dc_ulysses/.gitignore | 10 - bench_dc_ulysses/README.md | 51 -- bench_dc_ulysses/configs/config.yaml | 16 - .../configs/deepcompile_config.json | 19 - .../configs/deepcompile_config.yaml | 16 - bench_dc_ulysses/configs/ds_config.json | 25 - .../configs/ds_config.json.template | 30 - .../configs/ds_config.yaml.template | 19 - .../configs/torchcompile_config.json | 14 - .../configs/torchcompile_config.yaml | 16 - bench_dc_ulysses/distributed_attention.py | 104 ---- bench_dc_ulysses/gen_chart_acc_steps.py | 263 --------- bench_dc_ulysses/generate_conf.py | 42 -- bench_dc_ulysses/hostfile_n1 | 1 - bench_dc_ulysses/hostfile_n2 | 1 - bench_dc_ulysses/hostfile_n4 | 1 - bench_dc_ulysses/launch_jobs.sh | 13 - bench_dc_ulysses/ring_attention.py | 530 ------------------ bench_dc_ulysses/run.sh | 56 -- bench_dc_ulysses/run_acc_lm.py | 251 --------- bench_dc_ulysses/run_bench.sh | 19 - bench_dc_ulysses/run_bench_acc.sh | 42 -- bench_dc_ulysses/run_correctness_test.sh | 4 - bench_dc_ulysses/run_multinode.sh | 112 ---- bench_dc_ulysses/run_ulysses.sh | 65 --- bench_dc_ulysses/sample.slurm | 43 -- bench_dc_ulysses/sp_dp_registry.py | 45 -- deepspeed/compile/fx.py | 2 +- 28 files changed, 1 insertion(+), 1809 deletions(-) delete mode 100644 bench_dc_ulysses/.gitignore delete mode 100644 bench_dc_ulysses/README.md delete mode 100644 bench_dc_ulysses/configs/config.yaml delete mode 100644 bench_dc_ulysses/configs/deepcompile_config.json delete mode 100644 bench_dc_ulysses/configs/deepcompile_config.yaml delete mode 100644 bench_dc_ulysses/configs/ds_config.json delete mode 100644 bench_dc_ulysses/configs/ds_config.json.template delete mode 100644 bench_dc_ulysses/configs/ds_config.yaml.template delete mode 100644 bench_dc_ulysses/configs/torchcompile_config.json delete mode 100644 bench_dc_ulysses/configs/torchcompile_config.yaml delete mode 100644 bench_dc_ulysses/distributed_attention.py delete mode 100644 bench_dc_ulysses/gen_chart_acc_steps.py delete mode 100644 bench_dc_ulysses/generate_conf.py delete mode 100644 bench_dc_ulysses/hostfile_n1 delete mode 100644 bench_dc_ulysses/hostfile_n2 delete mode 100644 bench_dc_ulysses/hostfile_n4 delete mode 100755 bench_dc_ulysses/launch_jobs.sh delete mode 100644 bench_dc_ulysses/ring_attention.py delete mode 100755 bench_dc_ulysses/run.sh delete mode 100644 bench_dc_ulysses/run_acc_lm.py delete mode 100755 bench_dc_ulysses/run_bench.sh delete mode 100755 bench_dc_ulysses/run_bench_acc.sh delete mode 100755 bench_dc_ulysses/run_correctness_test.sh delete mode 100755 bench_dc_ulysses/run_multinode.sh delete mode 100755 bench_dc_ulysses/run_ulysses.sh delete mode 100644 bench_dc_ulysses/sample.slurm delete mode 100644 bench_dc_ulysses/sp_dp_registry.py diff --git a/bench_dc_ulysses/.gitignore b/bench_dc_ulysses/.gitignore deleted file mode 100644 index 4b197544883c..000000000000 --- a/bench_dc_ulysses/.gitignore +++ /dev/null @@ -1,10 +0,0 @@ -*.log -*.pyc -profiles -results -slurm_jobs -slurm* -experiments -logs -*. -*.pt diff --git a/bench_dc_ulysses/README.md b/bench_dc_ulysses/README.md deleted file mode 100644 index 4f0c68a3458d..000000000000 --- a/bench_dc_ulysses/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# Benchmark for DeepCompile - -## Setup - -This experiment scripts require 1 node that has 2 A100/A40 GPUs -We tested the scripts with Python 3.10.12 and CUDA 12.3. - -### Libraries - -In addition, you need to install the following: - -- PyTorch 2.5.1 -- [modified version of DeepSpeed](https://github.com/tohtana/DeepSpeed-internal/tree/neeld2/debug-loss) - -Here are an example of installation commands: - -```bash -pip3 install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 -pip3 install datasets==3.1 accelerate - -# Install DeepSpeed and DeepCompile -git clone -b neeld2/debug-loss https://github.com/tohtana/DeepSpeed-internal.git -cd DeepSpeed-internal -pip install -e transformers -cd .. -pip install -e DeepSpeed - -# Clone this repository -git clone https://github.com/neeldani/bench_dc_ulysses.git -``` - -## Running the scripts - -Test the setup by running the script: -```bash -bash run_ulysses.sh 6 [compile|deepcompile|eager|ringattn] -``` - -Here, 6 is the sequence length and is hardcoded because the input sequence inside run_acc_lm.py is hardcoded to easily verify the Q, K and V before and after the all-to-all. You may use pass `compile` to run compiled Ulysses (Ulysses with graph breaks) or `deepcompile` to run deepcompiled Ulysses (allwall inserted within the compiler pass) - -We save the Q, K and V tensors before and after the all-toa-all: -For deepcompiled Ulysses, the tensors are saved here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/deepspeed/compile/patch_aot_module.py#L243 - -For compiled Ulysses, the tensors are saved here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/deepspeed/sequence/layer.py#L381 - -You can then run the script [check_qkv.py](https://github.com/neeldani/bench_dc_ulysses/blob/main/check_qkv.py) to compare the tensors at various stages i.e before all2all, after all2all, attention outputs etc - -## Code walkthrough -1. Script: [run_ulyssess.sh](https://github.com/neeldani/bench_dc_ulysses/blob/main/run_ulysses.sh) -2. The script calls: [run_acc_lm.py](https://github.com/neeldani/bench_dc_ulysses/blob/main/run_acc_lm.py). We have added support for another attention backend in HuggingFace called "ulysses" which uses DistributedAttention. The implementation can be found here: https://github.com/tohtana/DeepSpeed-internal/blob/60feb352a6b0e22cf9a781b4e387d3919dc76833/transformers/src/transformers/models/llama/modeling_llama.py#L306 -3. If the `deepcompile` arg is passed to the config file, then a compiler pass will add the all2all's directy at the Torch IR level. The code for it can be found here: https://github.com/tohtana/DeepSpeed-internal/blob/neeld2/debug-loss/deepspeed/compile/patch_aot_module.py diff --git a/bench_dc_ulysses/configs/config.yaml b/bench_dc_ulysses/configs/config.yaml deleted file mode 100644 index 254b83004286..000000000000 --- a/bench_dc_ulysses/configs/config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - deepspeed_multinode_launcher: standard - deepspeed_config_file: configs/ds_config.json -distributed_type: DEEPSPEED -machine_rank: 1 -main_training_function: main -num_machines: 1 -num_processes: 2 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file diff --git a/bench_dc_ulysses/configs/deepcompile_config.json b/bench_dc_ulysses/configs/deepcompile_config.json deleted file mode 100644 index 93deb7402e19..000000000000 --- a/bench_dc_ulysses/configs/deepcompile_config.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - - "bf16": { - "enabled": true - }, - - "zero_optimization":{ - "stage": 0 - }, - "compile": { - "deepcompile": true - }, - "gradient_accumulation_steps": 1, - "gradient_clipping": "auto", - "steps_per_print": 2000, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} \ No newline at end of file diff --git a/bench_dc_ulysses/configs/deepcompile_config.yaml b/bench_dc_ulysses/configs/deepcompile_config.yaml deleted file mode 100644 index 405eb7163508..000000000000 --- a/bench_dc_ulysses/configs/deepcompile_config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - deepspeed_multinode_launcher: standard - deepspeed_config_file: configs/deepcompile_config.json -distributed_type: DEEPSPEED -machine_rank: 0 -main_training_function: main -num_machines: 1 -num_processes: 2 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/bench_dc_ulysses/configs/ds_config.json b/bench_dc_ulysses/configs/ds_config.json deleted file mode 100644 index 548aeaac6b78..000000000000 --- a/bench_dc_ulysses/configs/ds_config.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - - "bf16": { - "enabled": true - }, - - "zero_optimization":{ - "stage": 0 - }, - "compile": { - "deepcompile": true, - "offload_activation": false, - "offload_opt_states": false, - "double_buffer": true, - "symmetric_memory": false, - "free_activation": false, - "dump_graphs": false - }, - "gradient_accumulation_steps": 1, - "gradient_clipping": "auto", - "steps_per_print": 2000, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} \ No newline at end of file diff --git a/bench_dc_ulysses/configs/ds_config.json.template b/bench_dc_ulysses/configs/ds_config.json.template deleted file mode 100644 index 9afb8b6a7159..000000000000 --- a/bench_dc_ulysses/configs/ds_config.json.template +++ /dev/null @@ -1,30 +0,0 @@ -{ - {% if fp16 %} - "fp16": { - "enabled": true, - "initial_scale_power": 8 - }, - {% else %} - "bf16": { - "enabled": true - }, - {% endif %} - "zero_optimization":{ - "stage": 0 - }, - "compile": { - "deepcompile": {{ deepcompile }}, - "offload_activation": false, - "offload_opt_states": false, - "double_buffer": true, - "symmetric_memory": false, - "free_activation": false, - "dump_graphs": false - }, - "gradient_accumulation_steps": {{ gradient_accumulation_steps }}, - "gradient_clipping": "auto", - "steps_per_print": 2000, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} \ No newline at end of file diff --git a/bench_dc_ulysses/configs/ds_config.yaml.template b/bench_dc_ulysses/configs/ds_config.yaml.template deleted file mode 100644 index f130fbea7f98..000000000000 --- a/bench_dc_ulysses/configs/ds_config.yaml.template +++ /dev/null @@ -1,19 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - deepspeed_multinode_launcher: standard - {%- if zero_stage == 3 %} - zero3_init_flag: true - {%- endif %} - deepspeed_config_file: configs/ds_config.json -distributed_type: DEEPSPEED -machine_rank: {{ machine_rank }} -main_training_function: main -num_machines: {{ num_machines }} -num_processes: {{ num_processes }} -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file diff --git a/bench_dc_ulysses/configs/torchcompile_config.json b/bench_dc_ulysses/configs/torchcompile_config.json deleted file mode 100644 index d61b17b9f047..000000000000 --- a/bench_dc_ulysses/configs/torchcompile_config.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "bf16": { - "enabled": true - }, - "zero_optimization":{ - "stage": 0 - }, - "gradient_accumulation_steps": 1, - "gradient_clipping": "auto", - "steps_per_print": 2000, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} diff --git a/bench_dc_ulysses/configs/torchcompile_config.yaml b/bench_dc_ulysses/configs/torchcompile_config.yaml deleted file mode 100644 index cebc281c2b97..000000000000 --- a/bench_dc_ulysses/configs/torchcompile_config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - deepspeed_multinode_launcher: standard - deepspeed_config_file: configs/torchcompile_config.json -distributed_type: DEEPSPEED -machine_rank: 0 -main_training_function: main -num_machines: 1 -num_processes: 2 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file diff --git a/bench_dc_ulysses/distributed_attention.py b/bench_dc_ulysses/distributed_attention.py deleted file mode 100644 index 32d87eae2bef..000000000000 --- a/bench_dc_ulysses/distributed_attention.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -import torch -import torch.distributed as dist -from deepspeed.sequence.layer import DistributedAttention -from sp_dp_registry import get_group, is_setup, sp_size - -#TODO: Hacky, need to fix it -_padding_mask_context = None - -def set_padding_mask(mask): - global _padding_mask_context - _padding_mask_context = mask - -def get_padding_mask(): - global _padding_mask_context - return _padding_mask_context - -def clear_padding_mask(): - global _padding_mask_context - _padding_mask_context = None - -def ulysses_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - scaling=None, - dropout=0.0, - is_causal=True, - **kwargs, -): - assert is_setup(), 'Incorrectly setup SP/DP Groups.' - - gid = dist.get_rank() // sp_size() - group = get_group(gid) - - # Ulysses expects (batch, seq, heads, dim) - # HF standard provides (batch, heads, seq, dim) - q = query_states.transpose(1, 2).contiguous() - k = key_states.transpose(1, 2).contiguous() - v = value_states.transpose(1, 2).contiguous() - - if not hasattr(self, "ulysses_engine"): - self.ulysses_engine = DistributedAttention( - sdpa_wrapper, - group, - scatter_idx=2, # Shard heads - gather_idx=1 # Gather sequences - ) - - # b, s, n, h - # Note: we don't pass attention_mask here because it's the 4D mask created by HF - # based on sharded dimensions. We'll create the correct mask in sdpa_wrapper - # using the original unsharded padding mask stored in context. - attn_output = self.ulysses_engine( - q, k, v, - batch_dim_idx=0, - attn_mask=None, - dropout_p=dropout, - is_causal=False, - scale=scaling - ) - - # Return to HF format: (batch, seq, heads, dim) -> (batch, heads, seq, dim) - # Note: Transformers usually expects (B, N, S, H) back, - # but Llama's forward handles the reshape if we are careful. - return attn_output, None - -def sdpa_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None): - # Permute from [b, s, n, h] to [b, n, s, h] for SDPA - q = query.permute(0, 2, 1, 3).contiguous() - k = key.permute(0, 2, 1, 3).contiguous() - v = value.permute(0, 2, 1, 3).contiguous() - - # Create the attention mask from padding mask + causal mask - padding_mask = get_padding_mask() - combined_mask = None - - if padding_mask is not None: - B, S = padding_mask.shape # [B, S] - device = padding_mask.device - - causal_mask = torch.tril(torch.ones(S, S, device=device, dtype=torch.bool)) - padding_mask_bool = (padding_mask != 0).unsqueeze(1) # [B, 1, S] - causal_expanded = causal_mask.unsqueeze(0) # [1, S, S] - combined_mask = causal_expanded & padding_mask_bool # [B, S, S] - combined_mask = combined_mask.unsqueeze(1) # [B, 1, S, S] - - elif is_causal: - pass - - output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, - attn_mask=combined_mask, - dropout_p=dropout_p, - is_causal=(combined_mask is None and is_causal), - scale=scale, - enable_gqa=False - ) - - # Permute back from [b, n, s, h] to [b, s, n, h] for all-to-all on output - output = output.permute(0, 2, 1, 3).contiguous() - return output diff --git a/bench_dc_ulysses/gen_chart_acc_steps.py b/bench_dc_ulysses/gen_chart_acc_steps.py deleted file mode 100644 index 8b3cbd9201ab..000000000000 --- a/bench_dc_ulysses/gen_chart_acc_steps.py +++ /dev/null @@ -1,263 +0,0 @@ -import argparse -import re -import pandas as pd -import matplotlib.pyplot as plt -from pathlib import Path - -def throughput_calculator(micro_batch_size, acc_steps, np, elapsed_time_per_iter, - hidden_size, num_attention_heads, num_key_value_heads, - ffn_hidden_size, num_layers, padded_vocab_size, seq_len, - topk: int, swiglu: bool, checkpoint_activations: bool): - batch_size = micro_batch_size * acc_steps * np - samples_per_second = batch_size / elapsed_time_per_iter - - head_dim = hidden_size // num_attention_heads - gqa = num_attention_heads // num_key_value_heads - ffn_multiplier = 3 if swiglu else 2 - macs_per_flops = 2 - - pre_and_post_mha_gemm_macs = batch_size * num_layers * (1 + (2 // gqa) + 1) * (hidden_size**2) * seq_len - mha_bgemm_macs = batch_size * num_layers * 2 * head_dim * num_attention_heads * (seq_len**2) - ffn_gemm_macs = batch_size * num_layers * ffn_multiplier * ffn_hidden_size * hidden_size * seq_len * topk - logit_lmhead_gemm_macs = batch_size * padded_vocab_size * hidden_size * seq_len - - fwd_macs = pre_and_post_mha_gemm_macs + mha_bgemm_macs + ffn_gemm_macs + logit_lmhead_gemm_macs - bwd_macs = 2 * fwd_macs - fwd_bwd_macs = fwd_macs + bwd_macs - - if checkpoint_activations: - fwd_bwd_macs += fwd_macs - - flops_per_iteration = fwd_bwd_macs * macs_per_flops - tflops = flops_per_iteration / (elapsed_time_per_iter * np * (10**12)) - return samples_per_second, tflops - - -model_info = { - "meta-llama/Meta-Llama-3-8B": { - "hidden_size": 4096, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "ffn_hidden_size": 16384, - "num_layers": 32, - "padded_vocab_size": 32000, - "topk": 1, - "swiglu": True # Meta-Llama-3ではswigluが使われていると仮定 - }, - "meta-llama/Meta-Llama-3-70B-Instruct": { - "hidden_size": 8192, - "num_attention_heads": 64, - "num_key_value_heads": 8, - "ffn_hidden_size": 32768, - "num_layers": 80, - "padded_vocab_size": 32000, - "topk": 1, - "swiglu": True # Meta-Llama-3ではswigluが使われていると仮定 - }, - "mistralai/Mixtral-8x7B-v0.1": { - "hidden_size": 4096, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "ffn_hidden_size": 16384, - "num_layers": 32, - "padded_vocab_size": 32000, - "topk": 2, # MixtralではMoEで2エキスパート - "swiglu": False # Mistralはswigluを使っていないと仮定 - } -} - -parser = argparse.ArgumentParser(description="Plot performance metrics.") -parser.add_argument("--metric", choices=["iteration_time", "throughput", "flops", "mfu", "peak_mem"], required=True, - help="Metric to plot: 'iteration_time', 'flops', 'mfu', or 'peak_mem'") -parser.add_argument("--result_dir", type=str, required=True, help="Path to the directory containing results.txt") -parser.add_argument("--result_file", type=str, default="results.txt", help="Name of the result file") -args = parser.parse_args() - - -# データのパース -pattern = re.compile( - r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) " - r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) iteration time: (?P[\d.]+) " - r"alloc_mem: (?P\d+) peak_mem: (?P\d+)" -) -pattern_ctime = re.compile( - r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) " - r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) " - r"alloc_mem: (?P\d+) peak_mem: (?P\d+)" -) -pattern_cs = re.compile( - r"(?P\d+) (?P[\w./-]+) ds=(?P\w+) np=(?P\d+) batch_size=(?P\d+) " - r"seq=(?P\d+) acc=(?P\d+) ac=(?P\w+) compile=(?P\w+) schedule=(?P\w+) passes=(?P[\w,_]+) compile_time=(?P[\d.]+) iteration time: (?P[\d.]+) " - r"alloc_mem: (?P\d+) peak_mem: (?P\d+)" -) - -file = Path(args.result_dir) / args.result_file -matches = [] -with open(file) as f: - for line in f: - match = pattern.match(line) - if not match: - match = pattern_ctime.match(line) - if not match: - match = pattern_cs.match(line) - if not match: - print(f"Not matched: {line}") - if match: - d = match.groupdict() - if "passes" not in d: - d["passes"] = "" - if "compile_time" not in d: - d["compile_time"] = 0 - if "schedule" not in d: - d["schedule"] = d["compile"] - matches.append(d) - -df = pd.DataFrame(matches) - -# 型変換 -df["ds"] = df["ds"] == "True" -df["compile"] = df["compile"] == "True" -df["np"] = df["np"].astype(int) -df["batch_size"] = df["batch_size"].astype(int) # batch_sizeをfloatに変換 -df["seq"] = df["seq"].astype(int) -df["iteration_time"] = df["iteration_time"].astype(float) # iteration_timeをfloatに変換 -df["alloc_mem"] = df["alloc_mem"].astype(float) -df["peak_mem"] = df["peak_mem"].astype(float) -df["acc"] = df["acc"].astype(int) # accも明示的にint型へ -df["ac"] = df["ac"] == "True" # acを真偽値に変換 -df["compile_time"] = df["compile_time"].astype(float) -df["schedule"] = df["schedule"] == "True" - - -# モデルごとの計算とプロット -grouped = df.groupby(["model", "np", "batch_size"]) - -theoretical_peak = 312 # 理論ピーク性能 (TFLOPS) - - -LABEL_ZERO3 = "ZeRO3" -LABEL_ZERO3_C = "ZeRO3 (C)" -LABEL_FSDP = "FSDP" -LABEL_DC_PS = "DeepCompile (P+S)" -LABEL_DC_P = "DeepCompile (P)" -LABEL_DC_S = "DeepCompile (S)" - -for (model, np, batch_size), group in grouped: - group = group.sort_values("acc") - acc_labels = group["acc"].unique() - - print(f"acc_labels: {acc_labels}") - - metric_values = {LABEL_ZERO3: [0] * len(acc_labels), - LABEL_ZERO3_C: [0] * len(acc_labels), - LABEL_FSDP: [0] * len(acc_labels), - LABEL_DC_PS: [0] * len(acc_labels), - LABEL_DC_P: [0] * len(acc_labels), - LABEL_DC_S: [0] * len(acc_labels)} - - for _, row in group.iterrows(): - - if row["ds"] and not row["compile"]: - category = LABEL_ZERO3 - elif not row["ds"] and not row["compile"]: - category = LABEL_FSDP - elif row["ds"] and row["compile"]: - if not row["schedule"]: - category = LABEL_ZERO3_C - elif row["passes"] == "" or row["passes"] == 'prefetch,selective_gather': - category = LABEL_DC_PS - # print(f"found prefetch,selective_gather") - elif row["passes"] == 'prefetch': - category = LABEL_DC_P - # print(f"found prefetch") - elif row["passes"] == 'selective_gather': - category = LABEL_DC_S - # print(f"found selective_gather") - else: - print(f"Unknown category: {row}") - continue - else: - print(f"Unknown category: {row}") - continue - - acc_index = list(acc_labels).index(row["acc"]) - if args.metric == "iteration_time": - metric_values[category][acc_index] = row["iteration_time"] - elif args.metric == "peak_mem": - metric_values[category][acc_index] = row["peak_mem"] / (1024**3) - elif args.metric == "throughput": - metric_values[category][acc_index] = row["batch_size"] * row["seq"] * row["acc"] / row["iteration_time"] - elif args.metric in ["flops", "mfu"]: - # モデル情報を使用して FLOPs を計算 - model_params = model_info[row["model"]] - samples_per_second, tflops = throughput_calculator( - micro_batch_size=row["batch_size"], - acc_steps=row["acc"], # ログから取得 - np=row["np"], - elapsed_time_per_iter=row["iteration_time"], - hidden_size=model_params["hidden_size"], - num_attention_heads=model_params["num_attention_heads"], - num_key_value_heads=model_params["num_key_value_heads"], - ffn_hidden_size=model_params["ffn_hidden_size"], - num_layers=model_params["num_layers"], - padded_vocab_size=model_params["padded_vocab_size"], - seq_len=row["seq"], - topk=model_params["topk"], - swiglu=model_params["swiglu"], # モデル定義から取得 - checkpoint_activations=row["ac"] # ログから取得 - ) - if args.metric == "flops": - metric_values[category][acc_index] = tflops - elif args.metric == "mfu": - metric_values[category][acc_index] = tflops / theoretical_peak - - # グラフ作成 - x = range(len(acc_labels)) - width = 0.15 # 棒グラフの幅 - ylabel = { - "iteration_time": "Iteration Time (s)", - "flops": "TFLOPS", - "throughput": "Throughput (tokens/s/GPU)", - "mfu": "MFU", - "peak_mem": "Peak Memory (GB)" - }[args.metric] - - plt.figure(figsize=(10, 8)) - adjust = - 0.5 * width - plt.bar([i - width*2 + adjust for i in x], metric_values[LABEL_ZERO3], width, label=LABEL_ZERO3, alpha=0.7) - plt.bar([i - width + adjust for i in x], metric_values[LABEL_ZERO3_C], width, label=LABEL_ZERO3_C, alpha=0.7) - plt.bar([i + adjust for i in x], metric_values[LABEL_FSDP], width, label=LABEL_FSDP, alpha=0.7) - plt.bar([i + width + adjust for i in x], metric_values[LABEL_DC_P], width, label=LABEL_DC_P, alpha=0.7) - plt.bar([i + width*2 + adjust for i in x], metric_values[LABEL_DC_S], width, label=LABEL_DC_S, alpha=0.7) - plt.bar([i + width*3 + adjust for i in x], metric_values[LABEL_DC_PS], width, label=LABEL_DC_PS, alpha=0.7) - - gain_zero3 = [metric_values[LABEL_DC_PS][i] / metric_values[LABEL_ZERO3][i] for i in range(len(acc_labels))] - print(f"model {model} np {np} batch_size {batch_size} {LABEL_ZERO3} metric_values: {metric_values[LABEL_ZERO3]} gain_zero3: {gain_zero3}") - print(f"model {model} np {np} batch_size {batch_size} {LABEL_DC_PS} metric_values: {metric_values[LABEL_DC_PS]}") - - model = model.split('/')[1] - model = model.replace("Meta-Llama-3-8B", "Llama-3-8B") - model = model.replace("Meta-Llama-3-70B-Instruct", "Llama-3-70B") - model = model.replace("Mixtral-8x7B-v0.1", "Mixtral-8x7B") - - plt.title(f"Model: {model}, #GPUs: {np}, Batch Size: {batch_size}", fontsize=24) - plt.xlabel("Acc Steps", fontsize=24) - plt.ylabel(ylabel, fontsize=24) - plt.xticks(x, acc_labels, fontsize=24) - - if args.metric == "peak_mem": - plt.ylim(0, 80) - - plt.yticks(fontsize=20) - plt.legend(loc="lower right", fontsize=18) - plt.grid(axis="y") - - # ファイル保存 - metric_name = args.metric - model = model.replace("/", "_") - chart_dir = Path(args.result_dir) / Path(metric_name) - chart_dir.mkdir(parents=True, exist_ok=True) - conf_str = f"{metric_name}_{model}_np{np}_bs{batch_size}" - img_path = chart_dir / f"chart_{conf_str}.png" - plt.savefig(str(img_path)) - plt.close() diff --git a/bench_dc_ulysses/generate_conf.py b/bench_dc_ulysses/generate_conf.py deleted file mode 100644 index 29fa1c4f4c96..000000000000 --- a/bench_dc_ulysses/generate_conf.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import argparse -from jinja2 import Template -from pathlib import Path - -def get_args(): - parser = argparse.ArgumentParser(description='Config generation') - - parser.add_argument('--machine_rank', type=int, help='machine_rank') - parser.add_argument('--num_machines', type=int, help='num_machines') - parser.add_argument('--num_processes', type=int, help='num_processes') - parser.add_argument('--zero_stage', type=int, choices=[0, 1, 2, 3], help='ZeRO stage') - parser.add_argument('--fp16', action='store_true', help='Use fp16') - parser.add_argument('--gradient_accumulation_steps', type=int, default=1) - parser.add_argument('--deepcompile', action='store_true', help='Use deepcompile') - - parser.add_argument('--template_file', type=Path, help='Template file') - parser.add_argument('--output_file', type=Path, help='Output file') - - return parser.parse_args() - - -def main(args): - with open(args.template_file, 'r') as f: - template = Template(f.read()) - - with open(args.output_file, 'w') as f: - f.write(template.render(machine_rank=args.machine_rank, - num_machines=args.num_machines, - num_processes=args.num_processes, - zero_stage=args.zero_stage, - fp16=args.fp16, - gradient_accumulation_steps=args.gradient_accumulation_steps, - deepcompile=str(args.deepcompile).lower())) - -if __name__ == '__main__': - args = get_args() - main(args) diff --git a/bench_dc_ulysses/hostfile_n1 b/bench_dc_ulysses/hostfile_n1 deleted file mode 100644 index f81666ed14a0..000000000000 --- a/bench_dc_ulysses/hostfile_n1 +++ /dev/null @@ -1 +0,0 @@ -node-0 slots=1 diff --git a/bench_dc_ulysses/hostfile_n2 b/bench_dc_ulysses/hostfile_n2 deleted file mode 100644 index 5d6bf941211b..000000000000 --- a/bench_dc_ulysses/hostfile_n2 +++ /dev/null @@ -1 +0,0 @@ -node-0 slots=2 diff --git a/bench_dc_ulysses/hostfile_n4 b/bench_dc_ulysses/hostfile_n4 deleted file mode 100644 index 5d6bf941211b..000000000000 --- a/bench_dc_ulysses/hostfile_n4 +++ /dev/null @@ -1 +0,0 @@ -node-0 slots=2 diff --git a/bench_dc_ulysses/launch_jobs.sh b/bench_dc_ulysses/launch_jobs.sh deleted file mode 100755 index e59e4a901706..000000000000 --- a/bench_dc_ulysses/launch_jobs.sh +++ /dev/null @@ -1,13 +0,0 @@ -# launch job_*.slurm in slurm_jobs -# Usage: bash launch_jobs.sh - -# delete .out files in slurm_out -# for out in slurm_out/*.out; do -# rm -vf $out -# done - - -for job in slurm_jobs/job_*.slurm; do - # echo "Submitting job $job" - sbatch $job -done \ No newline at end of file diff --git a/bench_dc_ulysses/ring_attention.py b/bench_dc_ulysses/ring_attention.py deleted file mode 100644 index 7b01da7b96cf..000000000000 --- a/bench_dc_ulysses/ring_attention.py +++ /dev/null @@ -1,530 +0,0 @@ -## Code is taken directly from the RingFlashAttention -## repository: https://github.com/zhuzilin/ring-flash-attention -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import inspect -from functools import cache - -from sp_dp_registry import get_group, is_setup, sp_size -from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward - -__all__ = ["update_out_and_lse", "RingComm", "get_default_args"] - -## Utility communication files. ## -@cache -def _get_default_args(func): - spec = inspect.getfullargspec(func) - defaults = spec.defaults if spec.defaults is not None else () - padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults - args = dict(zip(spec.args, padded_defaults)) - if "softcap" in args: - args["softcap"] = 0.0 - return args - - -def get_default_args(func): - if inspect.isfunction(func): - return _get_default_args(func) - else: - # Use the origin _init_fn in CustomOpDef - return _get_default_args(func._init_fn) - - -@torch.jit.script -def _update_out_and_lse( - out: torch.Tensor, - lse: torch.Tensor, - block_out: torch.Tensor, - block_lse: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - - block_out = block_out.to(torch.float32) - block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - - # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out - # For additional context and discussion, please refer to: - # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - - return out, lse - - -def update_out_and_lse( - out: Optional[torch.Tensor], - lse: Optional[torch.Tensor], - block_out: torch.Tensor, - block_lse: torch.Tensor, - slice_=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - if out is None: - if slice_ is not None: - raise RuntimeError("first update_out_and_lse should not pass slice_ args") - out = block_out.to(torch.float32) - lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - elif slice_ is not None: - slice_out, slice_lse = out[slice_], lse[slice_] - slice_out, slice_lse = _update_out_and_lse( - slice_out, slice_lse, block_out, block_lse - ) - out[slice_], lse[slice_] = slice_out, slice_lse - else: - out, lse = _update_out_and_lse(out, lse, block_out, block_lse) - return out, lse - - -@torch.jit.script -def flatten_varlen_lse(lse, cu_seqlens): - new_lse = [] - for i in range(len(cu_seqlens) - 1): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse.append(lse[i, :, : end - start]) - return torch.cat(new_lse, dim=1) - - -@torch.jit.script -def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): - num_seq = len(cu_seqlens) - 1 - num_head = lse.shape[-2] - new_lse = torch.empty( - (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device - ) - for i in range(num_seq): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - new_lse[i, : end - start] = lse[start:end] - return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() - - -class RingComm: - def __init__(self, process_group: dist.ProcessGroup): - self._process_group = process_group - self._ops = [] - self.rank = dist.get_rank(self._process_group) - self.world_size = dist.get_world_size(self._process_group) - self._reqs = None - - self.send_rank = (self.rank + 1) % self.world_size - self.recv_rank = (self.rank - 1) % self.world_size - - if process_group is not None: - self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) - self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) - - def send_recv( - self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None - ) -> torch.Tensor: - if recv_tensor is None: - res = torch.empty_like(to_send) - else: - res = recv_tensor - - send_op = dist.P2POp( - dist.isend, to_send, self.send_rank, group=self._process_group - ) - recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) - self._ops.append(send_op) - self._ops.append(recv_op) - return res - - def commit(self): - if self._reqs is not None: - raise RuntimeError("commit called twice") - self._reqs = dist.batch_isend_irecv(self._ops) - - def wait(self): - if self._reqs is None: - raise RuntimeError("wait called before commit") - for req in self._reqs: - req.wait() - self._reqs = None - self._ops = [] - - def send_recv_kv( - self, - k: torch.Tensor, - v: torch.Tensor, - k_buffer: Optional[torch.Tensor] = None, - v_buffer: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) - self.commit() - return next_k, next_v - - -class AllGatherComm: - def __init__(self, group=None) -> None: - self.group = group - self.handles = [] - - def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): - handle = dist.all_gather_into_tensor( - output_tensor, input_tensor, group=self.group, async_op=True - ) - self.handles.append(handle) - - def wait(self): - for handle in self.handles: - handle.wait() - self.handles = [] - - -def ring_flash_attn_forward( - process_group, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - comm = RingComm(process_group) - - out = None - lse = None - - next_k, next_v = None, None - - for step in range(comm.world_size): - if step + 1 != comm.world_size: - next_k, next_v = comm.send_recv_kv(k, v) - - if not causal or step <= comm.rank: - params = get_default_args(_flash_attn_forward).copy() - params.update( - { - "q": q, - "k": k, - "v": v, - "dropout_p": dropout_p, - "softmax_scale": softmax_scale, - "causal": causal and step == 0, - "alibi_slopes": alibi_slopes, - "return_softmax": True and dropout_p > 0, - } - ) - if "window_size" in params: - params.update({"window_size": window_size}) - else: - params.update( - { - "window_size_left": window_size[0], - "window_size_right": window_size[1], - } - ) - outputs = _flash_attn_forward(**params) - if len(outputs) == 8: - block_out, _, _, _, _, block_lse, _, _ = outputs - else: - assert len(outputs) == 4 - block_out, block_lse, _, _ = outputs - out, lse = update_out_and_lse(out, lse, block_out, block_lse) - - if step + 1 != comm.world_size: - comm.wait() - k, v = next_k, next_v - - out = out.to(q.dtype) - lse = lse.squeeze(dim=-1).transpose(1, 2) - return out, lse - - -def ring_flash_attn_backward( - process_group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - dropout_p=0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, -): - kv_comm = RingComm(process_group) - d_kv_comm = RingComm(process_group) - dq, dk, dv = None, None, None - next_dk, next_dv = None, None - - block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) - block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) - block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) - - next_dk, next_dv = None, None - next_k, next_v = None, None - - for step in range(kv_comm.world_size): - if step + 1 != kv_comm.world_size: - next_k, next_v = kv_comm.send_recv_kv(k, v) - - if step <= kv_comm.rank or not causal: - bwd_causal = causal and step == 0 - params = get_default_args(_flash_attn_backward).copy() - params.update( - { - "dout": dout, - "q": q, - "k": k, - "v": v, - "out": out, - "softmax_lse": softmax_lse, - "dq": block_dq_buffer, - "dk": block_dk_buffer, - "dv": block_dv_buffer, - "dropout_p": dropout_p, - "softmax_scale": softmax_scale, - "causal": bwd_causal, - "alibi_slopes": alibi_slopes, - "deterministic": deterministic, - } - ) - if "window_size" in params: - params.update({"window_size": window_size}) - else: - params.update( - { - "window_size_left": window_size[0], - "window_size_right": window_size[1], - } - ) - _flash_attn_backward(**params) - - if dq is None: - dq = block_dq_buffer.to(torch.float32) - dk = block_dk_buffer.to(torch.float32) - dv = block_dv_buffer.to(torch.float32) - else: - dq += block_dq_buffer - d_kv_comm.wait() - dk = block_dk_buffer + next_dk - dv = block_dv_buffer + next_dv - elif step != 0: - d_kv_comm.wait() - dk, dv = next_dk, next_dv - - if step + 1 != kv_comm.world_size: - kv_comm.wait() - k, v = next_k, next_v - - next_dk, next_dv = d_kv_comm.send_recv_kv(dk, dv) - - d_kv_comm.wait() - - return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) - - -class RingFlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - group, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - assert alibi_slopes is None - k = k.contiguous() - v = v.contiguous() - out, softmax_lse = ring_flash_attn_forward( - group, - q, - k, - v, - softmax_scale=softmax_scale, - dropout_p=dropout_p, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - deterministic=False, - ) - # this should be out_padded - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.group = group - return out if not return_softmax else (out, softmax_lse, None) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = ring_flash_attn_backward( - ctx.group, - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale=ctx.softmax_scale, - dropout_p=ctx.dropout_p, - causal=ctx.causal, - window_size=ctx.window_size, - alibi_slopes=ctx.alibi_slopes, - deterministic=ctx.deterministic, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None - - -def ring_flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -def ring_flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=None, -): - return RingFlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - group, - ) - - -# HuggingFace-compatible wrapper for ring attention -# This follows the same pattern as ulysses_attention_forward in distributed_attention.py -def ring_attention_forward( - self, # This will be the LlamaAttention instance - query_states, - key_states, - value_states, - attention_mask=None, - scaling=None, - dropout=0.0, - is_causal=True, - **kwargs, -): - """ - Ring attention forward pass compatible with HuggingFace's attention interface. - - Args: - self: The LlamaAttention module instance - query_states: (batch, heads, seq, dim) - HuggingFace format - key_states: (batch, heads, seq, dim) - HuggingFace format - value_states: (batch, heads, seq, dim) - HuggingFace format - attention_mask: Not used (ring attention handles masking internally) - scaling: Softmax scaling factor - dropout: Dropout probability - is_causal: Whether to use causal masking - **kwargs: Additional arguments (ignored) - - Returns: - tuple: (attn_output, None) where attn_output is (batch, seq, heads, dim) - """ - # Convert from HF format (batch, heads, seq, dim) to flash_attn format (batch, seq, heads, dim) - assert is_setup(), 'Incorrectly setup SP/DP Groups.' - - gid = dist.get_rank() // sp_size() - group = get_group(gid) - - q = query_states.transpose(1, 2).contiguous() - k = key_states.transpose(1, 2).contiguous() - v = value_states.transpose(1, 2).contiguous() - - # Ring attention expects (batch, seq, heads, dim) - # Call the ring flash attention function - attn_output = ring_flash_attn_func( - q, - k, - v, - dropout_p=dropout, - softmax_scale=scaling, - causal=is_causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - group=group, - ) - - # Output is already in (batch, seq, heads, dim) format, which HF expects after attention - # Note: Llama's forward handles the reshape internally - return attn_output, None diff --git a/bench_dc_ulysses/run.sh b/bench_dc_ulysses/run.sh deleted file mode 100755 index 08bb41b09744..000000000000 --- a/bench_dc_ulysses/run.sh +++ /dev/null @@ -1,56 +0,0 @@ -HOST_IP=$1 -NUM_NODES=$2 -NUM_PROCESSES=$3 -BACKEND=$4 -MODEL=$5 -GRADIENT_ACCUMULATION_STEPS=$6 -DEEPCOMPILE=$7 -shift 7 -EXTRA_OPTS="$@" - -export NCCL_DEBUG=WARN - -CONFIG_TEMPLATE=configs/ds_config.yaml.template - -echo "HOST_IP: ${HOST_IP}" -echo "NUM_NODES: ${NUM_NODES}" -echo "NUM_PROCESSES: ${NUM_PROCESSES}" -echo "BACKEND: ${BACKEND}" -echo "MODEL: ${MODEL}" -echo "GRADIENT_ACCUMULATION_STEPS: ${GRADIENT_ACCUMULATION_STEPS}" -echo "EXTRA_OPTS: ${EXTRA_OPTS}" - -MACHINE_RANK=$(hostname | sed 's/[^0-9]*//g') - -python generate_conf.py \ - --machine_rank ${MACHINE_RANK} \ - --num_machines ${NUM_NODES} \ - --num_processes ${NUM_PROCESSES} \ - --template_file ${CONFIG_TEMPLATE} \ - --output_file configs/config.yaml - -GAS_OPTS="--gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS}" -DEEPCOMPILE_OPTS="" -if [ "${DEEPCOMPILE}" == "1" ]; then - DEEPCOMPILE_OPTS="--deepcompile" -fi - -if [ "${BACKEND}" == "deepspeed" ]; then - python generate_conf.py \ - --machine_rank ${MACHINE_RANK} \ - --num_machines ${NUM_NODES} \ - --num_processes ${NUM_PROCESSES} \ - --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \ - ${DEEPCOMPILE_OPTS} \ - --template_file configs/ds_config.json.template \ - --output_file configs/ds_config.json -fi - -accelerate launch --main_process_ip ${HOST_IP} --main_process_port 12345 \ ---num_machines ${NUM_NODES} --num_processes ${NUM_PROCESSES} --machine_rank ${MACHINE_RANK} \ ---config_file configs/config.yaml \ -run_acc_lm.py \ ---model_name "${MODEL}" \ -${GAS_OPTS} \ -${EXTRA_OPTS} \ -2>&1 | tee ${LOG_FILE} \ No newline at end of file diff --git a/bench_dc_ulysses/run_acc_lm.py b/bench_dc_ulysses/run_acc_lm.py deleted file mode 100644 index ce3aeaff443c..000000000000 --- a/bench_dc_ulysses/run_acc_lm.py +++ /dev/null @@ -1,251 +0,0 @@ -import os - -# Suppress tokenizers parallelism warning (must be before importing transformers) -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -import argparse -from datetime import datetime - -import torch - -from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, enable_full_determinism -from datasets import load_dataset -from accelerate import Accelerator -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -import torch.distributed as dist - -import torch -import random -import numpy as np -import time -import os - -from distributed_attention import ulysses_attention_forward -# from ring_attention import ring_attention_forward -from sp_dp_registry import get_group, populate_registry, get_registry - -torch.set_float32_matmul_precision("high") - -def set_seed(seed: int = 42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - -def seed_worker(worker_id): - worker_seed = 12 + worker_id - np.random.seed(worker_seed) - random.seed(worker_seed) - - -def prepare_autosp_inputs(input_id: torch.Tensor, label_id: torch.Tensor, position_id: torch.Tensor, attention_mask: torch.Tensor, seq_dim: int): - torch._dynamo.decorators.mark_dynamic(input_id, seq_dim) - torch._dynamo.decorators.mark_dynamic(label_id, seq_dim) - torch._dynamo.decorators.mark_dynamic(position_id, seq_dim) - torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) - input_id.tag = "input_id" - label_id.tag = "label_id" - position_id.tag = "position_id" - return input_id, label_id, position_id, attention_mask - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-7b-hf") - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--num_epochs", type=int, default=1) - parser.add_argument("--seq_length", type=int, default=512) - parser.add_argument("--steps", type=int, default=1) - parser.add_argument("--learning_rate", type=float, default=2e-5) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--activation_checkpointing", action="store_true") - parser.add_argument("--dataset_name", type=str, default="timdettmers/openassistant-guanaco") - parser.add_argument("--num_layers", type=int, default=1) - parser.add_argument("--compile", type=str, default="deepcompile") - parser.add_argument("--passes", type=str, default=None) - parser.add_argument("--backend", type=str, default="inductor") - parser.add_argument("--offload_opt_states", action="store_true") - parser.add_argument("--profile", action="store_true") - parser.add_argument("--profile_memory", action="store_true") - parser.add_argument("--deterministic", action="store_true") - parser.add_argument("--profile_dir", type=str, default="profiles") - parser.add_argument("--bench_step", type=int, default=1) - parser.add_argument("--warmup_step", type=int, default=15) - parser.add_argument("--print_interval", type=int, default=1) - parser.add_argument("--experiment_folder", type=str, default="") - parser.add_argument("--sp_size", type=int, default=2) - parser.add_argument("--dp_size", type=int, default=1) - - return parser.parse_args() - -def main(): - args = get_args() - set_seed(12) - - if args.deterministic: - enable_full_determinism(12) - from torch._inductor import config - config.fallback_random = True - torch.use_deterministic_algorithms(True) - - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - device = accelerator.device - is_deepspeed = accelerator.state.deepspeed_plugin is not None - assert accelerator.num_processes == args.sp_size * args.dp_size, 'Incorrect dp/sp sizing' - - ## Set sp/dp groups accordingly. ## - if args.compile in ['compile', 'eager', 'ringattn']: - populate_registry(args.sp_size, args.dp_size) - - if accelerator.is_main_process: - print(f'GROUP_REGISTRY: {get_registry()}') - - # Load model and tokenizer - if accelerator.is_main_process: - print("Loading model and tokenizer...") - - model_name = args.model_name - if args.compile == "deepcompile": - attention_backend = "sdpa" - else: - if args.compile == "eager" or args.compile == "compile": - from transformers.models.llama import modeling_llama - attention_backend = "ulyssess" - modeling_llama.ALL_ATTENTION_FUNCTIONS["ulyssess"] = ulysses_attention_forward - elif args.compile == "ringattn": - from transformers.models.llama import modeling_llama - attention_backend = "ringattn" - modeling_llama.ALL_ATTENTION_FUNCTIONS["ringattn"] = ring_attention_forward - - if args.num_layers is not None: - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - if accelerator.is_main_process: - print(f"num_hidden_layers: {model_config.num_hidden_layers} -> {args.num_layers}") - model_config.num_hidden_layers = args.num_layers - model_config._attn_implementation = attention_backend - model = AutoModelForCausalLM.from_config(model_config, trust_remote_code=True) - else: - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - model_config._attn_implementation = attention_backend - model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config, trust_remote_code=True) - - if args.activation_checkpointing: - model.gradient_checkpointing_enable() - - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - tokenizer.pad_token = tokenizer.eos_token - - # Load dataset - if accelerator.is_main_process: - print("Loading dataset...") - - g = torch.Generator() - g.manual_seed(12) - dataset = load_dataset('ag_news', split='train[:1%]') - - tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - tokenizer.pad_token = tokenizer.convert_ids_to_tokens(2) - - def tokenize_function(examples): - return tokenizer(examples['text'], padding='max_length', max_length=args.seq_length, truncation=True) ## Fix max_length and generate fake data instead to not exhaust disk. - - tokenized_dataset = dataset.map(tokenize_function, batched=True) - tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) - - num_replicas_ = args.dp_size - rank_ = accelerator.process_index // args.sp_size - - sampler = DistributedSampler(tokenized_dataset, num_replicas=num_replicas_, rank=rank_, seed=12, shuffle=False) - data_loader = DataLoader(tokenized_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4, worker_init_fn=seed_worker, generator=g) - - optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) - - model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) - print(f"Model prepared: {model.__class__}") - - if args.compile == "deepcompile": - print(f"Running deepcompile with backend={args.backend}") - torch._dynamo.config.capture_dynamic_output_shape_ops = True - torch._dynamo.config.capture_scalar_outputs = True - model.compile(backend=args.backend) - elif args.compile in ["compile", "ringattn"]: - print(f"Running torch.compile with backend={args.backend}") - torch._dynamo.config.capture_dynamic_output_shape_ops = True - torch._dynamo.config.capture_scalar_outputs = True - model = torch.compile(model, backend=args.backend) - else: - print(f"Running eager") - - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - model_name = args.model_name.split("/")[-1] - exp_name = f"{model_name}_np{accelerator.num_processes}_{args.compile}_" \ - f"B{args.backend}_" \ - f"L{0 if args.num_layers is None else args.num_layers}_" \ - f"bs{args.batch_size}_seq{args.seq_length}_" \ - f"T{timestamp}" - if args.profile_dir: - if accelerator.is_main_process and args.profile_dir: - os.makedirs(args.profile_dir, exist_ok=True) - if args.profile: - prof_dir = f"{args.profile_dir}/{exp_name}" - os.makedirs(prof_dir, exist_ok=True) - accelerator.wait_for_everyone() - - # Training loop - model.train() - global_step = 0 - print(f"Using global sequence length: {args.seq_length}") - - os.makedirs("logs", exist_ok=True) - loss_log_file = open(f"logs/loss_{args.compile}_{args.seq_length}_{accelerator.process_index}.csv", "w") - loss_log_file.write("step,loss\n") - - sp_rank = dist.get_rank() % args.sp_size - for epoch in range(args.num_epochs): - start_iter = time.time() - - for step, batch in enumerate(data_loader): - input_ids = batch['input_ids'].to(device) # [B, S] - B, S = input_ids.shape - - label_ids = input_ids.clone() # [B, S] - position_ids = torch.arange(S, device=device).unsqueeze(0) - attention_mask = batch['attention_mask'].to(device) - - #HACK: store the padding mask to be accessed directly in local attention - from distributed_attention import set_padding_mask - set_padding_mask(attention_mask) - - if args.compile == 'deepcompile': - input_ids, label_ids, position_ids, attention_mask = prepare_autosp_inputs(input_ids, label_ids, position_ids, attention_mask, seq_dim=1) - else: - chunk_size = S // args.sp_size - start = sp_rank * chunk_size - end = start + chunk_size - input_ids = input_ids[:, start:end] # [B, S_shard] - label_ids = label_ids[:, start:end] # [B, S_shard] - must match input_ids - position_ids = position_ids[:, start:end] - - outputs = model(input_ids=input_ids, labels=label_ids, position_ids=position_ids, attention_mask=attention_mask) - loss = outputs.loss - print(f"Epoch {epoch+1}, Step {global_step}, Loss: {loss.item()} time: {time.time() - start_iter} alloc_mem: {torch.cuda.memory_allocated() / (1024 ** 3)} peak_mem: {torch.cuda.max_memory_allocated() / (1024 ** 3)}") - - accelerator.backward(loss) - - loss_log_file.write(f"{global_step},{loss.item()}\n") - loss_log_file.flush() - - global_step += 1 - if global_step > args.steps: - break - -if __name__ == "__main__": - torch._dynamo.config.accumulated_cache_size_limit = 256 - torch._dynamo.config.cache_size_limit = 128 - torch._dynamo.config.optimize_ddp = False - main() - - diff --git a/bench_dc_ulysses/run_bench.sh b/bench_dc_ulysses/run_bench.sh deleted file mode 100755 index a46d2412df2c..000000000000 --- a/bench_dc_ulysses/run_bench.sh +++ /dev/null @@ -1,19 +0,0 @@ -PROFILE_DIR=${PROFILE_DIR:-profiles} -mkdir -p ${PROFILE_DIR} -PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}" -COMPILE_OPTS="--compile" -DC_OPTS="--compile --deepcompile" -ACC_OPTS="--gradient-accumulation-steps 1" -AC_OPTS="--activation-checkpointing" - -MODEL="meta-llama/Llama-2-7b-chat-hf" -BATCH_SIZE_OPTS=(1) -SEQ_LENGTH=$1 - -for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do - ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} ${ACC_OPTS} ${PROFILE_OPTS}" - - # compiled ulysses - bash ./run_multinode.sh --backend inductor ${ARGS} ${DC_OPTS} --num-layers 1 --num-gpus 2 --seq-length ${SEQ_LENGTH} - cp -r logs ${PROFILE_DIR}/ -done diff --git a/bench_dc_ulysses/run_bench_acc.sh b/bench_dc_ulysses/run_bench_acc.sh deleted file mode 100755 index a3b66844d279..000000000000 --- a/bench_dc_ulysses/run_bench_acc.sh +++ /dev/null @@ -1,42 +0,0 @@ -PROFILE_DIR=${PROFILE_DIR:-profiles} -mkdir -p ${PROFILE_DIR} -PROFILE_OPTS="--profile --profile-dir ${PROFILE_DIR}" -COMPILE_OPTS="--compile" -N3Z_OPTS="--compile --deepcompile" -AC_OPTS="--activation-checkpointing" - -MODEL="meta-llama/Meta-Llama-3-70B-Instruct" -BATCH_SIZE_OPTS=(1) -SEQ_LENGTH_OPTS=(1024) -ACC_OPTS=(2 4 8 16) -for ACC_STEP in ${ACC_OPTS[@]}; do - for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do - for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do - ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${AC_OPTS} ${PROFILE_OPTS} --gradient-accumulation-steps ${ACC_STEP}" - bash ./run_multinode.sh --backend deepspeed ${ARGS} - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS} - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch,selective_gather - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes selective_gather - cp -r logs ${PROFILE_DIR}/ - done - done -done - -MODEL="mistralai/Mixtral-8x7B-v0.1" -BATCH_SIZE_OPTS=(1) -SEQ_LENGTH_OPTS=(1024) -ACC_OPTS=(2 4 8 16) -for ACC_STEP in ${ACC_OPTS[@]}; do - for BATCH_SIZE in ${BATCH_SIZE_OPTS[@]}; do - for SEQ_LENGTH in ${SEQ_LENGTH_OPTS[@]}; do - ARGS="--model ${MODEL} --batch-size ${BATCH_SIZE} --seq-length ${SEQ_LENGTH} ${AC_OPTS} ${PROFILE_OPTS} --gradient-accumulation-steps ${ACC_STEP}" - bash ./run_multinode.sh --backend deepspeed ${ARGS} - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${COMPILE_OPTS} - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch,selective_gather - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes prefetch - bash ./run_multinode.sh --backend deepspeed ${ARGS} ${N3Z_OPTS} --passes selective_gather - cp -r logs ${PROFILE_DIR}/ - done - done -done diff --git a/bench_dc_ulysses/run_correctness_test.sh b/bench_dc_ulysses/run_correctness_test.sh deleted file mode 100755 index 37decf60a9d4..000000000000 --- a/bench_dc_ulysses/run_correctness_test.sh +++ /dev/null @@ -1,4 +0,0 @@ -MASTER_ADDR="" - -ds_ssh "cd /scratch/amlt_code/run_z3_graph_rewrite/accelerate; BLOB_BASE_DIR=/mnt/post-training-ppo bash ./run.sh ${MASTER_ADDR} 4 32 deepspeed 3 meta-llama/Meta-Llama-3-70B-Instruct 1 --batch_size 1 --seq_length 512 --activation_checkpointing --bench_step 1000 --print_interval 1" 2>&1 | tee logs/debug_Meta-Llama-3-70B-Instruct_deepspeed_np32c0b1s512g1a1pALL.log -ds_ssh "cd /scratch/amlt_code/run_z3_graph_rewrite/accelerate; BLOB_BASE_DIR=/mnt/post-training-ppo bash ./run.sh ${MASTER_ADDR} 4 32 deepspeed 3 meta-llama/Meta-Llama-3-70B-Instruct 1 --compile --batch_size 1 --seq_length 512 --activation_checkpointing --passes prefetch,selective_gather --bench_step 1000 --print_interval 1" 2>&1 | tee logs/debug_Meta-Llama-3-70B-Instruct_deepspeed_np32c1b1s512g1a1pprefetch_selective_gather.log diff --git a/bench_dc_ulysses/run_multinode.sh b/bench_dc_ulysses/run_multinode.sh deleted file mode 100755 index c3e90a05ca79..000000000000 --- a/bench_dc_ulysses/run_multinode.sh +++ /dev/null @@ -1,112 +0,0 @@ -#!/bin/bash - -NUM_NODES=${NUM_NODES:-$(wc -l < hostfile_n4)} -NGPUS_PER_NODE=${NGPUS_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} -NUM_PROCESSES=$((${NUM_NODES} * ${NGPUS_PER_NODE})) - -BACKEND="deepspeed" # ignore -MODEL="meta-llama/Meta-Llama-3-8B" -COMPILE=0 -PASSES="ALL" -EXTRA_OPTS="" - -EAGER=0 -DEEPCOMPILE=0 -GRADIENT_ACCUMULATION_STEPS=1 -ACTIVATION_CHECKPOINTING=1 -BATCH_SIZE=1 -SEQ_LENGTH=512 - -while [[ $# -gt 0 ]]; do - case $1 in - --backend) - BACKEND="$2" - shift 2 - ;; - --batch-size) - BATCH_SIZE="$2" - EXTRA_OPTS="${EXTRA_OPTS} --batch_size $2" - shift 2 - ;; - --seq-length) - SEQ_LENGTH="$2" - EXTRA_OPTS="${EXTRA_OPTS} --seq_length $2" - shift 2 - ;; - --gradient-accumulation-steps) - GRADIENT_ACCUMULATION_STEPS="$2" - # EXTRA_OPTS="${EXTRA_OPTS} --gradient_accumulation_steps $2" - shift 2 - ;; - --activation-checkpointing) - ACTIVATION_CHECKPOINTING=1 - EXTRA_OPTS="${EXTRA_OPTS} --activation_checkpointing" - shift - ;; - --compile) - COMPILE=1 - EXTRA_OPTS="${EXTRA_OPTS} $1" - shift - ;; - --eager) - EAGER=1 - EXTRA_OPTS="${EXTRA_OPTS} --backend eager" - shift - ;; - --deepcompile) - DEEPCOMPILE=1 - shift - ;; - --passes) - PASSES="$2" - EXTRA_OPTS="${EXTRA_OPTS} $1 $2" - shift 2 - ;; - --profile) - EXTRA_OPTS="${EXTRA_OPTS} $1" - shift - ;; - --profile-dir) - EXTRA_OPTS="${EXTRA_OPTS} --profile_dir $2" - shift 2 - ;; - --model) - MODEL="$2" - shift 2 - ;; - --num-layers) - EXTRA_OPTS="${EXTRA_OPTS} --num_layers $2" - shift 2 - ;; - --num-gpus) - NGPUS_PER_NODE="$2" - NUM_PROCESSES=$((${NUM_NODES} * ${NGPUS_PER_NODE})) - shift 2 - ;; - *) - echo "Unknown option: $1" - exit 1 - ;; - esac -done - - -HOST_IP=$(hostname -i) - -mkdir -p logs - -SCRIPT_DIR=$(dirname $(realpath $0)) - -#replace , with _ in PASSES -PASSES=$(echo $PASSES | tr ',' '_') - -LOG_FILE=debug_b${BACKEND}np${NUM_PROCESSES}c${COMPILE}dc${DEEPCOMPILE}bs${BATCH_SIZE}seq${SEQ_LENGTH}.log - -if [ "${NUM_NODES}" == "1" ]; then - # avoid dependency on pdsh when possible - cd ${SCRIPT_DIR}; bash ./run.sh ${HOST_IP} ${NUM_NODES} ${NUM_PROCESSES} ${BACKEND} ${MODEL} ${GRADIENT_ACCUMULATION_STEPS} ${DEEPCOMPILE} ${EXTRA_OPTS} \ - 2>&1 | tee logs/${LOG_FILE} -else - ds_ssh -f hostfile_n${NUM_NODES} "cd ${SCRIPT_DIR}; bash ./run.sh ${HOST_IP} ${NUM_NODES} ${NUM_PROCESSES} ${BACKEND} ${MODEL} ${GRADIENT_ACCUMULATION_STEPS} ${SCHEDULE} ${OFFLOAD_OPT_STATES} ${EXTRA_OPTS}" \ - 2>&1 | tee logs/${LOG_FILE} -fi diff --git a/bench_dc_ulysses/run_ulysses.sh b/bench_dc_ulysses/run_ulysses.sh deleted file mode 100755 index 0554c92bed5f..000000000000 --- a/bench_dc_ulysses/run_ulysses.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash - -SEQ_LEN=${1:-1024} -COMPILE=${2:-eager} -SP_SIZE=${3:-2} -DP_SIZE=${4:-1} -LAYER_COUNT=${5:-""} -EXP_NAME=${6:-""} - -if [[ "$COMPILE" != "eager" && "$COMPILE" != "compile" && "$COMPILE" != "deepcompile" && "$COMPILE" != "ringattn" ]]; then - echo "Invalid mode: $COMPILE. Choose from eager, compile, deepcompile, ringattn." - exit 1 -fi - -HOST_IP=$(hostname -i | awk '{print $1}') -PORT=$(python3 -c "import socket; s = socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()") -NUM_NODES=1 -NUM_PROCESSES=$((SP_SIZE * DP_SIZE)) -MODEL="meta-llama/Llama-2-7b-chat-hf" -# MODEL="meta-llama/Llama-3.1-8B" -# MODEL="meta-llama/Llama-3.2-1B" -# MODEL="meta-llama/Llama-3.2-3B" -PROFILE_DIR=${PROFILE_DIR:-profiles} -mkdir -p ${PROFILE_DIR} -PROFILE_OPTS="--profile_dir ${PROFILE_DIR}" - -COMPILE_OPTS="--compile ${COMPILE}" -CONFIG_FILE="configs/torchcompile_config.yaml" -if [ "${COMPILE}" == "deepcompile" ]; then - CONFIG_FILE="configs/deepcompile_config.yaml" -fi - - -TIMESTAMP=$(date +"%Y%m%d_%H%M%S") -LOG_FILE=logs/log_${COMPILE}_seq${SEQ_LEN}_${TIMESTAMP}.log - -echo "HOST_IP: ${HOST_IP}" -echo "PORT: ${PORT}" -echo "NUM_NODES: ${NUM_NODES}" -echo "NUM_PROCESSES: ${NUM_PROCESSES}" -echo "MODEL: ${MODEL}" -echo "COMPILE: ${COMPILE}" -echo "SEQ_LEN: ${SEQ_LEN}" -echo "LOG_FILE: ${LOG_FILE}" - -EXTRA_OPTS="--seq_length=${SEQ_LEN} --experiment_folder=${EXP_NAME} --sp_size=${SP_SIZE} --dp_size=${DP_SIZE}" - -# Only pass --num_layers if provided -NUM_LAYER_OPTS="" -if [[ -n "${LAYER_COUNT}" ]]; then - NUM_LAYER_OPTS="--num_layers ${LAYER_COUNT}" -fi - -( -accelerate launch --main_process_ip ${HOST_IP} --main_process_port ${PORT} \ ---num_machines ${NUM_NODES} --num_processes ${NUM_PROCESSES} --machine_rank 0 \ ---config_file ${CONFIG_FILE} \ -run_acc_lm.py \ ---model_name "${MODEL}" ${NUM_LAYER_OPTS} \ -${PROFILE_OPTS} \ -${EXTRA_OPTS} \ -${COMPILE_OPTS} -) 2>&1 | tee ${LOG_FILE} - - diff --git a/bench_dc_ulysses/sample.slurm b/bench_dc_ulysses/sample.slurm deleted file mode 100644 index 199e49a53168..000000000000 --- a/bench_dc_ulysses/sample.slurm +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -# SLURM Job Submission OR Interactive Environment Setup -# For batch submission: sbatch sample.slurm -# For interactive: -# srun -A bcjw-delta-gpu --time=1:00:00 --nodes=1 --mem=100G --gpus=2 --partition=gpuA100x4-interactive --pty /bin/bash -# -# SBATCH directives (ignored when sourced, used for sbatch submission): -#SBATCH -A bcjw-delta-gpu -#SBATCH --time=0:20:00 -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=2 -#SBATCH --cpus-per-task=16 -#SBATCH --partition=gpuA100x4-interactive -#SBATCH --gpus=2 -#SBATCH --gpu-bind=closest -#SBATCH --mem=100G -#SBATCH --output=slurm_out/out_%j.log - -source ~/.bashrc -conda activate /u/ndani/autosp-env - -# # NCCL configuration -export NCCL_INCLUDE_DIR=/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/12.8/nccl-2.25/include -export NCCL_HOME=/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/comm_libs/12.8/nccl-2.25 -export CPATH=$NCCL_HOME/include:$CPATH -export LD_LIBRARY_PATH=$NCCL_HOME/lib:$LD_LIBRARY_PATH -# export CPATH=$NCCL_INCLUDE_DIR:$CPATH - -export TRITON_CACHE_DIR="/tmp/triton_$USER" -export NCCL_DEBUG=WARN -export NCCL_SOCKET_IFNAME=lo -export NCCL_IB_DISABLE=1 -export NCCL_P2P_LEVEL=NVL - -export LD_LIBRARY_PATH=/usr/local/cuda-12.8/lib64:$LD_LIBRARY_PATH -export PATH=/usr/local/cuda-12.8/bin:$PATH - -export HF_DATASETS_CACHE="/u/$USER/.cache" -export HF_HOME=$HF_DATASETS_CACHE -export HF_HUB_CACHE=$HF_DATASETS_CACHE -export HF_ASSETS_CACHE=$HF_DATASETS_CACHE -export TRANSFORMERS_CACHE=$HF_DATASETS_CACHE - diff --git a/bench_dc_ulysses/sp_dp_registry.py b/bench_dc_ulysses/sp_dp_registry.py deleted file mode 100644 index 4fc1913f1499..000000000000 --- a/bench_dc_ulysses/sp_dp_registry.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import torch.distributed as dist - -GROUP_REGISTRY = {} # int -> dist.ProcessGroup - -def register_groups(groups): - """groups: List[List[int]], e.g. [[0,1],[2,3]]""" - for gid, ranks in enumerate(groups): - if gid not in GROUP_REGISTRY: - GROUP_REGISTRY[gid] = dist.new_group(ranks) - -def get_group(gid: int): - return GROUP_REGISTRY[gid] if gid is not None else dist.group.WORLD - -def get_registry(): - return GROUP_REGISTRY - -def is_setup(): - return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False - -def sp_size(): - assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' - - return GROUP_REGISTRY['SP_SIZE'] - -def dp_size(): - assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' - - return GROUP_REGISTRY['DP_SIZE'] - -def populate_registry(SP_SIZE, DP_SIZE): - ## We register in the run_acc_lm.py file for baselines to reduce code-duplication. - ## Else the registration happens within the SP compiler pass within deepspeed. - group_listing = [] - offset = 0 - for _ in range(DP_SIZE): - group_listing.append([i + offset for i in range(SP_SIZE)]) - offset += SP_SIZE - - register_groups(group_listing) - - ## Extraneous metadata required for proper instatiation. ## - GROUP_REGISTRY['SP_SIZE'] = SP_SIZE - GROUP_REGISTRY['DP_SIZE'] = DP_SIZE - GROUP_REGISTRY['is_reg'] = True diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py index fea046000516..d745bbda4624 100644 --- a/deepspeed/compile/fx.py +++ b/deepspeed/compile/fx.py @@ -162,4 +162,4 @@ def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Nod exclude = exclude or [] to_replace = [u for u in node.users if u not in exclude] for user in to_replace: - user.replace_input_with(node, replacement) \ No newline at end of file + user.replace_input_with(node, replacement) From a38674e7d0efe7a0480f68ff28405e0fd1a6e8ea Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sun, 22 Feb 2026 23:27:32 -0600 Subject: [PATCH 04/14] move constants and apis to deepspeed library --- deepspeed/compile/passes/sp_compile.py | 43 ++++++++++++++++++++++++++ deepspeed/compile/util.py | 11 +++---- deepspeed/runtime/constants.py | 7 +++++ 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index 726d998e518a..19c1588d273e 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -20,11 +20,54 @@ from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.experimental.symbolic_shapes import ShapeEnv +from deepspeed.runtime import constants from ..custom_ops import all_to_all from ..fx import find_node_by_name, get_node_shape_meta from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes, ShardingConfig +def prepare_autosp_inputs(input_id: torch.Tensor, label_id: torch.Tensor, position_id: torch.Tensor = None, attention_mask: torch.Tensor = None, seq_dim: int = 1): + """ + Prepare inputs for AutoSP by marking dynamic dimensions and tagging tensors. + + Args: + input_id: Token IDs tensor (required) + label_id: Label IDs tensor (required) + position_id: Position IDs tensor (optional) + attention_mask: Attention mask tensor (optional) + seq_dim: Sequence dimension index to mark as dynamic (default: 1) + """ + + if input_id is None: + raise ValueError("input_id is required") + if label_id is None: + raise ValueError("label_id is required") + + if seq_dim < 0 or seq_dim >= input_id.ndim: + raise ValueError(f"seq_dim {seq_dim} must be a valid index for input_id with shape {input_id.shape}") + + if position_id is not None: + if seq_dim >= position_id.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for position_id with shape {position_id.shape}") + + if attention_mask is not None: + if seq_dim >= attention_mask.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for attention_mask with shape {attention_mask.shape}") + + torch._dynamo.decorators.mark_dynamic(input_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(label_id, seq_dim) + if position_id is not None: + torch._dynamo.decorators.mark_dynamic(position_id, seq_dim) + if attention_mask is not None: + torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) + + input_id.tag = constants.INPUT_ID_KEY + label_id.tag = constants.LABEL_ID_KEY + if position_id is not None: + position_id.tag = constants.POSITION_ID_KEY + + return input_id, label_id, position_id, attention_mask + def pass_shard_seq_dim(gm: GraphModule, example_inputs): """ Finds all direct and indirect consumers of the input sequence, label and position ids. diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 5d4fc0a01751..9e3150e1a221 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -24,10 +24,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder - -INPUT_ID_KEY = "input_id" -LABEL_ID_KEY = "label_id" -POSITION_ID_KEY = "position_id" +from deepspeed.runtime import constants def is_deepcompile_supported() -> bool: return required_torch_version(min_version=2.6, max_version=2.9) and get_accelerator().device_name() == "cuda" @@ -547,21 +544,21 @@ def get_sdpa_nodes(gm: GraphModule) -> List[Node]: def get_input_id_node(gm: GraphModule) -> Node: from .fx import find_node_by_tag - node = find_node_by_tag(gm, INPUT_ID_KEY) + node = find_node_by_tag(gm, constants.INPUT_ID_KEY) if node is None: raise RuntimeError("Failed to find a node for the input sequence.") return node def get_label_id_node(gm: GraphModule) -> Node: from .fx import find_node_by_tag - node = find_node_by_tag(gm, LABEL_ID_KEY) + node = find_node_by_tag(gm, constants.LABEL_ID_KEY) if node is None: raise RuntimeError("Failed to find a node for the label.") return node def get_position_id_node(gm: GraphModule) -> Node: from .fx import find_node_by_tag - node = find_node_by_tag(gm, POSITION_ID_KEY) + node = find_node_by_tag(gm, constants.POSITION_ID_KEY) return node def create_shard_offsets( diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 9e73bad73376..a916befc76f4 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -501,3 +501,10 @@ class ValidationMode: ######################################### USE_DATA_BEFORE_EXPERT_PARALLEL = "use_data_before_expert_parallelism" USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT = False + +######################################### +# AUTOSP +######################################### +INPUT_ID_KEY = "input_id" +LABEL_ID_KEY = "label_id" +POSITION_ID_KEY = "position_id" From fea194c928e1fa261dea80593371d99676566fa2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 27 Feb 2026 05:20:47 +0000 Subject: [PATCH 05/14] add zero-1 interoperability to autosp --- deepspeed/compile/backend.py | 7 ++ deepspeed/compile/config.py | 11 ++++ deepspeed/compile/custom_ops/all_to_all.py | 40 ++++++----- .../compile/custom_ops/sp_dp_registry.py | 48 ++++++++++++++ deepspeed/compile/init_sp.py | 4 +- deepspeed/compile/passes/sp_compile.py | 37 +++++++---- deepspeed/compile/util.py | 66 ++++++++++--------- deepspeed/runtime/engine.py | 40 ++++++----- 8 files changed, 178 insertions(+), 75 deletions(-) create mode 100644 deepspeed/compile/custom_ops/sp_dp_registry.py diff --git a/deepspeed/compile/backend.py b/deepspeed/compile/backend.py index 4f930be09c3c..b5d1e9890ba5 100644 --- a/deepspeed/compile/backend.py +++ b/deepspeed/compile/backend.py @@ -384,3 +384,10 @@ def compiler_fn(gm, sample_inputs): raise ValueError(f"Unsupported backend {backend}") return backend_fn + + +def make_autosp_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False, sp_size=2, dp_size=1): + def backend_fn(gm: GraphModule, real_inputs): + apply_autosp(gm, real_inputs, debug_log, sp_size=sp_size, dp_size=dp_size) + return torch._inductor.compile(gm, real_inputs) + return backend_fn diff --git a/deepspeed/compile/config.py b/deepspeed/compile/config.py index 739add99271c..c10f54b06bd2 100644 --- a/deepspeed/compile/config.py +++ b/deepspeed/compile/config.py @@ -3,8 +3,10 @@ # DeepSpeed Team +from typing import List, Optional, Literal from deepspeed.runtime.config_utils import DeepSpeedConfigModel +PassName = Literal["z1", "z3", "autosp"] class CompileConfig(DeepSpeedConfigModel): """ Configure compile settings """ @@ -53,3 +55,12 @@ class CompileConfig(DeepSpeedConfigModel): keep_all_input_tensors: bool = False """ Keep real values for all input tensors in InputStorage instead of using dummy values """ + + passes: Optional[List[PassName]] = None + """ Composes different optimizations. """ + + sp_size: int = 1 + """ SP group-size """ + + dp_size: int = 1 + """ DP group-size """ diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py index 6f4b5172b67a..58835d444edd 100644 --- a/deepspeed/compile/custom_ops/all_to_all.py +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -1,61 +1,71 @@ import torch import torch.distributed as dist +from .sp_dp_registry import get_group, is_setup, sp_size, dp_size @torch.library.custom_op("autosp::all_to_all", mutates_args=()) def all_to_all( input: torch.Tensor, scatter_idx: int, gather_idx: int, - world_size: int, name: str, ) -> torch.Tensor: + """ + All-to-all collective for SDPA tensors [B, N, S, H]. + + For QKV (scatter_idx=1, gather_idx=2): + [B, N, S/P, H] -> [B, N/P, S, H] + For O (scatter_idx=2, gather_idx=1): + [B, N/P, S, H] -> [B, N, S/P, H] + """ + assert is_setup(), 'Incorrect initialization of SP/DP mesh.' B, dim1, dim2, H = input.shape + gid = dist.get_rank() // sp_size() + group = get_group(gid) if scatter_idx == 1: N, local_S = dim1, dim2 - input_t = input.reshape(B, world_size, N // world_size, local_S, H) + input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H) input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=dist.group.WORLD) + dist.all_to_all_single(output, input_t, group=group) output = output.permute(1, 2, 0, 3, 4).contiguous() - output = output.reshape(B, N // world_size, world_size * local_S, H) - else: + output = output.reshape(B, N // sp_size(), sp_size() * local_S, H) + else: # scatter_idx == 2, O: scatter sequence, gather heads local_N, S = dim1, dim2 - input_t = input.reshape(B, local_N, world_size, S // world_size, H) + input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H) input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=dist.group.WORLD) + dist.all_to_all_single(output, input_t, group=group) output = output.permute(1, 0, 2, 3, 4).contiguous() - output = output.reshape(B, world_size * local_N, S // world_size, H) + output = output.reshape(B, sp_size() * local_N, S // sp_size(), H) return output @torch.library.register_fake("autosp::all_to_all") -def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, world_size: int, name: str): +def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str): B, dim1, dim2, H = input.shape if scatter_idx == 1: - return input.new_empty(B, dim1 // world_size, dim2 * world_size, H) + return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H) else: - return input.new_empty(B, dim1 * world_size, dim2 // world_size, H) + return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H) def _all_to_all_backward_setup(ctx, inputs, output): - _, scatter_idx, gather_idx, world_size, name = inputs + _, scatter_idx, gather_idx, name = inputs ctx.scatter_idx = gather_idx ctx.gather_idx = scatter_idx - ctx.world_size = world_size ctx.name = name + "_grad" def _all_to_all_backward(ctx, grad): return ( - all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.world_size, ctx.name), - None, None, None, None, + all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), + None, None, None, None ) diff --git a/deepspeed/compile/custom_ops/sp_dp_registry.py b/deepspeed/compile/custom_ops/sp_dp_registry.py new file mode 100644 index 000000000000..b7875699b663 --- /dev/null +++ b/deepspeed/compile/custom_ops/sp_dp_registry.py @@ -0,0 +1,48 @@ +import torch +import torch.distributed as dist + +GROUP_REGISTRY = {} # int -> dist.ProcessGroup + +def register_groups(groups): + """groups: List[List[int]], e.g. [[0,1],[2,3]]""" + for gid, ranks in enumerate(groups): + if gid not in GROUP_REGISTRY: + GROUP_REGISTRY[gid] = dist.new_group(ranks) + +def get_group(gid: int): + return GROUP_REGISTRY[gid] if gid is not None else dist.group.WORLD + +def get_registry(): + return GROUP_REGISTRY + +def is_setup(): + return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False + +def sp_size(): + assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' + + return GROUP_REGISTRY['SP_SIZE'] + +def dp_size(): + assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' + + return GROUP_REGISTRY['DP_SIZE'] + +def populate_registry(SP_SIZE, DP_SIZE): + """ Populate rank to SP/DP mesh index. """ + + if GROUP_REGISTRY.get('is_reg', False): + return + + group_listing = [] + offset = 0 + for _ in range(DP_SIZE): + group_listing.append([i + offset for i in range(SP_SIZE)]) + offset += SP_SIZE + + register_groups(group_listing) + + ## Extraneous metadata required for proper instatiation. ## + GROUP_REGISTRY['SP_SIZE'] = SP_SIZE + GROUP_REGISTRY['DP_SIZE'] = DP_SIZE + GROUP_REGISTRY['is_reg'] = True diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py index 7862420a2006..65ecb8e1cf9c 100644 --- a/deepspeed/compile/init_sp.py +++ b/deepspeed/compile/init_sp.py @@ -7,8 +7,8 @@ from torch.fx import GraphModule from .passes.sp_compile import apply_autosp -def init_autosp(): +def init_autosp(sp_size=2, dp_size=1): def backend_fn(gm: GraphModule, real_inputs): - apply_autosp(gm, real_inputs, debug=False) + apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size) return torch._inductor.compile(gm, real_inputs) return backend_fn diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index 19c1588d273e..4cbc74fba962 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -22,9 +22,9 @@ from deepspeed.runtime import constants -from ..custom_ops import all_to_all +from ..custom_ops import all_to_all, sp_dp_registry from ..fx import find_node_by_name, get_node_shape_meta -from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes, ShardingConfig +from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes def prepare_autosp_inputs(input_id: torch.Tensor, label_id: torch.Tensor, position_id: torch.Tensor = None, attention_mask: torch.Tensor = None, seq_dim: int = 1): """ @@ -73,7 +73,7 @@ def pass_shard_seq_dim(gm: GraphModule, example_inputs): Finds all direct and indirect consumers of the input sequence, label and position ids. Shard the sequence dimension used by all such consumers. """ - world_size = dist.get_world_size() + sp_size = sp_dp_registry.sp_size() input_ids_node = get_input_id_node(gm) val = get_node_shape_meta(input_ids_node) @@ -88,7 +88,7 @@ def pass_shard_seq_dim(gm: GraphModule, example_inputs): with gm.graph.inserting_after(sym_seq_dim_node): sharded_node = gm.graph.call_function( operator.floordiv, - args=(sym_seq_dim_node, world_size) + args=(sym_seq_dim_node, sp_size) ) sharded_input_nodes = set() @@ -128,23 +128,23 @@ def pass_shard_seq_dim(gm: GraphModule, example_inputs): def pass_shard_input_ids(gm: GraphModule, example_inputs): - config = ShardingConfig.from_distributed() + """Shard input_ids tensor across ranks.""" input_ids_node = get_input_id_node(gm) - shard_tensor_node(gm, input_ids_node, config) + shard_tensor_node(gm, input_ids_node) def pass_shard_label_ids(gm: GraphModule, example_inputs): - config = ShardingConfig.from_distributed() + """Shard label_ids tensor across ranks.""" label_ids_node = get_label_id_node(gm) - shard_tensor_node(gm, label_ids_node, config) + shard_tensor_node(gm, label_ids_node) def pass_shard_position_ids(gm: GraphModule, example_inputs): - config = ShardingConfig.from_distributed() + """Shard position_ids tensor across ranks.""" position_ids_node = get_position_id_node(gm) if position_ids_node is None: print("[WARNING] position id node not found. Skipping sharding of position ids.") return - shard_tensor_node(gm, position_ids_node, config) + shard_tensor_node(gm, position_ids_node) def pass_insert_attention_all_to_all(gm: GraphModule, real_inputs): @@ -162,7 +162,7 @@ def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node with gm.graph.inserting_after(node): a2a_node = gm.graph.call_function( torch.ops.autosp.all_to_all.default, - args=(node, scatter_idx, gather_idx, world_size, name), + args=(node, scatter_idx, gather_idx, name), ) a2a_node.name = f"a2a_{name}" node.replace_all_uses_with(a2a_node) @@ -204,7 +204,22 @@ def apply_autosp( real_inputs, debug: bool = False, passes: Optional[List[Callable]] = None, + sp_size: int = 2, + dp_size: int = 1 ): + """ + Apply AutoSP (Ulysses) transformation passes to the graph and setup either DP/SP (2D) or SP (1D) mesh. + + Args: + gm: GraphModule to transform + real_inputs: Example inputs for shape propagation + debug: If True, print graph before/after each pass + passes: Optional custom list of passes (default: DEFAULT_PASSES) + """ + assert sp_size * dp_size <= torch.cuda.device_count(), 'Insufficient device count for mesh size' + + sp_dp_registry.populate_registry(sp_size, dp_size) + AUTOSP_PASSES = [ pass_shard_seq_dim, pass_shard_input_ids, diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 9e3150e1a221..25a047ab5852 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -26,6 +26,8 @@ from deepspeed.ops.op_builder.dc import DeepCompileBuilder from deepspeed.runtime import constants +from .custom_ops import sp_dp_registry + def is_deepcompile_supported() -> bool: return required_torch_version(min_version=2.6, max_version=2.9) and get_accelerator().device_name() == "cuda" @@ -524,17 +526,37 @@ def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor return padded -@dataclass -class ShardingConfig: - world_size: int - rank: int +INPUT_ID_KEY = "input_id" +LABEL_ID_KEY = "label_id" +POSITION_ID_KEY = "position_id" + + +def create_shard_offsets( + gm: GraphModule, + s0_node: Node +) -> Tuple[Node, Node]: + """ + Create FX nodes for computing shard start and end offsets. - @classmethod - def from_distributed(cls) -> "ShardingConfig": - return cls( - world_size=dist.get_world_size(), - rank=dist.get_rank(), - ) + Computes: + chunk_size = s0 // sp_size + start = rank * chunk_size + end = start + chunk_size + + Returns: + Tuple of (start_node, end_node) + """ + sp_size: int = sp_dp_registry.sp_size() + sp_rank: int = dist.get_rank() % sp_dp_registry.sp_size() + with gm.graph.inserting_after(s0_node): + chunk_size_node = gm.graph.call_function(operator.floordiv, args=(s0_node, sp_size)) + with gm.graph.inserting_after(chunk_size_node): + start_node = gm.graph.call_function(operator.mul, args=(sp_rank, chunk_size_node)) + with gm.graph.inserting_after(start_node): + end_node = gm.graph.call_function(operator.add, args=(start_node, chunk_size_node)) + + return start_node, end_node + def get_sdpa_nodes(gm: GraphModule) -> List[Node]: return list(gm.graph.find_nodes( @@ -561,28 +583,13 @@ def get_position_id_node(gm: GraphModule) -> Node: node = find_node_by_tag(gm, constants.POSITION_ID_KEY) return node -def create_shard_offsets( - gm: GraphModule, - sym_seq_dim_node: Node, - world_size: int, - rank: int -) -> Tuple[Node, Node]: - with gm.graph.inserting_after(sym_seq_dim_node): - chunk_size_node = gm.graph.call_function(operator.floordiv, args=(sym_seq_dim_node, world_size)) - with gm.graph.inserting_after(chunk_size_node): - start_node = gm.graph.call_function(operator.mul, args=(rank, chunk_size_node)) - with gm.graph.inserting_after(start_node): - end_node = gm.graph.call_function(operator.add, args=(start_node, chunk_size_node)) - - return start_node, end_node def create_symbolic_slice_indices( gm: GraphModule, sym_seq_dim_node: Node, - config: ShardingConfig ) -> Tuple[Node, Node]: - start_node, end_node = create_shard_offsets(gm, sym_seq_dim_node, config.world_size, config.rank) - + start_node, end_node = create_shard_offsets(gm, sym_seq_dim_node) + with gm.graph.inserting_after(end_node): slice_all = gm.graph.call_function(slice, args=(None, None, None)) with gm.graph.inserting_after(slice_all): @@ -592,8 +599,7 @@ def create_symbolic_slice_indices( def shard_tensor_node( gm: GraphModule, - tensor_node: Node, - config: ShardingConfig + tensor_node: Node ): from .fx import find_node_by_name, get_node_shape_meta, replace_node_users val = get_node_shape_meta(tensor_node) @@ -606,7 +612,7 @@ def shard_tensor_node( symb_seq_int_node = find_node_by_name(gm, str(seq_len)) assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}" - slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node, config) + slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node) indices = (slice_all, slice_range) with gm.graph.inserting_after(tensor_node): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index eb71a7c51765..9b2cc3f4863a 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1005,6 +1005,14 @@ def zero_sub_group_size(self): def zero_optimization_stage(self): return self._config.zero_optimization_stage + def compile_zero_optimization_stage(self): + """Determines if zero-pass is set in deepcompile's passes attributes.""" + return "z1" in self._config.compile_config.passes or "z3" in self._config.compile_config.passes + + def compile_autosp(self): + """Determines if AutoSP is set in deepcompile's passes attributes.""" + return "autosp" in self._config.compile_config.passes + def mics_shard_size(self): return self._config.mics_shard_size @@ -4383,23 +4391,21 @@ def passes_name_to_fn(passes): assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." compile_config = self._config.compile_config - if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] - and "offload_param" in self.config["zero_optimization"]) - and self._config.zero_config.offload_param.device == "cpu" - and self._config.zero_config.offload_optimizer.device == "cpu"): - compile_config.offload_parameters = True - if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule) - elif self.zero_optimization_stage() == ZeroStageEnum.gradients: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) - elif self.zero_optimization_stage() == ZeroStageEnum.weights: - if required_torch_version(min_version=2.9): - raise RuntimeError( - "DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. " - "Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.") - backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) - elif self.zero_optimization_stage() == ZeroStageEnum.disabled: - backend = init_autosp() + if self.compile_autosp(): + backend = init_autosp(sp_size=self._config.compile_config.sp_size, dp_size=self._config.compile_config.dp_size) + #backend = init_ulysses(self, backend, compile_config, compile_kwargs, schedule, sp_size=self._config.compile_config.sp_size, dp_size=self._config.compile_config.dp_size) + else: ## By default then only zero-style DP should be triggered in dc. ## + if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] + and "offload_param" in self.config["zero_optimization"]) + and self._config.zero_config.offload_param.device == "cpu" + and self._config.zero_config.offload_optimizer.device == "cpu"): + compile_config.offload_parameters = True + if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: + backend = init_z1(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.gradients: + backend = init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) + elif self.zero_optimization_stage() == ZeroStageEnum.weights: + backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) # Hook state must align with whether DeepCompile is active. self._set_deepcompile_active(enable_deepcompile) From bd916b7ea2ed7b119337fabf76f26f0d49aa5067 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 27 Feb 2026 18:27:31 +0000 Subject: [PATCH 06/14] fix early termination of gradients issue when using autosp --- deepspeed/compile/init_sp.py | 4 ++-- deepspeed/runtime/engine.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py index 65ecb8e1cf9c..de15bcb5f925 100644 --- a/deepspeed/compile/init_sp.py +++ b/deepspeed/compile/init_sp.py @@ -7,8 +7,8 @@ from torch.fx import GraphModule from .passes.sp_compile import apply_autosp -def init_autosp(sp_size=2, dp_size=1): +def init_autosp(compile_config): def backend_fn(gm: GraphModule, real_inputs): - apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size) + apply_autosp(gm, real_inputs, debug=False, sp_size=compile_config.sp_size, dp_size=compile_config.dp_size) return torch._inductor.compile(gm, real_inputs) return backend_fn diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9b2cc3f4863a..1164c90562cf 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2382,7 +2382,7 @@ def print_forward_breakdown(self, fwd_time): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Skip gradient reduction when DeepCompile is enabled # DeepCompile handles its own gradient reduction through compiled graph operations - if self.is_deepcompile_active(): + if self.is_deepcompile_active() and not self.compile_autosp(): return # Pass (PP) gas boundary flag to optimizer (required for zero) @@ -4392,8 +4392,7 @@ def passes_name_to_fn(passes): compile_config = self._config.compile_config if self.compile_autosp(): - backend = init_autosp(sp_size=self._config.compile_config.sp_size, dp_size=self._config.compile_config.dp_size) - #backend = init_ulysses(self, backend, compile_config, compile_kwargs, schedule, sp_size=self._config.compile_config.sp_size, dp_size=self._config.compile_config.dp_size) + backend = init_autosp(compile_config) else: ## By default then only zero-style DP should be triggered in dc. ## if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] and "offload_param" in self.config["zero_optimization"]) From 82beda161ac3eb55955841a73c118789fe6c5037 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sat, 28 Feb 2026 18:36:21 -0600 Subject: [PATCH 07/14] rename autosp specific constants --- deepspeed/compile/passes/sp_compile.py | 6 +++--- deepspeed/compile/util.py | 12 +++--------- deepspeed/runtime/constants.py | 6 +++--- deepspeed/runtime/engine.py | 3 +-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index 4cbc74fba962..0f0d2d6ceb33 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -61,10 +61,10 @@ def prepare_autosp_inputs(input_id: torch.Tensor, label_id: torch.Tensor, positi if attention_mask is not None: torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) - input_id.tag = constants.INPUT_ID_KEY - label_id.tag = constants.LABEL_ID_KEY + input_id.tag = constants.AUTOSP_INPUT_ID_KEY + label_id.tag = constants.AUTOSP_LABEL_ID_KEY if position_id is not None: - position_id.tag = constants.POSITION_ID_KEY + position_id.tag = constants.AUTOSP_POSITION_ID_KEY return input_id, label_id, position_id, attention_mask diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 25a047ab5852..908aac0943e9 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -6,7 +6,6 @@ import functools import operator from typing import List, Tuple, Dict, Optional -from dataclasses import dataclass from collections import defaultdict import torch @@ -526,11 +525,6 @@ def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor return padded -INPUT_ID_KEY = "input_id" -LABEL_ID_KEY = "label_id" -POSITION_ID_KEY = "position_id" - - def create_shard_offsets( gm: GraphModule, s0_node: Node @@ -566,21 +560,21 @@ def get_sdpa_nodes(gm: GraphModule) -> List[Node]: def get_input_id_node(gm: GraphModule) -> Node: from .fx import find_node_by_tag - node = find_node_by_tag(gm, constants.INPUT_ID_KEY) + node = find_node_by_tag(gm, constants.AUTOSP_INPUT_ID_KEY) if node is None: raise RuntimeError("Failed to find a node for the input sequence.") return node def get_label_id_node(gm: GraphModule) -> Node: from .fx import find_node_by_tag - node = find_node_by_tag(gm, constants.LABEL_ID_KEY) + node = find_node_by_tag(gm, constants.AUTOSP_LABEL_ID_KEY) if node is None: raise RuntimeError("Failed to find a node for the label.") return node def get_position_id_node(gm: GraphModule) -> Node: from .fx import find_node_by_tag - node = find_node_by_tag(gm, constants.POSITION_ID_KEY) + node = find_node_by_tag(gm, constants.AUTOSP_POSITION_ID_KEY) return node diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index a916befc76f4..c81389d9ae28 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -505,6 +505,6 @@ class ValidationMode: ######################################### # AUTOSP ######################################### -INPUT_ID_KEY = "input_id" -LABEL_ID_KEY = "label_id" -POSITION_ID_KEY = "position_id" +AUTOSP_INPUT_ID_KEY = "input_id" +AUTOSP_LABEL_ID_KEY = "label_id" +AUTOSP_POSITION_ID_KEY = "position_id" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1164c90562cf..6f7bdaeb0f1a 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4370,8 +4370,7 @@ def compile(self, enable_deepcompile = self.is_deepcompile_enabled() if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ and self.zero_optimization_stage() != ZeroStageEnum.weights \ - and self.zero_optimization_stage() != ZeroStageEnum.gradients \ - and self.zero_optimization_stage() != ZeroStageEnum.disabled: + and self.zero_optimization_stage() != ZeroStageEnum.gradients: logger.info( f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." ) From 634b706391b9ff6407090e27b894d89110ca1b22 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sat, 28 Feb 2026 18:57:55 -0600 Subject: [PATCH 08/14] fix merge conflicts --- deepspeed/compile/backend.py | 6 ------ deepspeed/compile/passes/sp_compile.py | 3 --- deepspeed/compile/util.py | 11 ----------- 3 files changed, 20 deletions(-) diff --git a/deepspeed/compile/backend.py b/deepspeed/compile/backend.py index b5d1e9890ba5..6aa43f64615d 100644 --- a/deepspeed/compile/backend.py +++ b/deepspeed/compile/backend.py @@ -385,9 +385,3 @@ def compiler_fn(gm, sample_inputs): return backend_fn - -def make_autosp_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False, sp_size=2, dp_size=1): - def backend_fn(gm: GraphModule, real_inputs): - apply_autosp(gm, real_inputs, debug_log, sp_size=sp_size, dp_size=dp_size) - return torch._inductor.compile(gm, real_inputs) - return backend_fn diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index 0f0d2d6ceb33..d6569c1fe96c 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -128,18 +128,15 @@ def pass_shard_seq_dim(gm: GraphModule, example_inputs): def pass_shard_input_ids(gm: GraphModule, example_inputs): - """Shard input_ids tensor across ranks.""" input_ids_node = get_input_id_node(gm) shard_tensor_node(gm, input_ids_node) def pass_shard_label_ids(gm: GraphModule, example_inputs): - """Shard label_ids tensor across ranks.""" label_ids_node = get_label_id_node(gm) shard_tensor_node(gm, label_ids_node) def pass_shard_position_ids(gm: GraphModule, example_inputs): - """Shard position_ids tensor across ranks.""" position_ids_node = get_position_id_node(gm) if position_ids_node is None: print("[WARNING] position id node not found. Skipping sharding of position ids.") diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 908aac0943e9..1071d6a5e9f0 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -529,17 +529,6 @@ def create_shard_offsets( gm: GraphModule, s0_node: Node ) -> Tuple[Node, Node]: - """ - Create FX nodes for computing shard start and end offsets. - - Computes: - chunk_size = s0 // sp_size - start = rank * chunk_size - end = start + chunk_size - - Returns: - Tuple of (start_node, end_node) - """ sp_size: int = sp_dp_registry.sp_size() sp_rank: int = dist.get_rank() % sp_dp_registry.sp_size() with gm.graph.inserting_after(s0_node): From ee8bdb06be632efc18531238cf391a6cbd1a7b7f Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sun, 1 Mar 2026 17:57:15 -0600 Subject: [PATCH 09/14] add missing __init__.py file --- deepspeed/compile/custom_ops/__init__.py | 4 ++++ deepspeed/compile/passes/sp_compile.py | 19 +++++++++---------- deepspeed/runtime/engine.py | 1 + 3 files changed, 14 insertions(+), 10 deletions(-) create mode 100644 deepspeed/compile/custom_ops/__init__.py diff --git a/deepspeed/compile/custom_ops/__init__.py b/deepspeed/compile/custom_ops/__init__.py new file mode 100644 index 000000000000..358ec2c8ef3f --- /dev/null +++ b/deepspeed/compile/custom_ops/__init__.py @@ -0,0 +1,4 @@ +from .all_to_all import all_to_all +from .sp_dp_registry import sp_dp_registry + +__all__ = ["all_to_all", "sp_dp_registry"] diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index d6569c1fe96c..a7be66052481 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -145,16 +145,6 @@ def pass_shard_position_ids(gm: GraphModule, example_inputs): def pass_insert_attention_all_to_all(gm: GraphModule, real_inputs): - """ - Insert all-to-all collectives around SDPA for Ulysses parallelism. - - For each SDPA: - - Before Q, K, V: scatter heads (dim=1), gather sequence (dim=2) - - After O: scatter sequence (dim=2), gather heads (dim=1) - """ - world_size = dist.get_world_size() - attention_nodes = get_sdpa_nodes(gm) - def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node: with gm.graph.inserting_after(node): a2a_node = gm.graph.call_function( @@ -166,6 +156,15 @@ def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node a2a_node.update_arg(0, node) return a2a_node + attention_nodes = get_sdpa_nodes(gm) + if len(attention_nodes) == 0: + raise RuntimeError( + "AutoSP currently supports torch.nn.functional.scaled_dot_product_attention as the " + "attention backend. No SDPA attention operations were found in the compiled graph. " + "Please ensure your model uses torch.nn.functional.scaled_dot_product_attention " + "for AutoSP to work as expected." + ) + for idx, attn_node in enumerate(attention_nodes): q, k, v = attn_node.args[:3] suffix = f"_{idx}" if len(attention_nodes) > 1 else "" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6f7bdaeb0f1a..ebf821d0da8e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4391,6 +4391,7 @@ def passes_name_to_fn(passes): compile_config = self._config.compile_config if self.compile_autosp(): + compile_kwargs['fullgraph'] = True backend = init_autosp(compile_config) else: ## By default then only zero-style DP should be triggered in dc. ## if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] From a02897315cd8d8b967969d02c0ee1bbea388e83d Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sun, 1 Mar 2026 19:48:38 -0600 Subject: [PATCH 10/14] fix __init__.py --- deepspeed/compile/custom_ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/compile/custom_ops/__init__.py b/deepspeed/compile/custom_ops/__init__.py index 358ec2c8ef3f..85164c7beabc 100644 --- a/deepspeed/compile/custom_ops/__init__.py +++ b/deepspeed/compile/custom_ops/__init__.py @@ -1,4 +1,4 @@ from .all_to_all import all_to_all -from .sp_dp_registry import sp_dp_registry +from . import sp_dp_registry __all__ = ["all_to_all", "sp_dp_registry"] From 8c77acdb8b4b4cd67df0a4f8cf071996454ce178 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sun, 1 Mar 2026 23:40:36 -0600 Subject: [PATCH 11/14] address review comments --- deepspeed/compile/constants.py | 6 ++++++ deepspeed/compile/passes/sp_compile.py | 2 +- deepspeed/compile/util.py | 2 +- deepspeed/runtime/constants.py | 6 ------ deepspeed/runtime/engine.py | 14 ++++++++++---- 5 files changed, 18 insertions(+), 12 deletions(-) create mode 100644 deepspeed/compile/constants.py diff --git a/deepspeed/compile/constants.py b/deepspeed/compile/constants.py new file mode 100644 index 000000000000..3580082a2499 --- /dev/null +++ b/deepspeed/compile/constants.py @@ -0,0 +1,6 @@ +######################################### +# AUTOSP +######################################### +AUTOSP_INPUT_ID_KEY = "input_id" +AUTOSP_LABEL_ID_KEY = "label_id" +AUTOSP_POSITION_ID_KEY = "position_id" diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index a7be66052481..aca4dac47edc 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -20,7 +20,7 @@ from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.experimental.symbolic_shapes import ShapeEnv -from deepspeed.runtime import constants +from deepspeed.compile import constants from ..custom_ops import all_to_all, sp_dp_registry from ..fx import find_node_by_name, get_node_shape_meta diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 1071d6a5e9f0..baf70f60e3fc 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -23,7 +23,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder -from deepspeed.runtime import constants +from deepspeed.compile import constants from .custom_ops import sp_dp_registry diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index c81389d9ae28..92abb58a2a49 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -502,9 +502,3 @@ class ValidationMode: USE_DATA_BEFORE_EXPERT_PARALLEL = "use_data_before_expert_parallelism" USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT = False -######################################### -# AUTOSP -######################################### -AUTOSP_INPUT_ID_KEY = "input_id" -AUTOSP_LABEL_ID_KEY = "label_id" -AUTOSP_POSITION_ID_KEY = "position_id" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ebf821d0da8e..7166943923b3 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4370,10 +4370,16 @@ def compile(self, enable_deepcompile = self.is_deepcompile_enabled() if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ and self.zero_optimization_stage() != ZeroStageEnum.weights \ - and self.zero_optimization_stage() != ZeroStageEnum.gradients: - logger.info( - f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." - ) + and self.zero_optimization_stage() != ZeroStageEnum.gradients \ + and self.compile_autosp() and self.zero_optimization_stage() not in [ZeroStageEnum.disabled, ZeroStageEnum.optimizer_states]: + if self.compile_autosp(): + logger.info( + f"Currently AutoSP does not compose with ZeRO stage 2 and 3. Falling back to the torch compiler." + ) + else: + logger.info( + f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." + ) enable_deepcompile = False if enable_deepcompile: From f5e4d4bcd63ff1a8187367aab19c8bc83e2711c4 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Mon, 2 Mar 2026 00:37:19 -0600 Subject: [PATCH 12/14] refactor engine.py to validate and select compiler backend --- deepspeed/runtime/engine.py | 112 +++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 46 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7166943923b3..28ea3f40dbee 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4345,6 +4345,66 @@ def empty_partition_cache(self): gc.collect() get_accelerator().empty_cache() + def get_autosp_backend(self, compile_kwargs): + if self.compile_autosp() and self.zero_optimization_stage() not in [ZeroStageEnum.disabled, ZeroStageEnum.optimizer_states]: + logger.info( + f"Currently AutoSP does not compose with ZeRO stage 2 and 3. Falling back to the torch compiler." + ) + return None + + compile_config = self._config.compile_config + compile_kwargs['fullgraph'] = True + return init_autosp(compile_config) + + def get_deepcompile_backend(self, backend, compile_kwargs, schedule): + if self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ + and self.zero_optimization_stage() != ZeroStageEnum.weights \ + and self.zero_optimization_stage() != ZeroStageEnum.gradients: + logger.info( + f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." + ) + return None + + compile_config = self._config.compile_config + if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] + and "offload_param" in self.config["zero_optimization"]) + and self._config.zero_config.offload_param.device == "cpu" + and self._config.zero_config.offload_optimizer.device == "cpu"): + compile_config.offload_parameters = True + if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: + return init_z1(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.gradients: + return init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) + elif self.zero_optimization_stage() == ZeroStageEnum.weights: + return init_z3(self, backend, compile_config, compile_kwargs, schedule) + return None + + def get_deepspeed_compile_backend(self, backend, compile_kwargs, schedule): + resolved_backend = None + + if schedule is not None: + + def passes_name_to_fn(passes): + for p in passes: + assert callable(p) or p in opt_passes, f"Unknown pass {p}" + return [p if callable(p) else opt_passes[p] for p in passes] + + schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] + + assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." + + if self.compile_autosp(): + resolved_backend = self.get_autosp_backend(compile_kwargs) + else: + if self.validate_deepcompile_config(): + resolved_backend = self.get_deepcompile_backend(backend, compile_kwargs, schedule) + + # Fallback to torch backend if no DeepSpeed backend was selected. + if resolved_backend is None: + resolved_backend = backend + + return resolved_backend, schedule + def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, @@ -4367,59 +4427,19 @@ def compile(self, logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") - enable_deepcompile = self.is_deepcompile_enabled() - if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ - and self.zero_optimization_stage() != ZeroStageEnum.weights \ - and self.zero_optimization_stage() != ZeroStageEnum.gradients \ - and self.compile_autosp() and self.zero_optimization_stage() not in [ZeroStageEnum.disabled, ZeroStageEnum.optimizer_states]: - if self.compile_autosp(): - logger.info( - f"Currently AutoSP does not compose with ZeRO stage 2 and 3. Falling back to the torch compiler." - ) - else: - logger.info( - f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." - ) - enable_deepcompile = False - - if enable_deepcompile: - - if schedule is not None: - - def passes_name_to_fn(passes): - for p in passes: - assert callable(p) or p in opt_passes, f"Unknown pass {p}" - return [p if callable(p) else opt_passes[p] for p in passes] - - schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] - - assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." - - compile_config = self._config.compile_config - if self.compile_autosp(): - compile_kwargs['fullgraph'] = True - backend = init_autosp(compile_config) - else: ## By default then only zero-style DP should be triggered in dc. ## - if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] - and "offload_param" in self.config["zero_optimization"]) - and self._config.zero_config.offload_param.device == "cpu" - and self._config.zero_config.offload_optimizer.device == "cpu"): - compile_config.offload_parameters = True - if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule) - elif self.zero_optimization_stage() == ZeroStageEnum.gradients: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) - elif self.zero_optimization_stage() == ZeroStageEnum.weights: - backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) + if self.is_deepcompile_enabled(): + backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule) + is_deepspeed_compile_backend = backend is not None + # Hook state must align with whether DeepCompile is active. - self._set_deepcompile_active(enable_deepcompile) + self._set_deepcompile_active(is_deepspeed_compile_backend) # create new dict to avoid modifying original dict try: self.module.compile(**{**compile_kwargs, 'backend': backend}) except Exception: - if enable_deepcompile: + if is_deepspeed_compile_backend: # Restore default hooks if compilation fails before completing. self._set_deepcompile_active(False) raise From 4fdc54b84a3003517f16512c7b327070f2069b93 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Mon, 2 Mar 2026 01:23:40 -0600 Subject: [PATCH 13/14] fallback to torch backend --- deepspeed/runtime/engine.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 28ea3f40dbee..d5f7d8d211f6 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4396,12 +4396,7 @@ def passes_name_to_fn(passes): if self.compile_autosp(): resolved_backend = self.get_autosp_backend(compile_kwargs) else: - if self.validate_deepcompile_config(): - resolved_backend = self.get_deepcompile_backend(backend, compile_kwargs, schedule) - - # Fallback to torch backend if no DeepSpeed backend was selected. - if resolved_backend is None: - resolved_backend = backend + resolved_backend = self.get_deepcompile_backend(backend, compile_kwargs, schedule) return resolved_backend, schedule @@ -4428,9 +4423,12 @@ def compile(self, logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") if self.is_deepcompile_enabled(): - backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule) + resolved_backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule) + + is_deepspeed_compile_backend = resolved_backend is not None - is_deepspeed_compile_backend = backend is not None + # default to torch.compiler backend if deepspeed config validation fails + backend = resolved_backend or backend # Hook state must align with whether DeepCompile is active. self._set_deepcompile_active(is_deepspeed_compile_backend) From a72cd7d8ff2e4c9c1c1a19405a0052471cb3d992 Mon Sep 17 00:00:00 2001 From: Ahan Gupta Date: Wed, 4 Mar 2026 17:16:25 +0000 Subject: [PATCH 14/14] refactor to avoid None edgecases --- deepspeed/compile/custom_ops/all_to_all.py | 2 +- deepspeed/compile/passes/sp_compile.py | 2 +- deepspeed/runtime/engine.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py index 58835d444edd..03d20aae83b0 100644 --- a/deepspeed/compile/custom_ops/all_to_all.py +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -65,7 +65,7 @@ def _all_to_all_backward_setup(ctx, inputs, output): def _all_to_all_backward(ctx, grad): return ( all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), - None, None, None, None + None, None, None ) diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index aca4dac47edc..308187b45d26 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -212,7 +212,7 @@ def apply_autosp( debug: If True, print graph before/after each pass passes: Optional custom list of passes (default: DEFAULT_PASSES) """ - assert sp_size * dp_size <= torch.cuda.device_count(), 'Insufficient device count for mesh size' + assert sp_size * dp_size <= dist.get_world_size(), 'Insufficient device count for mesh size' sp_dp_registry.populate_registry(sp_size, dp_size) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d5f7d8d211f6..39472ef803f9 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1011,7 +1011,7 @@ def compile_zero_optimization_stage(self): def compile_autosp(self): """Determines if AutoSP is set in deepcompile's passes attributes.""" - return "autosp" in self._config.compile_config.passes + return "autosp" in (getattr(self._config.compile_config, "passes", None) or []) def mics_shard_size(self): return self._config.mics_shard_size @@ -4422,6 +4422,7 @@ def compile(self, logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") + resolved_backend = None if self.is_deepcompile_enabled(): resolved_backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule)