Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 237 additions & 0 deletions example/training/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import math
from dataclasses import dataclass
from typing import Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from torch import nn

from magi_compiler import magi_compile


@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: int = 8
vocab_size: int = 128256
multiple_of: int = 256
ffn_dim_multiplier: float = 1.3
norm_eps: float = 1e-5
rope_theta: float = 500000

max_batch_size: int = 32
max_seq_len: int = 8192


class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads

self.wq = ColumnParallelLinear(
args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.wk = ColumnParallelLinear(
args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.wv = ColumnParallelLinear(
args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)

self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda()
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda()

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

if self.training:
keys = xk
values = xv
else:
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)


class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))


@magi_compile(dynamic_arg_dims={"x": 0})
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out


class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)

self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len * 2, params.rope_theta)

def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)

mask = torch.triu(mask, diagonal=1)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output
141 changes: 141 additions & 0 deletions example/training/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.

import os
import time
from functools import partial

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import initialize_model_parallel, model_parallel_is_initialized
from llama3 import ModelArgs, Transformer, TransformerBlock
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

import magi_compiler.utils.nvtx as nvtx


def setup_fsdp(model: nn.Module, device_id: int):
"""
Wrap the given Llama3 model with PyTorch FSDP.
We apply auto_wrap_policy to wrap each TransformerBlock individually.
"""
llama_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock})

fsdp_model = FSDP(
model, auto_wrap_policy=llama_auto_wrap_policy, device_id=device_id, sync_module_states=True, use_orig_params=True
)
local_params = sum(p.numel() for p in fsdp_model.parameters())
print(f"[Rank {device_id}] Local param count: {local_params:,}")
return fsdp_model


def main():
# Setup distributed environment
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
print("Not running in distributed mode. Set RANK and WORLD_SIZE to use FSDP.")
# For demonstration purposes, we will mock the distributed setup if not provided.
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo")

# Initialize model parallel group (since model uses fairscale VocabParallelEmbedding)
if not model_parallel_is_initialized():
initialize_model_parallel(1)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))

if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cpu")

# Initialize a small config for testing
config = ModelArgs(n_layers=10, max_batch_size=2, max_seq_len=1024)

# Create Model
if global_rank == 0:
print(f"Initializing model on {world_size} devices...")
model = Transformer(config).to(device)

# Wrap with FSDP
if global_rank == 0:
print("Wrapping model with FSDP...")

if torch.cuda.is_available():
fsdp_model = setup_fsdp(model, device_id=local_rank)
else:
# For CPU testing, fallback to DDP or just standard model if FSDP CPU is not supported
print(f"[Rank {global_rank}] CUDA not available. Running without FSDP for CPU fallback.")
fsdp_model = model

# Optimizer
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2)

num_epochs = 5
bsz, seq_len = config.max_batch_size, config.max_seq_len

if global_rank == 0:
print(f"Starting training for {num_epochs} epochs...")

for epoch in range(num_epochs):
# record start time
start_time = time.time()

# Dummy data for each epoch
# Ensure different ranks get different data if needed, but here we just generate random
torch.manual_seed(epoch * world_size + global_rank)
input_ids = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device)
labels = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device)

optimizer.zero_grad()

nvtx.switch_profile(epoch, 3, 4)
# Forward pass
logits = fsdp_model(input_ids, start_pos=0)

# Loss
loss = F.cross_entropy(logits.view(-1, config.vocab_size), labels.view(-1))

# Backward pass
loss.backward()

optimizer.step()

# record end time
end_time = time.time()

if global_rank == 0:
print(
f"Epoch {epoch + 1}/{num_epochs} | Loss: {loss.item():.4f} | Time taken: {end_time - start_time:.4f} seconds"
)

if global_rank == 0:
print("Training done!")
dist.destroy_process_group()


if __name__ == "__main__":
main()
Loading
Loading