Skip to content
57 changes: 35 additions & 22 deletions examples/run_simple_mcore_train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Tuple, Iterator

from megatron.core import parallel_state
from megatron.core import dist_checkpointing
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
Expand All @@ -25,19 +24,17 @@
from megatron.core.distributed.finalize_model_grads import finalize_model_grads
from megatron.core.tokenizers import MegatronTokenizer


_SEQUENCE_LENGTH: int = 64


def initialize_distributed(
tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1
) -> None:
"""
Initialize torch.distributed and Megatron-Core model parallel groups.
Set up torch.distributed and Megatron-Core model parallel groups.

Args:
tensor_model_parallel_size: Number of GPUs for tensor model parallelism.
pipeline_model_parallel_size: Number of GPUs for pipeline model parallelism.
tensor_model_parallel_size (int): Number of GPUs to use for tensor model parallelism.
pipeline_model_parallel_size (int): Number of GPUs to use for pipeline model parallelism.
"""
parallel_state.destroy_model_parallel()

Expand All @@ -59,10 +56,10 @@ def initialize_distributed(

def model_provider() -> GPTModel:
"""
Build and return a simple GPT model for demonstration.
Construct a minimal GPT model for demonstration and testing purposes.

Returns:
GPTModel: A small GPT model with 2 layers for testing.
GPTModel: A small GPT model instance with 2 layers.
"""
transformer_config: TransformerConfig = TransformerConfig(
num_layers=2,
Expand All @@ -84,10 +81,14 @@ def model_provider() -> GPTModel:

def get_train_data_iterator() -> Iterator:
"""
Create a mock dataset and return a data iterator.
Initialize and return an iterator over the training dataset for the GPT model.

This function sets up a mock dataset using the provided configuration and tokenizer, builds the dataset,
and returns an iterator for use in the training loop. It ensures that helper functions are compiled
across distributed processes if running in a distributed environment.

Returns:
Iterator: Data iterator for training batches.
Iterator: An iterator that yields training batches for the GPT model.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
Expand Down Expand Up @@ -124,15 +125,20 @@ def forward_step_func(
data_iterator: Iterator, model: torch.nn.Module
) -> Tuple[torch.Tensor, Callable]:
"""
Forward step function that computes model output and returns loss function.
Perform a forward pass on a batch of training data and return the model output and loss function.

This function retrieves the next batch from the data iterator, moves all tensors to the appropriate device,
and computes the model's output tensor. It also defines and returns a loss function, partially applied with the
current loss mask, for use in the training loop.

Args:
data_iterator: Iterator providing training batches.
model: The GPT model to train.
data_iterator (Iterator): Iterator yielding training batches as dictionaries of tensors.
model (torch.nn.Module): The GPT model to be trained.

Returns:
Tuple of (output_tensor, loss_function) where loss_function is a partial
function that will compute the final loss when called.
Tuple[torch.Tensor, Callable]:
- output_tensor: The output tensor from the model's forward pass.
- loss_function: A callable that computes the loss when invoked with the model output.
"""

def loss_func(
Expand Down Expand Up @@ -164,11 +170,15 @@ def save_distributed_checkpoint(
checkpoint_path: str, gpt_model: torch.nn.Module
) -> None:
"""
Save model checkpoint using Megatron-Core distributed checkpointing.
Save a distributed checkpoint of the GPT model using Megatron-Core utilities.

This function extracts the underlying model if wrapped with DistributedDataParallel (DDP),
obtains its sharded state dictionary, and saves it to the specified directory using
Megatron-Core's distributed checkpointing mechanism.

Args:
checkpoint_path: Directory path to save checkpoint.
gpt_model: The model to checkpoint (may be wrapped with DDP).
checkpoint_path (str): Directory path where the checkpoint will be saved.
gpt_model (torch.nn.Module): The GPT model to checkpoint (may be wrapped with DDP).
"""
# Access underlying model if wrapped with DDP
model: torch.nn.Module = (
Expand All @@ -184,14 +194,17 @@ def load_distributed_checkpoint(
checkpoint_path: str, gpt_model: torch.nn.Module
) -> torch.nn.Module:
"""
Load model checkpoint using Megatron-Core distributed checkpointing.
Load a distributed checkpoint into the GPT model using Megatron-Core utilities.

This function extracts the underlying model if wrapped with DistributedDataParallel (DDP),
loads the checkpoint from the specified directory, and updates the model's state dictionary.

Args:
checkpoint_path: Directory path to load checkpoint from.
gpt_model: The model to load into (may be wrapped with DDP).
checkpoint_path (str): Directory path from which to load the checkpoint.
gpt_model (torch.nn.Module): The GPT model to load the checkpoint into (may be wrapped with DDP).

Returns:
The model with loaded checkpoint weights.
torch.nn.Module: The model with loaded checkpoint weights.
"""
# Access underlying model if wrapped with DDP
model: torch.nn.Module = (
Expand Down
Loading