Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions train_tensor_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Ray Train + Tensor Parallelism Tutorial

A simple tutorial demonstrating how to train large language models with tensor parallelism using PyTorch native FSDP2+DTensor and Ray Train.

## Key Concepts

- **Tensor Parallelism (TP)**: Shards model weights across GPUs within a TP group
- **Data Parallelism (DP)**: Replicates the model across DP groups, each processing different data
- **2D Parallelism**: Combines TP and DP for scaling to many GPUs

## Quick Start

```bash
# 4 GPUs: 2-way tensor parallelism, 2-way data parallelism
python train.py \
--model_name Qwen/Qwen2-7B \
--tp_size 2 \
--dp_size 2 \
--num_workers 4 \
--num_epochs 3

# 8 GPUs: 4-way tensor parallelism, 2-way data parallelism
python train.py \
--model_name Qwen/Qwen2-7B \
--tp_size 4 \
--dp_size 2 \
--num_workers 8 \
--batch_size 2 \
--seq_length 2048 \
--num_epochs 3
```

## Arguments

| Argument | Description | Default |
|----------|-------------|---------|
| `--model_name` | HuggingFace model name | `Qwen/Qwen2-7B` |
| `--tp_size` | Tensor parallel degree | Required |
| `--dp_size` | Data parallel degree | `1` |
| `--num_workers` | Total workers (must equal tp_size * dp_size) | Required |
| `--dataset_name` | HuggingFace dataset | `wikitext` |
| `--dataset_percentage` | Percentage of dataset to use (0-100) | `10.0` |
| `--batch_size` | Per-GPU micro batch size | `1` |
| `--seq_length` | Maximum sequence length | `2048` |
| `--num_epochs` | Number of training epochs | `3` |
| `--learning_rate` | Learning rate | `1e-5` |
| `--weight_decay` | Weight decay | `0.01` |
| `--storage_path` | Checkpoint storage path | `/mnt/cluster_storage` |
| `--experiment_name` | Experiment name (auto-generated if not provided) | None |
| `--log_interval` | Logging interval (steps) | `10` |
| `--debug_steps` | Stop after N steps per epoch (0 = full epoch) | `0` |
| `--seed` | Random seed | `42` |

## Anyscale Job

```bash
anyscale job submit -f job.yaml
```

## File Structure

```
train_tensor_parallel_simple/
├── train.py # Main training script
├── args.py # Command line arguments
├── job.yaml # Anyscale job config
└── README.md # This file
```

## How 2D Parallelism Works

With `tp_size=2` and `dp_size=2` on 4 GPUs:

```
Device Mesh (2x2):
TP Dim
[0] [1]
DP +---+---+
Dim | 0 | 1 | <- TP Group 0 (same data, sharded model)
+---+---+
| 2 | 3 | <- TP Group 1 (same data, sharded model)
+---+---+
^ ^
DP Groups (different data, gradient sync)
```

- **TP Groups** (rows): GPUs 0,1 and GPUs 2,3 share the same input data but have sharded model weights
- **DP Groups** (columns): GPUs 0,2 and GPUs 1,3 see different data and synchronize gradients

## Key Implementation Details

### TP-Aware Data Loading

Standard data loaders shard by `world_rank`, giving each GPU different data. With tensor parallelism, all GPUs in a TP group must see identical data. This is handled by sharding based on `dp_rank` instead:

```python
# All TP ranks in same DP group get identical batches
sampler = DistributedSampler(
dataset,
num_replicas=dp_size, # NOT world_size
rank=dp_rank, # NOT world_rank
)
```

### Checkpointing

All workers save their model shards independently. Ray Train aggregates these into a single checkpoint that can be used for resuming training.
120 changes: 120 additions & 0 deletions train_tensor_parallel/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Command line argument parsing for tensor parallelism training."""

import argparse


def get_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Ray Train + FSDP2 + DTensor Tensor Parallelism Training"
)

# Model configuration
parser.add_argument(
"--model_name",
type=str,
default="Qwen/Qwen2-7B",
help="HuggingFace model name or path",
)

# Parallelism configuration
parser.add_argument(
"--tp_size",
type=int,
required=True,
help="Tensor parallel degree",
)
parser.add_argument(
"--dp_size",
type=int,
default=1,
help="Data parallel degree",
)
parser.add_argument(
"--num_workers",
type=int,
required=True,
help="Total number of workers (must equal tp_size * dp_size)",
)

# Dataset configuration
parser.add_argument(
"--dataset_name",
type=str,
default="wikitext",
help="HuggingFace dataset name",
)
parser.add_argument(
"--dataset_percentage",
type=float,
default=10.0,
help="Percentage of dataset to use (0-100)",
)

# Training configuration
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Per-GPU micro batch size",
)
parser.add_argument(
"--seq_length",
type=int,
default=2048,
help="Maximum sequence length",
)
parser.add_argument(
"--num_epochs",
type=int,
default=3,
help="Number of training epochs",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="Weight decay",
)

# Checkpointing configuration
parser.add_argument(
"--storage_path",
type=str,
default="/mnt/cluster_storage",
help="Storage path for checkpoints",
)
parser.add_argument(
"--experiment_name",
type=str,
default=None,
help="Experiment name (auto-generated if not provided)",
)

# Logging and debugging
parser.add_argument(
"--log_interval",
type=int,
default=10,
help="Logging interval (steps)",
)
parser.add_argument(
"--debug_steps",
type=int,
default=0,
help="Stop after this many steps per epoch (0 = run full epoch)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed",
)

return parser.parse_args()
50 changes: 50 additions & 0 deletions train_tensor_parallel/job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Anyscale Job: Ray Train + FSDP2 + DTensor Tensor Parallelism
# This job runs training with PyTorch native FSDP2 + DTensor for tensor parallelism
#
# Submit with: anyscale job submit -f job.yaml
# Or with custom args: anyscale job submit -f job.yaml --entrypoint "python train.py --tp_size 4 --dp_size 2 --num_workers 8"

name: train-tp-fsdp-dtensor

entrypoint: >
python train.py
--model_name Qwen/Qwen2.5-0.5B
--tp_size 2
--dp_size 2
--num_workers 4
--dataset_name wikitext
--dataset_percentage 1.0
--batch_size 2
--seq_length 1024
--num_epochs 1
--learning_rate 1e-5
--log_interval 1
--debug_steps 100

image_uri: anyscale/ray:2.53.0-py312-cu128

working_dir: .

requirements:
- torch>=2.9.1
- transformers>=4.45.0
- datasets>=3.0.0
- accelerate>=1.0.0

compute_config:
head_node:
instance_type: m5.xlarge
worker_nodes:
- instance_type: g4dn.12xlarge
min_nodes: 1
max_nodes: 1

env_vars:
RAY_TRAIN_V2_ENABLED: "1"
HF_HOME: /mnt/cluster_storage/huggingface

max_retries: 0

tags:
project: tensor-parallelism
framework: fsdp
Loading