Skip to content

[Dev] Add Llama3 training example and fix cache save#14

Open
wtr0504 wants to merge 3 commits intoSandAI-org:mainfrom
wtr0504:dev/training
Open

[Dev] Add Llama3 training example and fix cache save#14
wtr0504 wants to merge 3 commits intoSandAI-org:mainfrom
wtr0504:dev/training

Conversation

@wtr0504
Copy link
Copy Markdown
Collaborator

@wtr0504 wtr0504 commented Apr 1, 2026

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

Summary

Add end-to-end Llama3 training example (example/training/) with FSDP support, a distributed training script, and an Nsys profiling launch script.
Fix a cache save bug where aot_autograd artifacts were empty, causing compiled graphs to fail to persist correctly.

Changes

example/training/llama3.py — Llama3 model definition adapted to use magi_compile
example/training/train.py — distributed training loop with FSDP and NVTX profiling hooks
example/training/train.sh — torchrun launcher with optional Nsys profiling
magi_compiler/magi_backend/piecewise_compiler.py — workaround for empty aot_autograd artifacts on cache save
magi_compiler/utils/nvtx.py — profiler for iteration

from dataclasses import dataclass
from typing import Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add this pkg to requirements-test.txt?

device = torch.device("cpu")

# Initialize a small config for testing
config = ModelArgs(n_layers=10, max_batch_size=2, max_seq_len=1024)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use official config for profiling

export MAGI_ENABLE_FX_GRAPH_VIZ=${MAGI_ENABLE_FX_GRAPH_VIZ:-false}

$NSYS_CMD torchrun $DISTRIBUTED_ARGS $SCRIPT_DIR/train.py \
$NSYS_ARGS
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No NSYS_ARGS provided? Check again and try to simplify this script~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants