diff --git a/example/training/llama3.py b/example/training/llama3.py new file mode 100644 index 0000000..88447a6 --- /dev/null +++ b/example/training/llama3.py @@ -0,0 +1,237 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from torch import nn + +from magi_compiler import magi_compile + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int = 8 + vocab_size: int = 128256 + multiple_of: int = 256 + ffn_dim_multiplier: float = 1.3 + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 8192 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.wk = ColumnParallelLinear( + args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.wv = ColumnParallelLinear( + args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + + self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda() + self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda() + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + if self.training: + keys = xk + values = xv + else: + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) + self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x) + self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +@magi_compile(dynamic_arg_dims={"x": 0}) +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x) + + self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len * 2, params.rope_theta) + + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output diff --git a/example/training/train.py b/example/training/train.py new file mode 100644 index 0000000..e08e831 --- /dev/null +++ b/example/training/train.py @@ -0,0 +1,141 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +import time +from functools import partial + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import initialize_model_parallel, model_parallel_is_initialized +from llama3 import ModelArgs, Transformer, TransformerBlock +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + +import magi_compiler.utils.nvtx as nvtx + + +def setup_fsdp(model: nn.Module, device_id: int): + """ + Wrap the given Llama3 model with PyTorch FSDP. + We apply auto_wrap_policy to wrap each TransformerBlock individually. + """ + llama_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock}) + + fsdp_model = FSDP( + model, auto_wrap_policy=llama_auto_wrap_policy, device_id=device_id, sync_module_states=True, use_orig_params=True + ) + local_params = sum(p.numel() for p in fsdp_model.parameters()) + print(f"[Rank {device_id}] Local param count: {local_params:,}") + return fsdp_model + + +def main(): + # Setup distributed environment + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + print("Not running in distributed mode. Set RANK and WORLD_SIZE to use FSDP.") + # For demonstration purposes, we will mock the distributed setup if not provided. + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo") + + # Initialize model parallel group (since model uses fairscale VocabParallelEmbedding) + if not model_parallel_is_initialized(): + initialize_model_parallel(1) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + global_rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + else: + device = torch.device("cpu") + + # Initialize a small config for testing + config = ModelArgs(n_layers=10, max_batch_size=2, max_seq_len=1024) + + # Create Model + if global_rank == 0: + print(f"Initializing model on {world_size} devices...") + model = Transformer(config).to(device) + + # Wrap with FSDP + if global_rank == 0: + print("Wrapping model with FSDP...") + + if torch.cuda.is_available(): + fsdp_model = setup_fsdp(model, device_id=local_rank) + else: + # For CPU testing, fallback to DDP or just standard model if FSDP CPU is not supported + print(f"[Rank {global_rank}] CUDA not available. Running without FSDP for CPU fallback.") + fsdp_model = model + + # Optimizer + optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) + + num_epochs = 5 + bsz, seq_len = config.max_batch_size, config.max_seq_len + + if global_rank == 0: + print(f"Starting training for {num_epochs} epochs...") + + for epoch in range(num_epochs): + # record start time + start_time = time.time() + + # Dummy data for each epoch + # Ensure different ranks get different data if needed, but here we just generate random + torch.manual_seed(epoch * world_size + global_rank) + input_ids = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device) + labels = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device) + + optimizer.zero_grad() + + nvtx.switch_profile(epoch, 3, 4) + # Forward pass + logits = fsdp_model(input_ids, start_pos=0) + + # Loss + loss = F.cross_entropy(logits.view(-1, config.vocab_size), labels.view(-1)) + + # Backward pass + loss.backward() + + optimizer.step() + + # record end time + end_time = time.time() + + if global_rank == 0: + print( + f"Epoch {epoch + 1}/{num_epochs} | Loss: {loss.item():.4f} | Time taken: {end_time - start_time:.4f} seconds" + ) + + if global_rank == 0: + print("Training done!") + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/example/training/train.sh b/example/training/train.sh new file mode 100644 index 0000000..53b60f3 --- /dev/null +++ b/example/training/train.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +PROJECT_ROOT=$(cd "$SCRIPT_DIR/../.." &> /dev/null && pwd) + +### Distributed args ### +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-29500} + +GPUS_PER_NODE=${GPUS_PER_NODE:-1} +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) +DISTRIBUTED_ARGS="--nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=$GPUS_PER_NODE --rdzv-backend=c10d --rdzv-endpoint=$MASTER_ADDR:$MASTER_PORT" + +### Nsys args ### +NSYS_PROFILE=${NSYS_PROFILE:-true} + +if [ "$NSYS_PROFILE" = true ]; then + mkdir -p "$PROJECT_ROOT/nsys_reports" + + BASE_NAME="nsys_llama3" + TIMESTAMP=$(date +"%Y%m%d_%H%M%S") + NSYS_SUFFIX="ts_${TIMESTAMP}" + + [ -n "$WORLD_SIZE" ] && NSYS_SUFFIX="${NSYS_SUFFIX}_worldsize_${WORLD_SIZE}" + [ -n "$COMPILE_MODE" ] && NSYS_SUFFIX="${NSYS_SUFFIX}_compile_${COMPILE_MODE}" + [ -n "$CUDA_GRAPH_MODE" ] && NSYS_SUFFIX="${NSYS_SUFFIX}_cudagraph_${CUDA_GRAPH_MODE}" + + NSYS_OUTPUT="$PROJECT_ROOT/nsys_reports/${BASE_NAME}_${NSYS_SUFFIX}" + + NSYS_CMD="nsys profile --force-overwrite true -o $NSYS_OUTPUT --trace=cuda,nvtx --capture-range=cudaProfilerApi" +else + NSYS_CMD="" +fi + +### Environment Variables For Debugging ### +export ENABLE_REMOTE_DEBUG=${ENABLE_REMOTE_DEBUG:-false} +export MAGI_COMPILE_CACHE_ROOT_DIR=${MAGI_COMPILE_CACHE_ROOT_DIR:-"$PROJECT_ROOT/.cache"} +export MAGI_ENABLE_FX_GRAPH_VIZ=${MAGI_ENABLE_FX_GRAPH_VIZ:-false} + +$NSYS_CMD torchrun $DISTRIBUTED_ARGS $SCRIPT_DIR/train.py \ + $NSYS_ARGS diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index 293c11c..4a92e3a 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -221,6 +221,12 @@ def compile( restart_analysis_count = self._restart_analysis_counts.get(key, 0) if hasattr(self, "cache_dir") and self.cache_dir is not None: try: + # Workaround for empty aot_autograd artifacts + if getattr(compiled_graph, "_artifacts", None) is not None: + _, cache_info = compiled_graph._artifacts + if not cache_info.artifacts.get("aot_autograd"): + cache_info.artifacts["aot_autograd"] = [key] + path: Path = self.cache_dir / key compiled_graph.save(path=path.as_posix(), format="unpacked") compilation_counter.num_compiled_artifacts_saved += 1 diff --git a/magi_compiler/utils/nvtx.py b/magi_compiler/utils/nvtx.py index 7c36d82..f5729fb 100644 --- a/magi_compiler/utils/nvtx.py +++ b/magi_compiler/utils/nvtx.py @@ -103,3 +103,58 @@ def wrapped_fn(*args, **kwargs): return ret_val return cast(F, wrapped_fn) + + +def profile_start(event_name: str, record_shapes: bool = False): + torch.cuda.cudart().cudaProfilerStart() + torch.cuda.nvtx.range_push(event_name) + global _EMIT_NVTX_CTX + _EMIT_NVTX_CTX = torch.autograd.profiler.emit_nvtx(record_shapes=record_shapes) + _EMIT_NVTX_CTX.__enter__() + + +def profile_mark(event_name: str): + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push(event_name) + + +def profile_end(): + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + global _EMIT_NVTX_CTX + if _EMIT_NVTX_CTX is not None: + _EMIT_NVTX_CTX.__exit__(None, None, None) + _EMIT_NVTX_CTX = None + + +def switch_profile( + iter_id: int, start: int, end: int, profile_ranks: list[int] = None, event_name: str = None, record_shapes=True +): + """ + Controls the profiler state based on the iteration number. Turns on profiling + at the start iteration and turns it off at the end iteration. + + Args: + - iter_id: The current iteration number. + - start: The iteration number to start profiling. + - end: The iteration number to end profiling. + - profile_ranks: list of ranks to be profiled. + - event_name: Custom name for the profiling event. If None, defaults to 'iter_{iter_id}'. + """ + if profile_ranks is not None and torch.distributed.is_initialized() and torch.distributed.get_rank() not in profile_ranks: + return + + if event_name is None: + event_name = f"iter_{iter_id}" + + # Start profiling + if iter_id == start: + profile_start(event_name, record_shapes) + + # Stop profiling + elif iter_id == end: + profile_end() + + # Continue profiling + elif iter_id > start and iter_id < end: + profile_mark(event_name)