From cc713ed5b996149b45f2c6dcb3b5e626cbcf456b Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Thu, 19 Mar 2026 22:34:41 +0000 Subject: [PATCH 1/4] [Docs] add PyTorch blog link to README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e9e6493..8417c44 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ TorchSpec currently includes training flows and examples for: ## 🚀 Blogs +- PyTorch blog: [TorchSpec: Speculative Decoding Training at Scale](https://pytorch.org/blog/torchspec-speculative-decoding-training-at-scale/) - Release blog: [TorchSpec: Speculative Decoding Training at Scale](https://lightseek.org/blog/torchspec-speculative-decoding-training-at-scale.html) - Released draft model: [lightseekorg/kimi-k2.5-eagle3](https://huggingface.co/lightseekorg/kimi-k2.5-eagle3) From 14f0e2a1a226199c2cecf58427d305f1af893695 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 20 Mar 2026 01:28:40 +0000 Subject: [PATCH 2/4] [Fix] cast hidden states to bfloat16 at Mooncake storage boundary SGLang may load models in float16 (e.g. MiniMax-M2.5) while training runs in bfloat16. Without an explicit cast, float16 bytes were stored and later interpreted as bfloat16, silently corrupting training data. Introduce HIDDEN_STATES_STORAGE_DTYPE as a single source of truth and cast hidden_states/last_hidden_states/target in EagleMooncakeStore.put() so both SGLang and vLLM paths are covered. --- torchspec/inference/engine/sgl_engine.py | 5 +++-- torchspec/inference/engine/vllm_engine.py | 5 +++-- torchspec/transfer/mooncake/eagle_store.py | 22 +++++++++++++++++++--- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/torchspec/inference/engine/sgl_engine.py b/torchspec/inference/engine/sgl_engine.py index 1b2bc88..07a466a 100644 --- a/torchspec/inference/engine/sgl_engine.py +++ b/torchspec/inference/engine/sgl_engine.py @@ -39,6 +39,7 @@ from torchspec.inference.engine.base import InferenceEngine from torchspec.inference.engine.sgl_engine_decode import SglDecodeEngineMixin from torchspec.ray.ray_actor import RayActor +from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE from torchspec.utils.logging import logger, setup_file_logging from torchspec.utils.misc import get_default_eagle3_aux_layer_ids @@ -540,7 +541,7 @@ def _get_tensor_shapes(self, seq_len: int) -> dict: def _get_tensor_dtypes(self) -> dict: """Get tensor dtypes for mooncake metadata.""" return { - "hidden_states": torch.bfloat16, + "hidden_states": HIDDEN_STATES_STORAGE_DTYPE, "input_ids": torch.long, - "last_hidden_states": torch.bfloat16, + "last_hidden_states": HIDDEN_STATES_STORAGE_DTYPE, } diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py index a6681d1..3006799 100644 --- a/torchspec/inference/engine/vllm_engine.py +++ b/torchspec/inference/engine/vllm_engine.py @@ -33,6 +33,7 @@ from torchspec.inference.engine.base import InferenceEngine from torchspec.ray.ray_actor import RayActor +from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE from torchspec.utils.logging import logger, setup_file_logging from torchspec.utils.misc import get_default_eagle3_aux_layer_ids @@ -495,7 +496,7 @@ def _get_tensor_shapes(self, seq_len: int) -> dict: def _get_tensor_dtypes(self) -> dict: return { - "hidden_states": torch.bfloat16, + "hidden_states": HIDDEN_STATES_STORAGE_DTYPE, "input_ids": torch.long, - "last_hidden_states": torch.bfloat16, + "last_hidden_states": HIDDEN_STATES_STORAGE_DTYPE, } diff --git a/torchspec/transfer/mooncake/eagle_store.py b/torchspec/transfer/mooncake/eagle_store.py index fa44a5d..2d7bfbe 100644 --- a/torchspec/transfer/mooncake/eagle_store.py +++ b/torchspec/transfer/mooncake/eagle_store.py @@ -48,6 +48,9 @@ torch.bool: 1, } +# Canonical dtype for hidden-state tensors written to / read from Mooncake. +HIDDEN_STATES_STORAGE_DTYPE = torch.bfloat16 + class EagleMooncakeStore(MooncakeHiddenStateStore): """ @@ -131,6 +134,17 @@ def put( """ self._ensure_initialized() logger.debug("put: starting for key=%s", key) + + if hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE: + hidden_states = hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE) + if ( + last_hidden_states is not None + and last_hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE + ): + last_hidden_states = last_hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE) + if target is not None and target.dtype != HIDDEN_STATES_STORAGE_DTYPE: + target = target.to(HIDDEN_STATES_STORAGE_DTYPE) + keys = [f"{key}_hs", f"{key}_ids"] tensors = [hidden_states, input_ids] @@ -282,14 +296,16 @@ def get( ( "hidden_states", shapes["hidden_states"], - dtypes.get("hidden_states", torch.bfloat16), + dtypes.get("hidden_states", HIDDEN_STATES_STORAGE_DTYPE), ), ("input_ids", shapes["input_ids"], torch.int64), ] if "target" in shapes: keys.append(f"{key}_tgt") - tensor_specs.append(("target", shapes["target"], dtypes.get("target", torch.bfloat16))) + tensor_specs.append( + ("target", shapes["target"], dtypes.get("target", HIDDEN_STATES_STORAGE_DTYPE)) + ) if "last_hidden_states" in shapes: keys.append(f"{key}_lhs") @@ -297,7 +313,7 @@ def get( ( "last_hidden_states", shapes["last_hidden_states"], - dtypes.get("hidden_states", torch.bfloat16), + dtypes.get("hidden_states", HIDDEN_STATES_STORAGE_DTYPE), ) ) From 8e569590c5afe80449edcbf0619b1b59bed3c26b Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 20 Mar 2026 01:40:35 +0000 Subject: [PATCH 3/4] [Fix] align emitted tensor_dtypes metadata with Mooncake storage dtype The previous commit casts hidden states to bfloat16 inside EagleMooncakeStore.put(), but the vLLM worker extension and HF runner still reported the original pre-cast dtype in their metadata dicts. Since the training-side data fetcher trusts that metadata to decode Mooncake bytes, the mismatch would silently corrupt reads. Both emitters now report HIDDEN_STATES_STORAGE_DTYPE so metadata and stored bytes agree. --- torchspec/inference/engine/hf_runner.py | 6 +++--- torchspec/inference/engine/vllm_worker_extension.py | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/torchspec/inference/engine/hf_runner.py b/torchspec/inference/engine/hf_runner.py index d5b3914..97168a8 100644 --- a/torchspec/inference/engine/hf_runner.py +++ b/torchspec/inference/engine/hf_runner.py @@ -36,7 +36,7 @@ from torchspec.config.inference_config import HFInferenceConfig from torchspec.config.mooncake_config import MooncakeConfig from torchspec.models.target import HFTargetModel -from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore +from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE, EagleMooncakeStore from torchspec.utils.logging import logger @@ -239,9 +239,9 @@ def generate( last_hidden_states=sample["last_hidden_states"], ) - dtypes = {"hidden_states": sample["hidden_states"].dtype} + dtypes = {"hidden_states": HIDDEN_STATES_STORAGE_DTYPE} if sample["target"] is not None: - dtypes["target"] = sample["target"].dtype + dtypes["target"] = HIDDEN_STATES_STORAGE_DTYPE results.append( { diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py index 99b8a4d..4441137 100644 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -676,15 +676,18 @@ def _store_and_get_metadata( logger.debug(f"Successfully stored to Mooncake: key={mooncake_key}") - # Convert dtype to string for RPC serialization - # Include input_ids as a list for reconstruction (avoids Mooncake storage issues) + # Report post-cast dtypes so readers interpret bytes correctly. + # EagleMooncakeStore.put() casts to HIDDEN_STATES_STORAGE_DTYPE. + from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE + + _hs_dtype_str = str(HIDDEN_STATES_STORAGE_DTYPE).replace("torch.", "") result[req_id] = { "mooncake_key": mooncake_key, "tensor_shapes": tensor_shapes, "tensor_dtypes": { - "hidden_states": str(hidden_states.dtype).replace("torch.", ""), + "hidden_states": _hs_dtype_str, "input_ids": str(input_ids.dtype).replace("torch.", ""), - "last_hidden_states": str(last_hidden_states.dtype).replace("torch.", ""), + "last_hidden_states": _hs_dtype_str, }, "num_layers": len(layer_tensors), "packed_loss_mask": self._packed_loss_mask_map.get(req_id), From bc177a36c445df2397529656258ebbf3f959a455 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 20 Mar 2026 01:42:32 +0000 Subject: [PATCH 4/4] [Fix] return stored dtypes from EagleMooncakeStore.put() Make put() the single source of truth for both shapes and dtypes by returning {"shapes": ..., "dtypes": ...} from the post-cast tensors. Callers now use the store's return value instead of reading dtypes from their own pre-cast local variables. This eliminates the class of bugs where a producer emits metadata with the wrong dtype because put() silently cast under the hood. --- torchspec/inference/engine/hf_runner.py | 12 ++++-------- .../inference/engine/vllm_worker_extension.py | 13 +++---------- torchspec/transfer/mooncake/eagle_store.py | 16 +++++++++++++--- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/torchspec/inference/engine/hf_runner.py b/torchspec/inference/engine/hf_runner.py index 97168a8..6a6c75e 100644 --- a/torchspec/inference/engine/hf_runner.py +++ b/torchspec/inference/engine/hf_runner.py @@ -36,7 +36,7 @@ from torchspec.config.inference_config import HFInferenceConfig from torchspec.config.mooncake_config import MooncakeConfig from torchspec.models.target import HFTargetModel -from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE, EagleMooncakeStore +from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore from torchspec.utils.logging import logger @@ -231,7 +231,7 @@ def generate( for i, sample in enumerate(inference_outputs): if self.mooncake_store is not None: key = str(uuid.uuid4()) - shapes = self.mooncake_store.put( + store_meta = self.mooncake_store.put( key=key, hidden_states=sample["hidden_states"], target=sample["target"], @@ -239,15 +239,11 @@ def generate( last_hidden_states=sample["last_hidden_states"], ) - dtypes = {"hidden_states": HIDDEN_STATES_STORAGE_DTYPE} - if sample["target"] is not None: - dtypes["target"] = HIDDEN_STATES_STORAGE_DTYPE - results.append( { "mooncake_key": key, - "tensor_shapes": shapes, - "tensor_dtypes": dtypes, + "tensor_shapes": store_meta["shapes"], + "tensor_dtypes": store_meta["dtypes"], "packed_loss_mask": packed_loss_mask_list[i], } ) diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py index 4441137..04f1d08 100644 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -666,7 +666,7 @@ def _store_and_get_metadata( ) # Store to Mooncake - tensor_shapes = self._mooncake_store.put( + store_meta = self._mooncake_store.put( key=mooncake_key, hidden_states=hidden_states, input_ids=input_ids, @@ -676,18 +676,11 @@ def _store_and_get_metadata( logger.debug(f"Successfully stored to Mooncake: key={mooncake_key}") - # Report post-cast dtypes so readers interpret bytes correctly. - # EagleMooncakeStore.put() casts to HIDDEN_STATES_STORAGE_DTYPE. - from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE - - _hs_dtype_str = str(HIDDEN_STATES_STORAGE_DTYPE).replace("torch.", "") result[req_id] = { "mooncake_key": mooncake_key, - "tensor_shapes": tensor_shapes, + "tensor_shapes": store_meta["shapes"], "tensor_dtypes": { - "hidden_states": _hs_dtype_str, - "input_ids": str(input_ids.dtype).replace("torch.", ""), - "last_hidden_states": _hs_dtype_str, + k: str(v).replace("torch.", "") for k, v in store_meta["dtypes"].items() }, "num_layers": len(layer_tensors), "packed_loss_mask": self._packed_loss_mask_map.get(req_id), diff --git a/torchspec/transfer/mooncake/eagle_store.py b/torchspec/transfer/mooncake/eagle_store.py index 2d7bfbe..caeef92 100644 --- a/torchspec/transfer/mooncake/eagle_store.py +++ b/torchspec/transfer/mooncake/eagle_store.py @@ -21,7 +21,7 @@ import atexit import ctypes import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch @@ -121,7 +121,7 @@ def put( input_ids: torch.Tensor, last_hidden_states: Optional[torch.Tensor], target: Optional[torch.Tensor] = None, - ) -> Dict[str, Tuple[int, ...]]: + ) -> Dict[str, Any]: """Store Eagle3 output tensors via async batch_put_from. DtoH staging runs on ``_copy_stream`` so the caller's compute stream @@ -131,6 +131,10 @@ def put( still in-flight. For GPU Direct send the path is synchronous (no DtoH needed). + + Returns a dict with ``"shapes"`` and ``"dtypes"`` sub-dicts that + reflect the *actually stored* tensors (post dtype-cast). Callers + should forward these to consumers so metadata always matches bytes. """ self._ensure_initialized() logger.debug("put: starting for key=%s", key) @@ -201,13 +205,19 @@ def put( "hidden_states": tuple(hidden_states.shape), "input_ids": tuple(input_ids.shape), } + dtypes = { + "hidden_states": hidden_states.dtype, + "input_ids": input_ids.dtype, + } if target is not None: shapes["target"] = tuple(target.shape) + dtypes["target"] = target.dtype if last_hidden_states is not None: shapes["last_hidden_states"] = tuple(last_hidden_states.shape) + dtypes["last_hidden_states"] = last_hidden_states.dtype logger.debug("put: completed key=%s, shapes=%s", key, shapes) - return shapes + return {"shapes": shapes, "dtypes": dtypes} def flush(self) -> None: """Block until all in-flight async puts have completed.