Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from types import FrameType
from typing import Any

import numpy as np
Expand All @@ -24,10 +30,13 @@
from datasets import load_dataset
from PIL import Image
from scripts.ar_validate import validate_ar
from torch.distributed.tensor.experimental._attention import _SDPAMerger
from torch.utils.data import Dataset
from transformers import AutoProcessor, Trainer, TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother

import modelopt.torch.speculative.plugins.transformers
from modelopt.torch.speculative.utils import get_ttt_msk_func
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import is_master

Expand Down Expand Up @@ -566,3 +575,153 @@ def on_step_end(self, args, state, control, **kwargs):
except Exception:
print_rank_0("AR validation not available.")
return control


def _compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) -> torch.Tensor:
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""

msk_func = get_ttt_msk_func(seq_length, ttt_step)

dtypemin = torch.finfo(dtype).min
q_len = seq_length
kv_len = seq_length * (1 + ttt_step)
# Return tensor mask for non-flex attention
tensor_mask = msk_func(
None,
None,
torch.arange(q_len).view(1, 1, q_len, 1),
torch.arange(kv_len).view(1, 1, 1, kv_len),
).to(torch.cuda.current_device())
tensor_mask = torch.full_like(
tensor_mask, 0, dtype=dtype, device=torch.cuda.current_device()
).masked_fill(~tensor_mask, dtypemin)
return tensor_mask


def get_patched_templated_ring_attn(orig_templated_attn: Callable):
"""
Return patched version of
torch.distributed.tensor.experimental._attention._templated_ring_attention
to support TTT.
"""

def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
"""Get chunk-interleaved TTT mask for current rank.
e.g.:
2 ranks, ttt_step=1;
full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0],
[x, 0, 0, 0, 0, x, 0, 0],
[x, x, 0, 0, 0, 0, x, 0],
[x, x, x, 0, 0, 0, 0, x],

rank 0, step0: [[0, 0, x, 0],
[x, 0, 0, x]]

rank 1, step0: [[0, 0, x, 0],
[x, 0, 0, x]]

rank 0, step1: [[0, 0, 0, 0],
[0, 0, 0, 0]]

rank 1, step1: [[x, x, 0, 0],
[x, x, 0, 0]]

"""
device = torch.cuda.current_device()
q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device)
kv_indices = (
torch.arange(q_len * size * (ttt_step + 1), device=device)
.view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :]
.reshape(-1)
)
msk_func = get_ttt_msk_func(q_len * size, ttt_step)
attn_mask = msk_func(
None,
None,
q_indices.view(1, 1, -1, 1),
kv_indices.view(1, 1, 1, -1),
)
attn_bias = torch.where(
attn_mask,
torch.zeros((), dtype=dtype, device=attn_mask.device),
torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device),
)

return attn_bias

def patched_templated_attn(*args, **kwargs):
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
# Get original attention op
# Sensitive to impl of _templated_ring_attention
original_op = args[2]

# This patch is only enabled for eagle model by context manager, not base model.
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH

if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
raise ValueError(f"CP TTT only supports cuddn attention now. Got: {original_op}")

# Unset is_causal to use custom attn mask
if patch_enbabled:
kwargs["is_causal"] = False

def patched_op(*args, **kwargs):
# Inpect the parent frame to get current shard info
# This is sensitive to torch _templated_ring_attention impl
try:
frame: FrameType = inspect.currentframe()
f_back: FrameType = frame.f_back
rank = f_back.f_locals["rank"]
size = f_back.f_locals["size"]
query = f_back.f_locals["query"]
key = f_back.f_locals["key"]
i = f_back.f_locals["i"]
ttt_step = (key.shape[2] // query.shape[2]) - 1
except Exception as e:
print(f"Failed to capture loop variables in patched _templated_ring_attention: {e}")
# Set attn mask to permuted TTT mask
if "attn_bias" in kwargs:
kwargs["attn_bias"] = _get_sharded_ttt_msk(
i, rank, size, query.shape[2], ttt_step, query.dtype
)
# Perform shard attention
return original_op(*args, **kwargs)

return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs)

return patched_templated_attn


def patch_ring_attention_for_ttt():
"""Patch torch ring attention to support context parallelism for TTT."""
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.

# 1. Disable load balance, which is designed for causal mask.
# This affect how buffers are sharded. So need to be done permenantly before accelerate/hf trainer init.
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False

