-
Notifications
You must be signed in to change notification settings - Fork 20
[Dev] Add Llama3 training example and fix cache save #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+482
β0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,237 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. | ||
|
|
||
| import math | ||
| from dataclasses import dataclass | ||
| from typing import Optional, Tuple | ||
|
|
||
| import fairscale.nn.model_parallel.initialize as fs_init | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding | ||
| from torch import nn | ||
|
|
||
| from magi_compiler import magi_compile | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelArgs: | ||
| dim: int = 4096 | ||
| n_layers: int = 32 | ||
| n_heads: int = 32 | ||
| n_kv_heads: int = 8 | ||
| vocab_size: int = 128256 | ||
| multiple_of: int = 256 | ||
| ffn_dim_multiplier: float = 1.3 | ||
| norm_eps: float = 1e-5 | ||
| rope_theta: float = 500000 | ||
|
|
||
| max_batch_size: int = 32 | ||
| max_seq_len: int = 8192 | ||
|
|
||
|
|
||
| class RMSNorm(torch.nn.Module): | ||
| def __init__(self, dim: int, eps: float = 1e-6): | ||
| super().__init__() | ||
| self.eps = eps | ||
| self.weight = nn.Parameter(torch.ones(dim)) | ||
|
|
||
| def _norm(self, x): | ||
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
|
||
| def forward(self, x): | ||
| output = self._norm(x.float()).type_as(x) | ||
| return output * self.weight | ||
|
|
||
|
|
||
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | ||
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
| t = torch.arange(end, device=freqs.device, dtype=torch.float32) | ||
| freqs = torch.outer(t, freqs) | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | ||
| return freqs_cis | ||
|
|
||
|
|
||
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | ||
| ndim = x.ndim | ||
| assert 0 <= 1 < ndim | ||
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | ||
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | ||
| return freqs_cis.view(*shape) | ||
|
|
||
|
|
||
| def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | ||
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
| return xq_out.type_as(xq), xk_out.type_as(xk) | ||
|
|
||
|
|
||
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
| """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" | ||
| bs, slen, n_kv_heads, head_dim = x.shape | ||
| if n_rep == 1: | ||
| return x | ||
| return x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim) | ||
|
|
||
|
|
||
| class Attention(nn.Module): | ||
| def __init__(self, args: ModelArgs): | ||
| super().__init__() | ||
| self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads | ||
| model_parallel_size = fs_init.get_model_parallel_world_size() | ||
| self.n_local_heads = args.n_heads // model_parallel_size | ||
| self.n_local_kv_heads = self.n_kv_heads // model_parallel_size | ||
| self.n_rep = self.n_local_heads // self.n_local_kv_heads | ||
| self.head_dim = args.dim // args.n_heads | ||
|
|
||
| self.wq = ColumnParallelLinear( | ||
| args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x | ||
| ) | ||
| self.wk = ColumnParallelLinear( | ||
| args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x | ||
| ) | ||
| self.wv = ColumnParallelLinear( | ||
| args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x | ||
| ) | ||
| self.wo = RowParallelLinear( | ||
| args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x | ||
| ) | ||
|
|
||
| self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda() | ||
| self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda() | ||
|
|
||
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
| bsz, seqlen, _ = x.shape | ||
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | ||
|
|
||
| xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
| xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | ||
| xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | ||
|
|
||
| xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | ||
|
|
||
| if self.training: | ||
| keys = xk | ||
| values = xv | ||
| else: | ||
| self.cache_k = self.cache_k.to(xq) | ||
| self.cache_v = self.cache_v.to(xq) | ||
|
|
||
| self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk | ||
| self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv | ||
|
|
||
| keys = self.cache_k[:bsz, : start_pos + seqlen] | ||
| values = self.cache_v[:bsz, : start_pos + seqlen] | ||
|
|
||
| # repeat k/v heads if n_kv_heads < n_heads | ||
| keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) | ||
| values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) | ||
|
|
||
| xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | ||
| keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | ||
| values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | ||
| scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
| if mask is not None: | ||
| scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) | ||
| scores = F.softmax(scores.float(), dim=-1).type_as(xq) | ||
| output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) | ||
| output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) | ||
| return self.wo(output) | ||
|
|
||
|
|
||
| class FeedForward(nn.Module): | ||
| def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]): | ||
| super().__init__() | ||
| hidden_dim = int(2 * hidden_dim / 3) | ||
| # custom dim factor multiplier | ||
| if ffn_dim_multiplier is not None: | ||
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) | ||
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | ||
|
|
||
| self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) | ||
| self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x) | ||
| self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) | ||
|
|
||
| def forward(self, x): | ||
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
|
||
|
|
||
| @magi_compile(dynamic_arg_dims={"x": 0}) | ||
| class TransformerBlock(nn.Module): | ||
| def __init__(self, layer_id: int, args: ModelArgs): | ||
| super().__init__() | ||
| self.n_heads = args.n_heads | ||
| self.dim = args.dim | ||
| self.head_dim = args.dim // args.n_heads | ||
| self.attention = Attention(args) | ||
| self.feed_forward = FeedForward( | ||
| dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier | ||
| ) | ||
| self.layer_id = layer_id | ||
| self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
| self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
|
|
||
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
| h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) | ||
| out = h + self.feed_forward(self.ffn_norm(h)) | ||
| return out | ||
|
|
||
|
|
||
| class Transformer(nn.Module): | ||
| def __init__(self, params: ModelArgs): | ||
| super().__init__() | ||
| self.params = params | ||
| self.vocab_size = params.vocab_size | ||
| self.n_layers = params.n_layers | ||
|
|
||
| self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x) | ||
|
|
||
| self.layers = torch.nn.ModuleList() | ||
| for layer_id in range(params.n_layers): | ||
| self.layers.append(TransformerBlock(layer_id, params)) | ||
|
|
||
| self.norm = RMSNorm(params.dim, eps=params.norm_eps) | ||
| self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x) | ||
|
|
||
| self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len * 2, params.rope_theta) | ||
|
|
||
| def forward(self, tokens: torch.Tensor, start_pos: int): | ||
| _bsz, seqlen = tokens.shape | ||
| h = self.tok_embeddings(tokens) | ||
| self.freqs_cis = self.freqs_cis.to(h.device) | ||
| freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] | ||
|
|
||
| mask = None | ||
| if seqlen > 1: | ||
| mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) | ||
|
|
||
| mask = torch.triu(mask, diagonal=1) | ||
|
|
||
| # When performing key-value caching, we compute the attention scores | ||
| # only for the new sequence. Thus, the matrix of scores is of size | ||
| # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for | ||
| # j > cache_len + i, since row i corresponds to token cache_len + i. | ||
| mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h) | ||
|
|
||
| for layer in self.layers: | ||
| h = layer(h, start_pos, freqs_cis, mask) | ||
| h = self.norm(h) | ||
| output = self.output(h).float() | ||
| return output | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. | ||
|
|
||
| import os | ||
| import time | ||
| from functools import partial | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from fairscale.nn.model_parallel.initialize import initialize_model_parallel, model_parallel_is_initialized | ||
| from llama3 import ModelArgs, Transformer, TransformerBlock | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | ||
|
|
||
| import magi_compiler.utils.nvtx as nvtx | ||
|
|
||
|
|
||
| def setup_fsdp(model: nn.Module, device_id: int): | ||
| """ | ||
| Wrap the given Llama3 model with PyTorch FSDP. | ||
| We apply auto_wrap_policy to wrap each TransformerBlock individually. | ||
| """ | ||
| llama_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock}) | ||
|
|
||
| fsdp_model = FSDP( | ||
| model, auto_wrap_policy=llama_auto_wrap_policy, device_id=device_id, sync_module_states=True, use_orig_params=True | ||
| ) | ||
| local_params = sum(p.numel() for p in fsdp_model.parameters()) | ||
| print(f"[Rank {device_id}] Local param count: {local_params:,}") | ||
| return fsdp_model | ||
|
|
||
|
|
||
| def main(): | ||
| # Setup distributed environment | ||
| if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: | ||
| print("Not running in distributed mode. Set RANK and WORLD_SIZE to use FSDP.") | ||
| # For demonstration purposes, we will mock the distributed setup if not provided. | ||
| os.environ["RANK"] = "0" | ||
| os.environ["WORLD_SIZE"] = "1" | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "12355" | ||
|
|
||
| dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo") | ||
|
|
||
| # Initialize model parallel group (since model uses fairscale VocabParallelEmbedding) | ||
| if not model_parallel_is_initialized(): | ||
| initialize_model_parallel(1) | ||
|
|
||
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) | ||
| global_rank = int(os.environ.get("RANK", 0)) | ||
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | ||
|
|
||
| if torch.cuda.is_available(): | ||
| torch.cuda.set_device(local_rank) | ||
| device = torch.device(f"cuda:{local_rank}") | ||
| else: | ||
| device = torch.device("cpu") | ||
|
|
||
| # Initialize model config | ||
| config = ModelArgs(max_batch_size=2, max_seq_len=1024) | ||
|
|
||
| # Create Model | ||
| if global_rank == 0: | ||
| print(f"Initializing model on {world_size} devices...") | ||
| model = Transformer(config).to(device) | ||
|
|
||
| # Wrap with FSDP | ||
| if global_rank == 0: | ||
| print("Wrapping model with FSDP...") | ||
|
|
||
| if torch.cuda.is_available(): | ||
| fsdp_model = setup_fsdp(model, device_id=local_rank) | ||
| else: | ||
| # For CPU testing, fallback to DDP or just standard model if FSDP CPU is not supported | ||
| print(f"[Rank {global_rank}] CUDA not available. Running without FSDP for CPU fallback.") | ||
| fsdp_model = model | ||
|
|
||
| # Optimizer | ||
| optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) | ||
|
|
||
| num_epochs = 5 | ||
| bsz, seq_len = config.max_batch_size, config.max_seq_len | ||
|
|
||
| if global_rank == 0: | ||
| print(f"Starting training for {num_epochs} epochs...") | ||
|
|
||
| for epoch in range(num_epochs): | ||
| # record start time | ||
| start_time = time.time() | ||
|
|
||
| # Dummy data for each epoch | ||
| # Ensure different ranks get different data if needed, but here we just generate random | ||
| torch.manual_seed(epoch * world_size + global_rank) | ||
| input_ids = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device) | ||
| labels = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device) | ||
|
|
||
| optimizer.zero_grad() | ||
|
|
||
| nvtx.switch_profile(epoch, 3, 4) | ||
| # Forward pass | ||
| logits = fsdp_model(input_ids, start_pos=0) | ||
|
|
||
| # Loss | ||
| loss = F.cross_entropy(logits.view(-1, config.vocab_size), labels.view(-1)) | ||
|
|
||
| # Backward pass | ||
| loss.backward() | ||
|
|
||
| optimizer.step() | ||
|
|
||
| # record end time | ||
| end_time = time.time() | ||
|
|
||
| if global_rank == 0: | ||
| print( | ||
| f"Epoch {epoch + 1}/{num_epochs} | Loss: {loss.item():.4f} | Time taken: {end_time - start_time:.4f} seconds" | ||
| ) | ||
|
|
||
| if global_rank == 0: | ||
| print("Training done!") | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.