From 34be75ced53049ac3bbd28ed88cc918dacc69d96 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 8 Jan 2026 08:52:23 +0000 Subject: [PATCH 1/6] init: cp ttt Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 147 ++++++++++++++++++ examples/speculative_decoding/launch_train.sh | 33 ++-- examples/speculative_decoding/main.py | 15 +- .../torch/speculative/plugins/transformers.py | 36 ++--- modelopt/torch/speculative/utils.py | 29 ++++ 5 files changed, 222 insertions(+), 38 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 45c9c6632..ade81e842 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -13,9 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import os +from collections.abc import Callable from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import FrameType from typing import Any import numpy as np @@ -24,10 +30,13 @@ from datasets import load_dataset from PIL import Image from scripts.ar_validate import validate_ar +from torch.distributed.tensor.experimental._attention import _SDPAMerger from torch.utils.data import Dataset from transformers import AutoProcessor, Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother +import modelopt.torch.speculative.plugins.transformers +from modelopt.torch.speculative.utils import get_ttt_msk_func from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master @@ -566,3 +575,141 @@ def on_step_end(self, args, state, control, **kwargs): except Exception: print_rank_0("AR validation not available.") return control + + +def _compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) -> torch.Tensor: + """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" + + msk_func = get_ttt_msk_func(seq_length, ttt_step) + + dtypemin = torch.finfo(dtype).min + q_len = seq_length + kv_len = seq_length * (1 + ttt_step) + # Return tensor mask for non-flex attention + tensor_mask = msk_func( + None, + None, + torch.arange(q_len).view(1, 1, q_len, 1), + torch.arange(kv_len).view(1, 1, 1, kv_len), + ).to(torch.cuda.current_device()) + tensor_mask = torch.full_like( + tensor_mask, 0, dtype=dtype, device=torch.cuda.current_device() + ).masked_fill(~tensor_mask, dtypemin) + return tensor_mask + + +def get_patched_templated_ring_attn(orig_templated_attn: Callable): + """ + Return patched version of + torch.distributed.tensor.experimental._attention._templated_ring_attention + to support TTT. + """ + + def _get_sharded_ttt_msk(i, rank, size, seq_length, ttt_step, dtype): + """Get chunk-interleaved TTT mask for current rank. + e.g.: + 2 ranks, ttt_step=1; + full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0], + [x, 0, 0, 0, 0, x, 0, 0], + [x, x, 0, 0, 0, 0, x, 0], + [x, x, x, 0, 0, 0, 0, x], + + rank 0, step0: [[0, 0, x, 0], + [x, 0, 0, x]] + + rank 1, step0: [[0, 0, x, 0], + [x, 0, 0, x]] + + rank 0, step1: [[0, 0, 0, 0], + [0, 0, 0, 0]] + + rank 1, step1: [[x, x, 0, 0], + [x, x, 0, 0]] + + """ + # Get full TTT mask + attn_bias = _compute_ttt_attention_mask(1, seq_length * size, ttt_step, dtype) + # Chunk to get current ranks's q rows + attn_bias = attn_bias.chunk(size, dim=2)[rank] + # Split cols into seq_length blocks + attn_bias = attn_bias.split(seq_length, dim=3) + # Get interleaved col blocks for current rank + attn_bias = attn_bias[(rank - i) % size :: size] + return torch.cat(attn_bias, dim=3) + + def patched_templated_attn(*args, **kwargs): + """Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention.""" + # Get original attention op + # Sensitive to impl of _templated_ring_attention + original_op = args[2] + + # This patch is only enabled for eagle model by context manager, not base model. + patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH + + if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention: + raise ValueError(f"CP TTT only supports cuddn attention now. Got: {original_op}") + + # Unset is_causal to use custom attn mask + if patch_enbabled: + kwargs["is_causal"] = False + + def patched_op(*args, **kwargs): + # Inpect the parent frame to get current shard info + # This is sensitive to torch _templated_ring_attention impl + try: + frame: FrameType = inspect.currentframe() + f_back: FrameType = frame.f_back + rank = f_back.f_locals["rank"] + size = f_back.f_locals["size"] + query = f_back.f_locals["query"] + key = f_back.f_locals["key"] + i = f_back.f_locals["i"] + ttt_step = (key.shape[2] // query.shape[2]) - 1 + except Exception as e: + print(f"Failed to capture loop variables in patched _templated_ring_attention: {e}") + # Set attn mask to permuted TTT mask + if "attn_bias" in kwargs: + kwargs["attn_bias"] = _get_sharded_ttt_msk( + i, rank, size, query.shape[2], ttt_step, query.dtype + ) + # Perform shard attention + return original_op(*args, **kwargs) + + return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs) + + return patched_templated_attn + + +def patch_ring_attention_for_ttt(): + """Patch to enable context parallelism for TTT.""" + # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. + + # 1. Disable load balance, which is designed for causal mask. + # This affect how buffers are sharded. So need to be done permenantly before accelerate/hf trainer init. + torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False + + # 2. Patch templated ring attention for TTT mask. + original_templated_ring_attention = ( + torch.distributed.tensor.experimental._attention._templated_ring_attention + ) + original_templated_ring_attention_backward = ( + torch.distributed.tensor.experimental._attention._templated_ring_attention_backward + ) + torch.distributed.tensor.experimental._attention._templated_ring_attention = ( + get_patched_templated_ring_attn(original_templated_ring_attention) + ) + torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = ( + get_patched_templated_ring_attn(original_templated_ring_attention_backward) + ) + + # 3. Patch merger to skip the blank shard to avoid difference in output. + original_sdpa_merger_step = _SDPAMerger.step + + def patched_sdpa_merger_step( + self, out: torch.Tensor, lse: torch.Tensor, partial: bool + ) -> torch.Tensor: + if lse.sum() <= 0: + return + return original_sdpa_merger_step(self, out, lse, partial) + + _SDPAMerger.step = patched_sdpa_merger_step diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index e3b6c5a21..5cdb3cd49 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi EAGLE_CONFIG="${1#*=}" ;; - --fsdp_transformer_layer_cls_to_wrap*) - if [[ "$1" != *=* ]]; then shift; fi - FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}" - ;; - --num_gpu*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_GPU="${1#*=}" - ;; --disable_tqdm*) if [[ "$1" != *=* ]]; then shift; fi DISABLE_TQDM="${1#*=}" @@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi AR_VALIDATE_STEPS="${1#*=}" ;; + --cp_size*) + if [[ "$1" != *=* ]]; then shift; fi + CP_SIZE="${1#*=}" + ;; + --dp_size*) + if [[ "$1" != *=* ]]; then shift; fi + DP_SHARD_SIZE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -129,8 +129,6 @@ LR=${LR:-"1e-4"} TRAIN_BS=${TRAIN_BS:-4} MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} -FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} -NUM_GPU=${NUM_GPU:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} DISABLE_TQDM=${DISABLE_TQDM:-False} @@ -138,6 +136,8 @@ VLM_PROCESSOR=${VLM_PROCESSOR:-} VLM_IMG_DIR=${VLM_IMG_DIR:-} AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} ESTIMATE_AR=${ESTIMATE_AR:-False} +CP_SIZE=${CP_SIZE:-1} +DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -163,11 +163,6 @@ else OFFLINE_TRAINING_ARGS="" fi -if [[ "$NUM_GPU" == 1 ]]; then - MULTI_GPU="" -else - MULTI_GPU="--multi_gpu" -fi if [[ "$VLM_PROCESSOR" != "" ]]; then VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR" @@ -177,7 +172,7 @@ fi # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ +CMD="accelerate launch --mixed_precision bf16 main.py \ --mode $MODE \ --eagle_decoder_type $EAGLE_DECODER_TYPE \ --model_name_or_path $MODEL \ @@ -197,7 +192,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ --weight_decay 0.0 \ --warmup_steps 100 \ --lr_scheduler_type linear \ - --logging_steps 100 \ + --logging_steps 5 \ --tf32 True \ --data_path $DATA \ --disable_tqdm $DISABLE_TQDM \ @@ -206,6 +201,10 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS \ + --fsdp 'full_shard' \ + --fsdp_config fsdp_config.json \ + --cp_size $CP_SIZE \ + --dp_shard_size $DP_SHARD_SIZE \ " start_time=$(date +%s) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index cd1af9563..12914b46d 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -36,7 +36,13 @@ import torch import transformers -from eagle_utils import EagleTrainerWithAccLog, EagleTrainingPlot, make_eagle_supervised_data_module +from accelerate import ParallelismConfig +from eagle_utils import ( + EagleTrainerWithAccLog, + EagleTrainingPlot, + make_eagle_supervised_data_module, + patch_ring_attention_for_ttt, +) from medusa_utils import make_medusa_supervised_data_module from transformers.trainer_utils import get_last_checkpoint @@ -100,6 +106,8 @@ class TrainingArguments(transformers.TrainingArguments): remove_unused_columns: bool = field( default=False, metadata={"help": "Set to False to keep extra args for VLM."} ) + cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) + dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) @dataclass @@ -130,6 +138,11 @@ def train(): model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) + training_args.parallelism_config = ParallelismConfig( + cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size + ) + if training_args.cp_size > 1: + patch_ring_attention_for_ttt() print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") # Detecting last checkpoint. diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 39df7b9b7..b2bddc81e 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -56,10 +56,13 @@ AcceptanceRateValidation, ResBlock, _setup_kimi_k2_decoder, + enable_cp_ttt_patch, + get_ttt_msk_func, temporary_set_config_value, ) IGNORE_TOKEN_ID = LabelSmoother.ignore_index +ENABLE_CP_TTT_PATCH = False @MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) @@ -370,7 +373,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -678,16 +681,7 @@ def _compute_ttt_attention_mask( self, batch_size, seq_length, ttt_step ) -> BlockMask | torch.Tensor: """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" - - def msk_func(b, h, q_idx, kv_idx): - mask = kv_idx <= (q_idx - ttt_step) - for i in range(1, ttt_step + 1): - mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i)) & ( - kv_idx >= seq_length * i - ) - mask = mask | mask_block_i - return mask - + msk_func = get_ttt_msk_func(seq_length, ttt_step) dtypemin = torch.finfo(self._base_llm_config.dtype).min q_len = seq_length kv_len = seq_length * (1 + ttt_step) @@ -874,9 +868,9 @@ def forward( ) if not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values = DynamicCache(config=self._base_llm_config) if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) + eagle_cache = DynamicCache(config=self.eagle_module.config) # ====Run eagle forward==== eagle_loss = None @@ -912,13 +906,14 @@ def forward( if ttt_step == 0 else self._get_ttt_attention_mask(b, seq_length, ttt_step) ) - _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward( - eagle_input_hidden_states, - inputs_embeds, - attention_mask, - position_ids, - eagle_cache, - ) + with enable_cp_ttt_patch(): + _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward( + eagle_input_hidden_states, + inputs_embeds, + attention_mask, + position_ids, + eagle_cache, + ) eagle_input_hidden_states = torch.cat( ( torch.zeros( @@ -989,6 +984,7 @@ def _eagle_loss( assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) loss_mask = loss_mask[:, :, None] + loss_mask = loss_mask[:, : eagle_logits.shape[1]] classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( eagle_logits ) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 1f919de06..a3f91ce25 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -27,8 +27,11 @@ import torch.distributed from huggingface_hub import snapshot_download from torch import nn +from torch.nn.attention import SDPBackend, sdpa_kernel from transformers.cache_utils import DynamicCache +import modelopt.torch.speculative.plugins.transformers + KIMI_K2_REPO_ID = "moonshotai/Kimi-K2-Thinking" KIMI_K2_PACKAGE_NAME = "kimi_k2_temp" @@ -439,3 +442,29 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init return getattr(kimi_k2_module, "DeepseekV3DecoderLayer") + + +def get_ttt_msk_func(seq_length, ttt_step): + """Return mask function for Eagle3 Training Time Test.""" + + def ttt_msk_func(b, h, q_idx, kv_idx): + mask = kv_idx <= (q_idx - ttt_step) + for i in range(1, ttt_step + 1): + mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i)) & ( + kv_idx >= seq_length * i + ) + mask = mask | mask_block_i + return mask + + return ttt_msk_func + + +@contextlib.contextmanager +def enable_cp_ttt_patch(): + """Context manager to enable CP TTT patch.""" + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + try: + yield + finally: + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False From f0d5cb96b79c16cebc45014de563627e1e08243e Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 8 Jan 2026 08:55:47 +0000 Subject: [PATCH 2/6] docstring Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index ade81e842..5b4334687 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -681,7 +681,7 @@ def patched_op(*args, **kwargs): def patch_ring_attention_for_ttt(): - """Patch to enable context parallelism for TTT.""" + """Patch torch ring attention to support context parallelism for TTT.""" # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. # 1. Disable load balance, which is designed for causal mask. From e0f8e579aa69badc5e2d80a58aa53d43cf6d4f86 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 8 Jan 2026 08:57:12 +0000 Subject: [PATCH 3/6] add fsdp config Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/fsdp_config.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 examples/speculative_decoding/fsdp_config.json diff --git a/examples/speculative_decoding/fsdp_config.json b/examples/speculative_decoding/fsdp_config.json new file mode 100644 index 000000000..6d934182f --- /dev/null +++ b/examples/speculative_decoding/fsdp_config.json @@ -0,0 +1 @@ +{"fsdp_version":2} \ No newline at end of file From 73235dc966f1b550c2e859d78d80b7026c0c27f8 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:26:11 -0800 Subject: [PATCH 4/6] update requirements Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/requirements.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/speculative_decoding/requirements.txt b/examples/speculative_decoding/requirements.txt index 765af6104..176e43a65 100644 --- a/examples/speculative_decoding/requirements.txt +++ b/examples/speculative_decoding/requirements.txt @@ -1,5 +1,4 @@ -flash-attn -openai -py7zr -sentencepiece>=0.2.0 -tensorboardX +accelerate==1.12.0 +torch==2.8.0 +transformers==5.0.0rc1 +wandb From 8918bac3a9a8c92079ad5197b5dd672bb94ee147 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 11 Jan 2026 17:36:17 -0800 Subject: [PATCH 5/6] efficient mask construction Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 32 +++++++++++++------ examples/speculative_decoding/main.py | 2 ++ .../torch/speculative/plugins/transformers.py | 2 ++ 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 5b4334687..ade92d21e 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -605,7 +605,7 @@ def get_patched_templated_ring_attn(orig_templated_attn: Callable): to support TTT. """ - def _get_sharded_ttt_msk(i, rank, size, seq_length, ttt_step, dtype): + def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype): """Get chunk-interleaved TTT mask for current rank. e.g.: 2 ranks, ttt_step=1; @@ -627,15 +627,27 @@ def _get_sharded_ttt_msk(i, rank, size, seq_length, ttt_step, dtype): [x, x, 0, 0]] """ - # Get full TTT mask - attn_bias = _compute_ttt_attention_mask(1, seq_length * size, ttt_step, dtype) - # Chunk to get current ranks's q rows - attn_bias = attn_bias.chunk(size, dim=2)[rank] - # Split cols into seq_length blocks - attn_bias = attn_bias.split(seq_length, dim=3) - # Get interleaved col blocks for current rank - attn_bias = attn_bias[(rank - i) % size :: size] - return torch.cat(attn_bias, dim=3) + device = torch.cuda.current_device() + q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device) + kv_indices = ( + torch.arange(q_len * size * (ttt_step + 1), device=device) + .view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :] + .reshape(-1) + ) + msk_func = get_ttt_msk_func(q_len * size, ttt_step) + attn_mask = msk_func( + None, + None, + q_indices.view(1, 1, -1, 1), + kv_indices.view(1, 1, 1, -1), + ) + attn_bias = torch.where( + attn_mask, + torch.zeros((), dtype=dtype, device=attn_mask.device), + torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device), + ) + + return attn_bias def patched_templated_attn(*args, **kwargs): """Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention.""" diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 12914b46d..f8452cd90 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -143,6 +143,8 @@ def train(): ) if training_args.cp_size > 1: patch_ring_attention_for_ttt() + # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 + training_args.parallelism_config.sp_backend = None print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") # Detecting last checkpoint. diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index b2bddc81e..441aa16bb 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -63,6 +63,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index ENABLE_CP_TTT_PATCH = False +CACHED_SHARD_TTT_MASKS = {} @MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) @@ -901,6 +902,7 @@ def forward( # ====Perform training-time-testing with 3 extra eagle forward passes==== for ttt_step in range(self.num_ttt_steps): + # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. attention_mask = ( attention_mask_0 if ttt_step == 0 From e8ca86dd9ed9c8809d8d2299c790bcbf457b4d79 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 11 Jan 2026 17:39:16 -0800 Subject: [PATCH 6/6] revert irrelevant change Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/launch_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 5cdb3cd49..ad49d614f 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -192,7 +192,7 @@ CMD="accelerate launch --mixed_precision bf16 main.py \ --weight_decay 0.0 \ --warmup_steps 100 \ --lr_scheduler_type linear \ - --logging_steps 5 \ + --logging_steps 100 \ --tf32 True \ --data_path $DATA \ --disable_tqdm $DISABLE_TQDM \