# 2. Patch templated ring attention for TTT mask.
original_templated_ring_attention = (
torch.distributed.tensor.experimental._attention._templated_ring_attention
)
original_templated_ring_attention_backward = (
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
)
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
get_patched_templated_ring_attn(original_templated_ring_attention)
)
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
)

# 3. Patch merger to skip the blank shard to avoid difference in output.
original_sdpa_merger_step = _SDPAMerger.step

def patched_sdpa_merger_step(
self, out: torch.Tensor, lse: torch.Tensor, partial: bool
) -> torch.Tensor:
if lse.sum() <= 0:
return
return original_sdpa_merger_step(self, out, lse, partial)

_SDPAMerger.step = patched_sdpa_merger_step
1 change: 1 addition & 0 deletions examples/speculative_decoding/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"fsdp_version":2}
31 changes: 15 additions & 16 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_CONFIG="${1#*=}"
;;
--fsdp_transformer_layer_cls_to_wrap*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
;;
--num_gpu*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_GPU="${1#*=}"
;;
--disable_tqdm*)
if [[ "$1" != *=* ]]; then shift; fi
DISABLE_TQDM="${1#*=}"
Expand All @@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
AR_VALIDATE_STEPS="${1#*=}"
;;
--cp_size*)
if [[ "$1" != *=* ]]; then shift; fi
CP_SIZE="${1#*=}"
;;
--dp_size*)
if [[ "$1" != *=* ]]; then shift; fi
DP_SHARD_SIZE="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -129,15 +129,15 @@ LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-4}
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
NUM_GPU=${NUM_GPU:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
DISABLE_TQDM=${DISABLE_TQDM:-False}
VLM_PROCESSOR=${VLM_PROCESSOR:-}
VLM_IMG_DIR=${VLM_IMG_DIR:-}
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
ESTIMATE_AR=${ESTIMATE_AR:-False}
CP_SIZE=${CP_SIZE:-1}
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}

if [[ "$MODE" == "medusa" ]]; then
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
Expand All @@ -163,11 +163,6 @@ else
OFFLINE_TRAINING_ARGS=""
fi

if [[ "$NUM_GPU" == 1 ]]; then
MULTI_GPU=""
else
MULTI_GPU="--multi_gpu"
fi

if [[ "$VLM_PROCESSOR" != "" ]]; then
VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR"
Expand All @@ -177,7 +172,7 @@ fi

# Disable tokenizers parallelism to avoid warning
export TOKENIZERS_PARALLELISM=False
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
CMD="accelerate launch --mixed_precision bf16 main.py \
--mode $MODE \
--eagle_decoder_type $EAGLE_DECODER_TYPE \
--model_name_or_path $MODEL \
Expand Down Expand Up @@ -206,6 +201,10 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
$SPECULATIVE_ARGS \
--fsdp 'full_shard' \
--fsdp_config fsdp_config.json \
--cp_size $CP_SIZE \
--dp_shard_size $DP_SHARD_SIZE \
"

start_time=$(date +%s)
Expand Down
17 changes: 16 additions & 1 deletion examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@

import torch
import transformers
from eagle_utils import EagleTrainerWithAccLog, EagleTrainingPlot, make_eagle_supervised_data_module
from accelerate import ParallelismConfig
from eagle_utils import (
EagleTrainerWithAccLog,
EagleTrainingPlot,
make_eagle_supervised_data_module,
patch_ring_attention_for_ttt,
)
from medusa_utils import make_medusa_supervised_data_module
from transformers.trainer_utils import get_last_checkpoint

Expand Down Expand Up @@ -100,6 +106,8 @@ class TrainingArguments(transformers.TrainingArguments):
remove_unused_columns: bool = field(
default=False, metadata={"help": "Set to False to keep extra args for VLM."}
)
cp_size: int = field(default=1, metadata={"help": "Context parallelism size."})
dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."})


@dataclass
Expand Down Expand Up @@ -130,6 +138,13 @@ def train():
model_args, data_args, training_args, medusa_args, eagle_args = (
parser.parse_args_into_dataclasses()
)
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
training_args.parallelism_config.sp_backend = None
print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}")

# Detecting last checkpoint.
Expand Down
9 changes: 4 additions & 5 deletions examples/speculative_decoding/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
flash-attn
openai
py7zr
sentencepiece>=0.2.0
tensorboardX
accelerate==1.12.0
torch==2.8.0
transformers==5.0.0rc1
wandb
Loading
Loading