diff --git a/magi_compiler/passes/joint_graph/joint_graph_partition.py b/magi_compiler/passes/joint_graph/joint_graph_partition.py index 8e19906..d96c7c3 100644 --- a/magi_compiler/passes/joint_graph/joint_graph_partition.py +++ b/magi_compiler/passes/joint_graph/joint_graph_partition.py @@ -14,7 +14,7 @@ import operator import os -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence from unittest.mock import patch import torch @@ -29,21 +29,20 @@ from magi_compiler.utils import compute_code_hash, magi_logger from magi_compiler.utils.visualize import joint_graph_vis -SAVE_TENSOR_NODES: Optional[list[fx.Node]] = None - def is_memory_increase_by_node(node: fx.Node) -> bool: - # Only support aten.to now - assert get_aten_target(node) == torch.ops.prims.convert_element_type + """Check if the operation increases memory size (e.g., casting fp16 to fp32).""" + assert get_aten_target(node) == torch.ops.prims.convert_element_type, "Only aten.to is supported" input_dtype = node.args[0].meta["tensor_meta"].dtype output_dtype = node.args[1] - assert output_dtype is not None + assert output_dtype is not None, "Output dtype must be specified" return output_dtype.itemsize > input_dtype.itemsize def is_compute_sensitive_op( node: fx.Node, op_types: OpTypes, custom_compute_sensitive_ops: list[torch._ops.OpOverload] ) -> bool: + """Check if the node is a compute-intensive operation.""" if op_types.is_compute_intensive(node): return True if node.op != "call_function": @@ -55,105 +54,143 @@ def is_compute_sensitive_op( return False -def is_primal_contribute_to_bwd_directly( +def _primal_contributes_to_bwd_directly( primal_node: fx.Node, node_info: NodeInfo, op_types: OpTypes, custom_compute_sensitive_ops: list[torch._ops.OpOverload] ) -> bool: """ - FSDP ensures that weights already reside in memory. If there exists a path from the primal to the bwd, and the path does not contain any matmul, then the primal contributes to the bwd directly. - And we should save this primals. + FSDP ensures that weights already reside in memory. + If there is a path from the primal (weight) to the backward pass that does not contain + any compute-intensive operations (like matmul), it contributes to the backward pass directly, + and we should save this primal node. """ if node_info.is_required_bw(primal_node): return True - topology_start = set({primal_node}) - while len(topology_start) > 0: - cur_node = topology_start.pop() + worklist = {primal_node} + visited = {primal_node} + + while worklist: + cur_node = worklist.pop() for user in cur_node.users: if node_info.is_required_bw(user): return True if is_compute_sensitive_op(user, op_types, custom_compute_sensitive_ops): continue - topology_start.add(user) + if user not in visited: + visited.add(user) + worklist.add(user) + return False -def is_compute_intensive_and_has_following_recomputable_ops( - intermidiate_node: fx.Node, - node_info: NodeInfo, - op_types: OpTypes, - custom_compute_sensitive_ops: list[torch._ops.OpOverload], -) -> Tuple[bool, fx.Node]: +def _push_down_save_node(node: fx.Node, node_info: NodeInfo, op_types: OpTypes) -> Optional[fx.Node]: """ - If compute-intensive node(CIN) is not the output of fwd graph(has following memory-intensive ops in the fwd graph), then we should save this CIN node. - NOTE: For CIN+aten.to, we should save aten.to op instead of CIN op to save more memory. + Starting from a compute-intensive node, walk forward through memory-efficient ops + (views, type-narrowing casts) to find the optimal save point. + + For example, for `matmul -> view -> to(fp16)`, we save the fp16 tensor rather than + the matmul output, since they carry the same information at a lower memory cost. + + Returns None if the node is a direct output of the forward graph (no explicit save needed). """ - if not is_compute_sensitive_op(intermidiate_node, op_types, custom_compute_sensitive_ops): - return False, None - - save_node = intermidiate_node - topology_start = set({save_node}) - while len(topology_start) > 0: - cur_node = topology_start.pop() - fwd_user_nodes = [] - for user in cur_node.users: - if node_info.is_required_fw(user): - fwd_user_nodes.append(user) + cur_node = node + save_node = node + + while True: + fwd_user_nodes = [u for u in cur_node.users if node_info.is_required_fw(u)] - if len(fwd_user_nodes) > 1: # multiple users, save current node - return True, save_node - elif len(fwd_user_nodes) == 0: # output, return - return False, None + if len(fwd_user_nodes) > 1: # branch point: multiple users, save here + return save_node + if len(fwd_user_nodes) == 0: # fwd graph output: autograd handles it + return None - # save current node if it's user is recomputable next_node = fwd_user_nodes[0] - if op_types.is_view(next_node): + + if next_node.op == "output": + return None + + # Try to push save_node down through memory-efficient ops + is_view = op_types.is_view(next_node) + is_type_convert = get_aten_target(next_node) == torch.ops.prims.convert_element_type + + if is_view: if save_node == cur_node: save_node = next_node - topology_start.add(next_node) - # Special case for aten.to, memory efficient case - elif get_aten_target(next_node) == torch.ops.prims.convert_element_type: - is_memory_increase = is_memory_increase_by_node(next_node) - if not is_memory_increase: + cur_node = next_node + elif is_type_convert: + if not is_memory_increase_by_node(next_node): save_node = next_node - topology_start.add(next_node) - elif next_node.op == "output": - return False, None + cur_node = next_node else: - return True, save_node - assert False, f"Should not reach here: {intermidiate_node=} {save_node=}" + return save_node + + +def _decide_save_node( + node: fx.Node, node_info: NodeInfo, primal_set: frozenset, op_types: OpTypes, custom_ops: list[torch._ops.OpOverload] +) -> Optional[fx.Node]: + """ + Unified decision function: given any node in the joint graph, return the optimal + node to save, or None if no save is needed. + + Two cases trigger saving: + 1. Primal node: backward needs it via a path with no compute-intensive barrier. + Save the primal as-is (it is already the smallest representation of itself). + 2. Compute-intensive forward node: push the save point down through memory-efficient + ops (views, type-narrowing casts) to minimize memory footprint. + """ + if node in primal_set: + if _primal_contributes_to_bwd_directly(node, node_info, op_types, custom_ops): + return node + return None + + if node_info.is_required_fw(node) and is_compute_sensitive_op(node, op_types, custom_ops): + return _push_down_save_node(node, node_info, op_types) + + return None + + +def _collect_save_node(save_node: fx.Node, output: OrderedSet) -> None: + """ + Add save_node to the output set. + If the node's output is a tuple (e.g., from ops like `torch.var_mean`), + collect all getitem users instead of the node itself. + """ + out_val = save_node.meta.get("val") + assert out_val is not None, f"save_node {save_node} must have output, otherwise it's no need to save" + + if isinstance(out_val, tuple): + for user in save_node.users: + assert ( + user.op == "call_function" and user.target == operator.getitem + ), f"save_node {save_node} must have a getitem user" + output.add(user) + else: + output.add(save_node) -# TODO: We find an elegant impl to heuristically save nodes, reconstruct this later def heuristic_choose_saved_values_set(joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1) -> list[fx.Node]: - output: OrderedSet[fx.Node] = OrderedSet() + """ + Heuristic to select which forward nodes to save for the backward pass. + + Rather than reasoning about primals and intermediates separately, we make a single + pass over all joint-graph nodes and apply a unified decision (_decide_save_node): + - Primal nodes that backward directly needs (no compute-intensive barrier) are saved as-is. + - Compute-intensive forward nodes whose outputs are consumed by later forward ops + are saved at their memory-optimal downstream position. + """ op_types = get_default_op_list() - custom_compute_sensitive_ops = get_compile_config().recompute_config.custom_compute_sensitive_ops - custom_compute_sensitive_ops: list[torch._ops.OpOverload] = resolve_defined_ops(custom_compute_sensitive_ops) - # Select the inputs that are required by the backward pass - for primal_node in node_info.inputs: - if is_primal_contribute_to_bwd_directly(primal_node, node_info, op_types, custom_compute_sensitive_ops): - output.add(primal_node) - magi_logger.info("MagiCompiler: saved_output forward-input = %s", output) - # Select the compute-intensive nodes that are required by the forward pass - for intermidiate_node in node_info.required_fw_nodes: - is_save, save_node = is_compute_intensive_and_has_following_recomputable_ops( - intermidiate_node, node_info, op_types, custom_compute_sensitive_ops - ) - if not is_save or save_node is None: - continue - out_val = save_node.meta.get("val") - assert out_val is not None, f"save_node {save_node} must have output, otherwise it's no need to save" - if isinstance(out_val, tuple): - for user in save_node.users: - assert ( - user.op == "call_function" and user.target == operator.getitem - ), f"save_node {save_node} must have a getitem user" - output.add(user) - else: - output.add(save_node) - magi_logger.info("MagiCompiler: saved_output compute-intensive = %s", output) - global SAVE_TENSOR_NODES - SAVE_TENSOR_NODES = list(output) + custom_ops: list[torch._ops.OpOverload] = resolve_defined_ops( + get_compile_config().recompute_config.custom_compute_sensitive_ops + ) + primal_set = frozenset(node_info.inputs) + output: OrderedSet[fx.Node] = OrderedSet() + + for node in joint_graph.nodes: + save_node = _decide_save_node(node, node_info, primal_set, op_types, custom_ops) + if save_node is not None: + _collect_save_node(save_node, output) + + magi_logger.info("MagiCompiler: saved_output = %s", output) return list(output) @@ -162,37 +199,37 @@ def custom_joint_graph_partition_fn( _joint_inputs, compiler="inductor", *, - num_fwd_outputs, + num_fwd_outputs: int, static_lifetime_input_indices: Optional[list[int]] = None, ) -> tuple[fx.GraphModule, fx.GraphModule]: recompute_config = get_compile_config().recompute_config - if recompute_config.recompute_policy == RecomputePolicy.HANDCRAFT: + partition_kwargs = dict(num_fwd_outputs=num_fwd_outputs, static_lifetime_input_indices=static_lifetime_input_indices) + + save_tensor_nodes: list[fx.Node] = [] + policy = recompute_config.recompute_policy + + if policy == RecomputePolicy.HANDCRAFT: magi_logger.info("MagiCompiler using handcraft recompute policy") # TODO: different memory budget definition from torch - with patch("torch._functorch.config.activation_memory_budget", recompute_config.memory_budget): - fwd_module, bwd_module = min_cut_rematerialization_partition( - joint_module, - _joint_inputs, - compiler, - num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - elif recompute_config.recompute_policy == RecomputePolicy.HEURISTIC: + ctx = patch("torch._functorch.config.activation_memory_budget", recompute_config.memory_budget) + elif policy == RecomputePolicy.HEURISTIC: magi_logger.info("MagiCompiler using heuristic recompute policy") - with patch("torch._functorch.partitioners.choose_saved_values_set", heuristic_choose_saved_values_set): - fwd_module, bwd_module = min_cut_rematerialization_partition( - joint_module, - _joint_inputs, - compiler, - num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - elif recompute_config.recompute_policy == RecomputePolicy.AUTOSEARCH: - raise ValueError(f"AutoSearch recompute policy is not supported yet") + + def _tracked_choose(joint_graph, node_info, memory_budget=1): + result = heuristic_choose_saved_values_set(joint_graph, node_info, memory_budget) + save_tensor_nodes.extend(result) + return result + + ctx = patch("torch._functorch.partitioners.choose_saved_values_set", _tracked_choose) + elif policy == RecomputePolicy.AUTOSEARCH: + raise ValueError("AutoSearch recompute policy is not supported yet") else: - raise ValueError(f"Invalid recompute policy: {recompute_config.recompute_policy}") + raise ValueError(f"Invalid recompute policy: {policy}") + + with ctx: + fwd_module, bwd_module = min_cut_rematerialization_partition(joint_module, _joint_inputs, compiler, **partition_kwargs) - joint_graph_vis(joint_module, fwd_module, bwd_module, save_tensor_nodes=SAVE_TENSOR_NODES) + joint_graph_vis(joint_module, fwd_module, bwd_module, save_tensor_nodes=save_tensor_nodes or None) return fwd_module, bwd_module diff --git a/tests/model_definition.py b/tests/model_definition.py index 86e2521..1fc7c9e 100644 --- a/tests/model_definition.py +++ b/tests/model_definition.py @@ -161,3 +161,164 @@ def create_mlp_model_with_initial_params(config: MLPConfig, device: torch.device model = MLP(config).to(device) initial_params = [p.clone().detach() for p in model.parameters()] return model, initial_params + + +@dataclass +class TransformerConfig: + """Configuration for the Transformer model""" + + vocab_size: int + hidden_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + num_key_value_heads: int + max_position_embeddings: int + rms_norm_eps: float = 1e-6 + params_dtype: torch.dtype = torch.bfloat16 + + +class Attention(nn.Module): + """Multi-head attention module""" + + def __init__(self, config: TransformerConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False, dtype=config.params_dtype) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, dtype=config.params_dtype + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, dtype=config.params_dtype + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False, dtype=config.params_dtype) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + bsz, seq_len, _ = x.shape + q = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # GQA + if self.num_key_value_groups > 1: + k = k.repeat_interleave(self.num_key_value_groups, dim=1) + v = v.repeat_interleave(self.num_key_value_groups, dim=1) + + flash_attn_out = torch.ops.aten._scaled_dot_product_flash_attention( + q, k, v, dropout_p=0.0, is_causal=False, return_debug_mask=False + ) + attn_output = flash_attn_out[0] + + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size) + return self.o_proj(attn_output) + + +class TransformerMLP(nn.Module): + """MLP module for Transformer""" + + def __init__(self, config: TransformerConfig): + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate = F.silu(self.gate_proj(x)) + up = self.up_proj(x) + return self.down_proj(gate * up) + + +@magi_compile(dynamic_arg_dims={"x": 0}) +class TransformerBlock(nn.Module): + """A single Transformer block""" + + def __init__(self, config: TransformerConfig): + super().__init__() + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = Attention(config) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = TransformerMLP(config) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + residual = x + x = self.input_layernorm(x).to(torch.bfloat16) + x = self.self_attn(x, attention_mask=attention_mask) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x).to(torch.bfloat16) + x = self.mlp(x) + x = residual + x + return x + + +class Transformer(nn.Module): + """A complete Transformer model""" + + config: TransformerConfig + + def __init__(self, config: TransformerConfig): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=config.params_dtype) + self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype=config.params_dtype) + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + """Forward pass of the Transformer model. + + Args: + input_ids (torch.Tensor): Input token ids + attention_mask (torch.Tensor): Attention mask + + Returns: + output (torch.Tensor): Output logits + + Shape: + - input_ids: (batch_size, seq_len) + - attention_mask: (batch_size, 1, seq_len, seq_len) + - output: (batch_size, seq_len, vocab_size) + """ + x = self.embed_tokens(input_ids) + for layer in self.layers: + x = layer(x, attention_mask=attention_mask) + x = self.norm(x).to(torch.bfloat16) + return self.lm_head(x) + + +def create_transformer_model(config: TransformerConfig, device: torch.device) -> Transformer: + """Create Transformer model + + Args: + config: Transformer configuration + device: Target device + + Returns: + model: Created Transformer model + """ + model = Transformer(config).to(device) + return model + + +def create_transformer_model_with_initial_params( + config: TransformerConfig, device: torch.device +) -> tuple[Transformer, list[torch.Tensor]]: + """Create Transformer model and return model with initial parameter snapshot + + Args: + config: Transformer configuration + device: Target device + + Returns: + model: Created Transformer model + initial_params: Initial snapshot of model parameters for verifying parameter updates + """ + model = Transformer(config).to(device) + initial_params = [p.clone().detach() for p in model.parameters()] + return model, initial_params diff --git a/tests/model_tests/test_mlp_training.py b/tests/model_tests/test_mlp_training.py index dc9b352..d475d44 100644 --- a/tests/model_tests/test_mlp_training.py +++ b/tests/model_tests/test_mlp_training.py @@ -16,7 +16,15 @@ import torch import torch.nn as nn -from tests.model_definition import MLP, MLPConfig, create_mlp_model_with_initial_params +from magi_compiler.utils import add_nvtx_event +from tests.model_definition import ( + MLP, + MLPConfig, + Transformer, + TransformerConfig, + create_mlp_model_with_initial_params, + create_transformer_model_with_initial_params, +) from tests.utils import CleanupCacheContext, enable_remote_debug @@ -92,6 +100,82 @@ def train_mlp_model( return epoch_losses +def train_transformer_model( + model: Transformer, + optimizer: torch.optim.Optimizer, + device: torch.device, + batch_size: int, + seq_len: int, + vocab_size: int, + num_epochs: int, + batches_per_epoch: int, + gradient_accumulation_steps: int = 1, +) -> list[float]: + """Execute training loop for Transformer model (supports gradient accumulation) + + Args: + model: Transformer model to train + optimizer: Optimizer + device: Training device + batch_size: Number of sequences per batch + seq_len: Length of each sequence + vocab_size: Vocabulary size + num_epochs: Number of training epochs + batches_per_epoch: Number of batches per epoch + gradient_accumulation_steps: Gradient accumulation steps, default is 1 (no accumulation) + + Returns: + epoch_losses: List of average losses per epoch + """ + epoch_losses = [] + + print(f"Starting Transformer training: {num_epochs} epochs, {batches_per_epoch} batches per epoch") + if gradient_accumulation_steps > 1: + print(f"Using gradient accumulation, accumulation steps: {gradient_accumulation_steps}") + + for epoch in range(num_epochs): + epoch_loss_sum = 0.0 + + for batch_idx in range(batches_per_epoch): + # Zero gradients at the start of each accumulation cycle + if batch_idx % gradient_accumulation_steps == 0: + optimizer.zero_grad() + + # Generate random input and target data + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) + target_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) + + # Forward pass + with add_nvtx_event("transformer_forward"): + output = model(input_ids) + + # Compute loss, divided by accumulation steps to maintain effective batch size consistency + loss = nn.functional.cross_entropy(output.view(-1, vocab_size), target_ids.view(-1)) / gradient_accumulation_steps + + # Backward pass (gradients are automatically accumulated) + with add_nvtx_event("transformer_backward"): + loss.backward() + + # Accumulate loss for logging (multiply by accumulation steps to restore original value) + epoch_loss_sum += loss.item() * gradient_accumulation_steps + + # Update parameters after accumulating gradient_accumulation_steps batches + if (batch_idx + 1) % gradient_accumulation_steps == 0: + optimizer.step() + + # Handle the last incomplete accumulation batch + if batches_per_epoch % gradient_accumulation_steps != 0: + optimizer.step() + optimizer.zero_grad() + + avg_loss = epoch_loss_sum / batches_per_epoch + epoch_losses.append(avg_loss) + print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.6f}") + + print("Training completed!") + return epoch_losses + + def verify_model_parameters_updated( initial_params: list[torch.Tensor], current_params: list[torch.Tensor], tolerance: float = 1e-6 ) -> bool: @@ -152,9 +236,63 @@ def test_mlp_training_with_magi_compiler(): print("Test passed: Model successfully completed multiple training epochs, parameters have been updated") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available, skipping test") +def test_transformer_training_with_magi_compiler(): + """Test Transformer training with magi_compiler in training scenario""" + + # Set device + device = torch.device("cuda") + + # Create Transformer configuration + transformer_config = TransformerConfig( + vocab_size=10000, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=2, + num_attention_heads=16, + num_key_value_heads=16, + max_position_embeddings=1024, + rms_norm_eps=1e-6, + params_dtype=torch.bfloat16, + ) + + # Create model and save initial parameters + model, initial_params = create_transformer_model_with_initial_params(transformer_config, device) + + # Create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + + # Training parameters + batch_size = 8 + seq_len = 1024 * 4 + vocab_size = transformer_config.vocab_size + num_epochs = 4 + batches_per_epoch = 2 + + # Execute training + epoch_losses = train_transformer_model( + model=model, + optimizer=optimizer, + device=device, + batch_size=batch_size, + seq_len=seq_len, + vocab_size=vocab_size, + num_epochs=num_epochs, + batches_per_epoch=batches_per_epoch, + ) + + # Verify model parameters have been updated + params_updated = verify_model_parameters_updated(initial_params=initial_params, current_params=list(model.parameters())) + + assert params_updated, "Transformer model parameters should change after training" + + print("Test passed: Transformer model successfully completed multiple training epochs, parameters have been updated") + + if __name__ == "__main__": # Usage: # ENABLE_REMOTE_DEBUG=true MAGI_ENABLE_FX_GRAPH_VIZ=true TORCH_LOGS=aot CUDA_VISIBLE_DEVICES=1 python pkgs/MagiCompiler/tests/test_mlp_training.py with CleanupCacheContext(): enable_remote_debug() test_mlp_training_with_magi_compiler() + test_transformer_training_with_magi_compiler()