From 0a01d3447170a92b0d87eb78fc679ff77f4a748a Mon Sep 17 00:00:00 2001 From: wtr Date: Tue, 31 Mar 2026 16:44:51 +0800 Subject: [PATCH 1/5] refactor: simplify heuristic save-node selection & add Transformer test --- tests/model_definition.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/model_definition.py b/tests/model_definition.py index 1fc7c9e..2ebcce2 100644 --- a/tests/model_definition.py +++ b/tests/model_definition.py @@ -232,7 +232,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: up = self.up_proj(x) return self.down_proj(gate * up) - @magi_compile(dynamic_arg_dims={"x": 0}) class TransformerBlock(nn.Module): """A single Transformer block""" @@ -243,7 +242,6 @@ def __init__(self, config: TransformerConfig): self.self_attn = Attention(config) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = TransformerMLP(config) - def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: residual = x x = self.input_layernorm(x).to(torch.bfloat16) @@ -269,7 +267,6 @@ def __init__(self, config: TransformerConfig): self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype=config.params_dtype) - def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: """Forward pass of the Transformer model. From 7e55fd1c22a08a2577515935ad3b4ccfa06a55de Mon Sep 17 00:00:00 2001 From: wtr Date: Tue, 31 Mar 2026 20:19:45 +0800 Subject: [PATCH 2/5] chore --- tests/model_definition.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/model_definition.py b/tests/model_definition.py index 2ebcce2..1fc7c9e 100644 --- a/tests/model_definition.py +++ b/tests/model_definition.py @@ -232,6 +232,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: up = self.up_proj(x) return self.down_proj(gate * up) + @magi_compile(dynamic_arg_dims={"x": 0}) class TransformerBlock(nn.Module): """A single Transformer block""" @@ -242,6 +243,7 @@ def __init__(self, config: TransformerConfig): self.self_attn = Attention(config) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = TransformerMLP(config) + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: residual = x x = self.input_layernorm(x).to(torch.bfloat16) @@ -267,6 +269,7 @@ def __init__(self, config: TransformerConfig): self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype=config.params_dtype) + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: """Forward pass of the Transformer model. From cf5da14f7f0451c756a1a74760819d2bafb0607c Mon Sep 17 00:00:00 2001 From: wtr Date: Wed, 1 Apr 2026 17:03:09 +0800 Subject: [PATCH 3/5] dev: Add training example with llama & fix cache save bug --- example/training/model.py | 237 ++++++++++++++++++ example/training/tokenizer.py | 211 ++++++++++++++++ example/training/train.py | 232 +++++++++++++++++ example/training/train.sh | 60 +++++ .../magi_backend/piecewise_compiler.py | 6 + magi_compiler/utils/nvtx.py | 55 ++++ 6 files changed, 801 insertions(+) create mode 100644 example/training/model.py create mode 100644 example/training/tokenizer.py create mode 100644 example/training/train.py create mode 100644 example/training/train.sh diff --git a/example/training/model.py b/example/training/model.py new file mode 100644 index 0000000..88447a6 --- /dev/null +++ b/example/training/model.py @@ -0,0 +1,237 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from torch import nn + +from magi_compiler import magi_compile + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int = 8 + vocab_size: int = 128256 + multiple_of: int = 256 + ffn_dim_multiplier: float = 1.3 + norm_eps: float = 1e-5 + rope_theta: float = 500000 + + max_batch_size: int = 32 + max_seq_len: int = 8192 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.wk = ColumnParallelLinear( + args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.wv = ColumnParallelLinear( + args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + + self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda() + self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim)).cuda() + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + if self.training: + keys = xk + values = xv + else: + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) + self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x) + self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +@magi_compile(dynamic_arg_dims={"x": 0}) +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x) + + self.freqs_cis = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len * 2, params.rope_theta) + + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output diff --git a/example/training/tokenizer.py b/example/training/tokenizer.py new file mode 100644 index 0000000..b18381e --- /dev/null +++ b/example/training/tokenizer.py @@ -0,0 +1,211 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from logging import getLogger +from pathlib import Path +from typing import AbstractSet, Collection, Dict, Iterator, List, Literal, Sequence, TypedDict, Union, cast + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + +logger = getLogger(__name__) + + +Role = Literal["system", "user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +Dialog = Sequence[Message] + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5)] + self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)} + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + logger.info(f"Reloaded tiktoken model from {model_path}") + + self.n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.stop_tokens = {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]} + logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend(self.model.encode(substr, allowed_special=allowed_special, disallowed_special=disallowed_special)) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + +class ChatFormat: + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + + def encode_header(self, message: Message) -> List[int]: + tokens = [] + tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) + tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) + tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) + tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) + return tokens + + def encode_message(self, message: Message) -> List[int]: + tokens = self.encode_header(message) + tokens.extend(self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)) + tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) + return tokens + + def encode_dialog_prompt(self, dialog: Dialog) -> List[int]: + tokens = [] + tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) + for message in dialog: + tokens.extend(self.encode_message(message)) + # Add the start of an assistant message for the model to complete. + tokens.extend(self.encode_header({"role": "assistant", "content": ""})) + return tokens diff --git a/example/training/train.py b/example/training/train.py new file mode 100644 index 0000000..bfe7fac --- /dev/null +++ b/example/training/train.py @@ -0,0 +1,232 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import json +import os +import sys +import time +from functools import partial +from pathlib import Path +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) +from model import ModelArgs, Transformer, TransformerBlock +from tokenizer import ChatFormat, Tokenizer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + +import magi_compiler.utils.nvtx as nvtx + + +def setup_fsdp(model: nn.Module, device_id: int): + """ + Wrap the given Llama3 model with PyTorch FSDP. + We apply auto_wrap_policy to wrap each TransformerBlock individually. + """ + llama_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock}) + + fsdp_model = FSDP( + model, auto_wrap_policy=llama_auto_wrap_policy, device_id=device_id, sync_module_states=True, use_orig_params=True + ) + local_params = sum(p.numel() for p in fsdp_model.parameters()) + print(f"[Rank {device_id}] Local param count: {local_params:,}") + return fsdp_model + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + model_parallel_size: Optional[int] = None, + seed: int = 1, + ) -> "Llama": + """ + Build a Llama instance by initializing and loading a model checkpoint. + + Args: + ckpt_dir (str): Path to the directory containing checkpoint files. + tokenizer_path (str): Path to the tokenizer file. + max_seq_len (int): Maximum sequence length for input text. + max_batch_size (int): Maximum batch size for inference. + model_parallel_size (Optional[int], optional): Number of model parallel processes. + If not provided, it's determined from the environment. Defaults to None. + + Returns: + Llama: An instance of the Llama class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory, + or if the model parallel size does not match the number of checkpoint files. + + Note: + This method initializes the distributed process group, sets the device to CUDA, + and loads the pre-trained model and tokenizer. + """ + assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}." + assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist." + assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist." + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("nccl") + if not model_parallel_is_initialized(): + if model_parallel_size is None: + model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(model_parallel_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert model_parallel_size == len( + checkpoints + ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params) + tokenizer = Tokenizer(model_path=tokenizer_path) + assert model_args.vocab_size == tokenizer.n_words + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer) + + def __init__(self, model: Transformer, tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer + self.formatter = ChatFormat(tokenizer) + + +def main(): + # Setup distributed environment + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + print("Not running in distributed mode. Set RANK and WORLD_SIZE to use FSDP.") + # For demonstration purposes, we will mock the distributed setup if not provided. + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo") + + # Initialize model parallel group (since model uses fairscale VocabParallelEmbedding) + if not model_parallel_is_initialized(): + initialize_model_parallel(1) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + global_rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + else: + device = torch.device("cpu") + + # Initialize a small config for testing + config = ModelArgs(n_layers=10, max_batch_size=2, max_seq_len=1024) + + # Create Model + if global_rank == 0: + print(f"Initializing model on {world_size} devices...") + model = Transformer(config).to(device) + + # Wrap with FSDP + if global_rank == 0: + print("Wrapping model with FSDP...") + + if torch.cuda.is_available(): + fsdp_model = setup_fsdp(model, device_id=local_rank) + else: + # For CPU testing, fallback to DDP or just standard model if FSDP CPU is not supported + print(f"[Rank {global_rank}] CUDA not available. Running without FSDP for CPU fallback.") + fsdp_model = model + + # Optimizer + optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) + + num_epochs = 5 + bsz, seq_len = config.max_batch_size, config.max_seq_len + + if global_rank == 0: + print(f"Starting training for {num_epochs} epochs...") + + for epoch in range(num_epochs): + # record start time + start_time = time.time() + + # Dummy data for each epoch + # Ensure different ranks get different data if needed, but here we just generate random + torch.manual_seed(epoch * world_size + global_rank) + input_ids = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device) + labels = torch.randint(0, config.vocab_size, (bsz, seq_len), device=device) + + optimizer.zero_grad() + + nvtx.switch_profile(epoch, 3, 4) + # Forward pass + logits = fsdp_model(input_ids, start_pos=0) + + # Loss + loss = F.cross_entropy(logits.view(-1, config.vocab_size), labels.view(-1)) + + # Backward pass + loss.backward() + + optimizer.step() + + # record end time + end_time = time.time() + + if global_rank == 0: + print( + f"Epoch {epoch + 1}/{num_epochs} | Loss: {loss.item():.4f} | Time taken: {end_time - start_time:.4f} seconds" + ) + + if global_rank == 0: + print("Training done!") + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/example/training/train.sh b/example/training/train.sh new file mode 100644 index 0000000..8d990e0 --- /dev/null +++ b/example/training/train.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +PROJECT_ROOT=$(cd "$SCRIPT_DIR/../.." &> /dev/null && pwd) + +### Distributed args ### +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-29500} + +export CUDA_VISIBLE_DEVICES=1,2 +GPUS_PER_NODE=${GPUS_PER_NODE:-2} +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) +DISTRIBUTED_ARGS="--nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=$GPUS_PER_NODE --rdzv-backend=c10d --rdzv-endpoint=$MASTER_ADDR:$MASTER_PORT" + +### Nsys args ### +NSYS_PROFILE=${NSYS_PROFILE:-true} + +if [ "$NSYS_PROFILE" = true ]; then + mkdir -p "$PROJECT_ROOT/nsys_reports" + + BASE_NAME="nsys_llama3" + TIMESTAMP=$(date +"%Y%m%d_%H%M%S") + NSYS_SUFFIX="ts_${TIMESTAMP}" + + [ -n "$WORLD_SIZE" ] && NSYS_SUFFIX="${NSYS_SUFFIX}_worldsize_${WORLD_SIZE}" + [ -n "$COMPILE_MODE" ] && NSYS_SUFFIX="${NSYS_SUFFIX}_compile_${COMPILE_MODE}" + [ -n "$CUDA_GRAPH_MODE" ] && NSYS_SUFFIX="${NSYS_SUFFIX}_cudagraph_${CUDA_GRAPH_MODE}" + + NSYS_OUTPUT="$PROJECT_ROOT/nsys_reports/${BASE_NAME}_${NSYS_SUFFIX}" + + NSYS_CMD="nsys profile --force-overwrite true -o $NSYS_OUTPUT --trace=cuda,nvtx --capture-range=cudaProfilerApi" +else + NSYS_CMD="" +fi + +### Environment Variables ### +export ENABLE_REMOTE_DEBUG=${ENABLE_REMOTE_DEBUG:-false} +export MAGI_COMPILE_CACHE_ROOT_DIR=${MAGI_COMPILE_CACHE_ROOT_DIR:-"$PROJECT_ROOT/.cache"} +export MAGI_ENABLE_FX_GRAPH_VIZ=${MAGI_ENABLE_FX_GRAPH_VIZ:-false} + +$NSYS_CMD torchrun $DISTRIBUTED_ARGS $SCRIPT_DIR/train.py \ + $NSYS_ARGS diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index 293c11c..4a92e3a 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -221,6 +221,12 @@ def compile( restart_analysis_count = self._restart_analysis_counts.get(key, 0) if hasattr(self, "cache_dir") and self.cache_dir is not None: try: + # Workaround for empty aot_autograd artifacts + if getattr(compiled_graph, "_artifacts", None) is not None: + _, cache_info = compiled_graph._artifacts + if not cache_info.artifacts.get("aot_autograd"): + cache_info.artifacts["aot_autograd"] = [key] + path: Path = self.cache_dir / key compiled_graph.save(path=path.as_posix(), format="unpacked") compilation_counter.num_compiled_artifacts_saved += 1 diff --git a/magi_compiler/utils/nvtx.py b/magi_compiler/utils/nvtx.py index 7c36d82..f5729fb 100644 --- a/magi_compiler/utils/nvtx.py +++ b/magi_compiler/utils/nvtx.py @@ -103,3 +103,58 @@ def wrapped_fn(*args, **kwargs): return ret_val return cast(F, wrapped_fn) + + +def profile_start(event_name: str, record_shapes: bool = False): + torch.cuda.cudart().cudaProfilerStart() + torch.cuda.nvtx.range_push(event_name) + global _EMIT_NVTX_CTX + _EMIT_NVTX_CTX = torch.autograd.profiler.emit_nvtx(record_shapes=record_shapes) + _EMIT_NVTX_CTX.__enter__() + + +def profile_mark(event_name: str): + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push(event_name) + + +def profile_end(): + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + global _EMIT_NVTX_CTX + if _EMIT_NVTX_CTX is not None: + _EMIT_NVTX_CTX.__exit__(None, None, None) + _EMIT_NVTX_CTX = None + + +def switch_profile( + iter_id: int, start: int, end: int, profile_ranks: list[int] = None, event_name: str = None, record_shapes=True +): + """ + Controls the profiler state based on the iteration number. Turns on profiling + at the start iteration and turns it off at the end iteration. + + Args: + - iter_id: The current iteration number. + - start: The iteration number to start profiling. + - end: The iteration number to end profiling. + - profile_ranks: list of ranks to be profiled. + - event_name: Custom name for the profiling event. If None, defaults to 'iter_{iter_id}'. + """ + if profile_ranks is not None and torch.distributed.is_initialized() and torch.distributed.get_rank() not in profile_ranks: + return + + if event_name is None: + event_name = f"iter_{iter_id}" + + # Start profiling + if iter_id == start: + profile_start(event_name, record_shapes) + + # Stop profiling + elif iter_id == end: + profile_end() + + # Continue profiling + elif iter_id > start and iter_id < end: + profile_mark(event_name) From c66cec3d08fd414d31a510a4946c04b96d5722f4 Mon Sep 17 00:00:00 2001 From: wtr Date: Wed, 1 Apr 2026 19:17:39 +0800 Subject: [PATCH 4/5] chore --- example/training/{model.py => llama3.py} | 0 example/training/tokenizer.py | 211 ----------------------- example/training/train.py | 95 +--------- 3 files changed, 2 insertions(+), 304 deletions(-) rename example/training/{model.py => llama3.py} (100%) delete mode 100644 example/training/tokenizer.py diff --git a/example/training/model.py b/example/training/llama3.py similarity index 100% rename from example/training/model.py rename to example/training/llama3.py diff --git a/example/training/tokenizer.py b/example/training/tokenizer.py deleted file mode 100644 index b18381e..0000000 --- a/example/training/tokenizer.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright (c) 2026 SandAI. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - -import os -from logging import getLogger -from pathlib import Path -from typing import AbstractSet, Collection, Dict, Iterator, List, Literal, Sequence, TypedDict, Union, cast - -import tiktoken -from tiktoken.load import load_tiktoken_bpe - -logger = getLogger(__name__) - - -Role = Literal["system", "user", "assistant"] - - -class Message(TypedDict): - role: Role - content: str - - -Dialog = Sequence[Message] - - -class Tokenizer: - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5)] - self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)} - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - logger.info(f"Reloaded tiktoken model from {model_path}") - - self.n_words: int = self.model.n_vocab - # BOS / EOS token IDs - self.bos_id: int = self.special_tokens["<|begin_of_text|>"] - self.eos_id: int = self.special_tokens["<|end_of_text|>"] - self.pad_id: int = -1 - self.stop_tokens = {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]} - logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") - - def encode( - self, - s: str, - *, - bos: bool, - eos: bool, - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_tokens ("all"|set[str]): allowed special tokens in string - disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - assert type(s) is str - - # The tiktoken tokenizer can handle <=400k chars without - # pyo3_runtime.PanicException. - TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - - # https://github.com/openai/tiktoken/issues/195 - # Here we iterate over subsequences and split if we exceed the limit - # of max consecutive non-whitespace or whitespace characters. - MAX_NO_WHITESPACES_CHARS = 25_000 - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend(self.model.encode(substr, allowed_special=allowed_special, disallowed_special=disallowed_special)) - if bos: - t.insert(0, self.bos_id) - if eos: - t.append(self.eos_id) - return t - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: - """ - Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces. - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] - - -class ChatFormat: - def __init__(self, tokenizer: Tokenizer): - self.tokenizer = tokenizer - - def encode_header(self, message: Message) -> List[int]: - tokens = [] - tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) - tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) - tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) - tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) - return tokens - - def encode_message(self, message: Message) -> List[int]: - tokens = self.encode_header(message) - tokens.extend(self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)) - tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) - return tokens - - def encode_dialog_prompt(self, dialog: Dialog) -> List[int]: - tokens = [] - tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) - for message in dialog: - tokens.extend(self.encode_message(message)) - # Add the start of an assistant message for the model to complete. - tokens.extend(self.encode_header({"role": "assistant", "content": ""})) - return tokens diff --git a/example/training/train.py b/example/training/train.py index bfe7fac..e08e831 100644 --- a/example/training/train.py +++ b/example/training/train.py @@ -15,25 +15,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. -import json import os -import sys import time from functools import partial -from pathlib import Path -from typing import Optional import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from fairscale.nn.model_parallel.initialize import ( - get_model_parallel_rank, - initialize_model_parallel, - model_parallel_is_initialized, -) -from model import ModelArgs, Transformer, TransformerBlock -from tokenizer import ChatFormat, Tokenizer +from fairscale.nn.model_parallel.initialize import initialize_model_parallel, model_parallel_is_initialized +from llama3 import ModelArgs, Transformer, TransformerBlock from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy @@ -55,88 +46,6 @@ def setup_fsdp(model: nn.Module, device_id: int): return fsdp_model -class Llama: - @staticmethod - def build( - ckpt_dir: str, - tokenizer_path: str, - max_seq_len: int, - max_batch_size: int, - model_parallel_size: Optional[int] = None, - seed: int = 1, - ) -> "Llama": - """ - Build a Llama instance by initializing and loading a model checkpoint. - - Args: - ckpt_dir (str): Path to the directory containing checkpoint files. - tokenizer_path (str): Path to the tokenizer file. - max_seq_len (int): Maximum sequence length for input text. - max_batch_size (int): Maximum batch size for inference. - model_parallel_size (Optional[int], optional): Number of model parallel processes. - If not provided, it's determined from the environment. Defaults to None. - - Returns: - Llama: An instance of the Llama class with the loaded model and tokenizer. - - Raises: - AssertionError: If there are no checkpoint files in the specified directory, - or if the model parallel size does not match the number of checkpoint files. - - Note: - This method initializes the distributed process group, sets the device to CUDA, - and loads the pre-trained model and tokenizer. - """ - assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}." - assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist." - assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist." - - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") - if not model_parallel_is_initialized(): - if model_parallel_size is None: - model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) - initialize_model_parallel(model_parallel_size) - - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) - - # seed must be the same in all processes - torch.manual_seed(seed) - - if local_rank > 0: - sys.stdout = open(os.devnull, "w") - - start_time = time.time() - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - assert model_parallel_size == len( - checkpoints - ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" - ckpt_path = checkpoints[get_model_parallel_rank()] - checkpoint = torch.load(ckpt_path, map_location="cpu") - with open(Path(ckpt_dir) / "params.json", "r") as f: - params = json.loads(f.read()) - - model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params) - tokenizer = Tokenizer(model_path=tokenizer_path) - assert model_args.vocab_size == tokenizer.n_words - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) - else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) - model = Transformer(model_args) - model.load_state_dict(checkpoint, strict=False) - print(f"Loaded in {time.time() - start_time:.2f} seconds") - - return Llama(model, tokenizer) - - def __init__(self, model: Transformer, tokenizer: Tokenizer): - self.model = model - self.tokenizer = tokenizer - self.formatter = ChatFormat(tokenizer) - - def main(): # Setup distributed environment if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: From 925aab88be553efceff6579d62fcd04f24fff709 Mon Sep 17 00:00:00 2001 From: wtr Date: Wed, 1 Apr 2026 19:27:16 +0800 Subject: [PATCH 5/5] chore --- example/training/train.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/example/training/train.sh b/example/training/train.sh index 8d990e0..53b60f3 100644 --- a/example/training/train.sh +++ b/example/training/train.sh @@ -23,8 +23,7 @@ PROJECT_ROOT=$(cd "$SCRIPT_DIR/../.." &> /dev/null && pwd) MASTER_ADDR=${MASTER_ADDR:-localhost} MASTER_PORT=${MASTER_PORT:-29500} -export CUDA_VISIBLE_DEVICES=1,2 -GPUS_PER_NODE=${GPUS_PER_NODE:-2} +GPUS_PER_NODE=${GPUS_PER_NODE:-1} NNODES=${NNODES:-1} NODE_RANK=${NODE_RANK:-0} WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) @@ -51,7 +50,7 @@ else NSYS_CMD="" fi -### Environment Variables ### +### Environment Variables For Debugging ### export ENABLE_REMOTE_DEBUG=${ENABLE_REMOTE_DEBUG:-false} export MAGI_COMPILE_CACHE_ROOT_DIR=${MAGI_COMPILE_CACHE_ROOT_DIR:-"$PROJECT_ROOT/.cache"} export MAGI_ENABLE_FX_GRAPH_VIZ=${MAGI_ENABLE_FX_GRAPH_VIZ:-false}