diff --git a/torchspec/inference/engine/hf_runner.py b/torchspec/inference/engine/hf_runner.py index d5b3914..6a6c75e 100644 --- a/torchspec/inference/engine/hf_runner.py +++ b/torchspec/inference/engine/hf_runner.py @@ -231,7 +231,7 @@ def generate( for i, sample in enumerate(inference_outputs): if self.mooncake_store is not None: key = str(uuid.uuid4()) - shapes = self.mooncake_store.put( + store_meta = self.mooncake_store.put( key=key, hidden_states=sample["hidden_states"], target=sample["target"], @@ -239,15 +239,11 @@ def generate( last_hidden_states=sample["last_hidden_states"], ) - dtypes = {"hidden_states": sample["hidden_states"].dtype} - if sample["target"] is not None: - dtypes["target"] = sample["target"].dtype - results.append( { "mooncake_key": key, - "tensor_shapes": shapes, - "tensor_dtypes": dtypes, + "tensor_shapes": store_meta["shapes"], + "tensor_dtypes": store_meta["dtypes"], "packed_loss_mask": packed_loss_mask_list[i], } ) diff --git a/torchspec/inference/engine/sgl_engine.py b/torchspec/inference/engine/sgl_engine.py index 1b2bc88..07a466a 100644 --- a/torchspec/inference/engine/sgl_engine.py +++ b/torchspec/inference/engine/sgl_engine.py @@ -39,6 +39,7 @@ from torchspec.inference.engine.base import InferenceEngine from torchspec.inference.engine.sgl_engine_decode import SglDecodeEngineMixin from torchspec.ray.ray_actor import RayActor +from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE from torchspec.utils.logging import logger, setup_file_logging from torchspec.utils.misc import get_default_eagle3_aux_layer_ids @@ -540,7 +541,7 @@ def _get_tensor_shapes(self, seq_len: int) -> dict: def _get_tensor_dtypes(self) -> dict: """Get tensor dtypes for mooncake metadata.""" return { - "hidden_states": torch.bfloat16, + "hidden_states": HIDDEN_STATES_STORAGE_DTYPE, "input_ids": torch.long, - "last_hidden_states": torch.bfloat16, + "last_hidden_states": HIDDEN_STATES_STORAGE_DTYPE, } diff --git a/torchspec/inference/engine/vllm_engine.py b/torchspec/inference/engine/vllm_engine.py index a6681d1..3006799 100644 --- a/torchspec/inference/engine/vllm_engine.py +++ b/torchspec/inference/engine/vllm_engine.py @@ -33,6 +33,7 @@ from torchspec.inference.engine.base import InferenceEngine from torchspec.ray.ray_actor import RayActor +from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE from torchspec.utils.logging import logger, setup_file_logging from torchspec.utils.misc import get_default_eagle3_aux_layer_ids @@ -495,7 +496,7 @@ def _get_tensor_shapes(self, seq_len: int) -> dict: def _get_tensor_dtypes(self) -> dict: return { - "hidden_states": torch.bfloat16, + "hidden_states": HIDDEN_STATES_STORAGE_DTYPE, "input_ids": torch.long, - "last_hidden_states": torch.bfloat16, + "last_hidden_states": HIDDEN_STATES_STORAGE_DTYPE, } diff --git a/torchspec/inference/engine/vllm_worker_extension.py b/torchspec/inference/engine/vllm_worker_extension.py index 99b8a4d..04f1d08 100644 --- a/torchspec/inference/engine/vllm_worker_extension.py +++ b/torchspec/inference/engine/vllm_worker_extension.py @@ -666,7 +666,7 @@ def _store_and_get_metadata( ) # Store to Mooncake - tensor_shapes = self._mooncake_store.put( + store_meta = self._mooncake_store.put( key=mooncake_key, hidden_states=hidden_states, input_ids=input_ids, @@ -676,15 +676,11 @@ def _store_and_get_metadata( logger.debug(f"Successfully stored to Mooncake: key={mooncake_key}") - # Convert dtype to string for RPC serialization - # Include input_ids as a list for reconstruction (avoids Mooncake storage issues) result[req_id] = { "mooncake_key": mooncake_key, - "tensor_shapes": tensor_shapes, + "tensor_shapes": store_meta["shapes"], "tensor_dtypes": { - "hidden_states": str(hidden_states.dtype).replace("torch.", ""), - "input_ids": str(input_ids.dtype).replace("torch.", ""), - "last_hidden_states": str(last_hidden_states.dtype).replace("torch.", ""), + k: str(v).replace("torch.", "") for k, v in store_meta["dtypes"].items() }, "num_layers": len(layer_tensors), "packed_loss_mask": self._packed_loss_mask_map.get(req_id), diff --git a/torchspec/transfer/mooncake/eagle_store.py b/torchspec/transfer/mooncake/eagle_store.py index fa44a5d..caeef92 100644 --- a/torchspec/transfer/mooncake/eagle_store.py +++ b/torchspec/transfer/mooncake/eagle_store.py @@ -21,7 +21,7 @@ import atexit import ctypes import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch @@ -48,6 +48,9 @@ torch.bool: 1, } +# Canonical dtype for hidden-state tensors written to / read from Mooncake. +HIDDEN_STATES_STORAGE_DTYPE = torch.bfloat16 + class EagleMooncakeStore(MooncakeHiddenStateStore): """ @@ -118,7 +121,7 @@ def put( input_ids: torch.Tensor, last_hidden_states: Optional[torch.Tensor], target: Optional[torch.Tensor] = None, - ) -> Dict[str, Tuple[int, ...]]: + ) -> Dict[str, Any]: """Store Eagle3 output tensors via async batch_put_from. DtoH staging runs on ``_copy_stream`` so the caller's compute stream @@ -128,9 +131,24 @@ def put( still in-flight. For GPU Direct send the path is synchronous (no DtoH needed). + + Returns a dict with ``"shapes"`` and ``"dtypes"`` sub-dicts that + reflect the *actually stored* tensors (post dtype-cast). Callers + should forward these to consumers so metadata always matches bytes. """ self._ensure_initialized() logger.debug("put: starting for key=%s", key) + + if hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE: + hidden_states = hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE) + if ( + last_hidden_states is not None + and last_hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE + ): + last_hidden_states = last_hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE) + if target is not None and target.dtype != HIDDEN_STATES_STORAGE_DTYPE: + target = target.to(HIDDEN_STATES_STORAGE_DTYPE) + keys = [f"{key}_hs", f"{key}_ids"] tensors = [hidden_states, input_ids] @@ -187,13 +205,19 @@ def put( "hidden_states": tuple(hidden_states.shape), "input_ids": tuple(input_ids.shape), } + dtypes = { + "hidden_states": hidden_states.dtype, + "input_ids": input_ids.dtype, + } if target is not None: shapes["target"] = tuple(target.shape) + dtypes["target"] = target.dtype if last_hidden_states is not None: shapes["last_hidden_states"] = tuple(last_hidden_states.shape) + dtypes["last_hidden_states"] = last_hidden_states.dtype logger.debug("put: completed key=%s, shapes=%s", key, shapes) - return shapes + return {"shapes": shapes, "dtypes": dtypes} def flush(self) -> None: """Block until all in-flight async puts have completed. @@ -282,14 +306,16 @@ def get( ( "hidden_states", shapes["hidden_states"], - dtypes.get("hidden_states", torch.bfloat16), + dtypes.get("hidden_states", HIDDEN_STATES_STORAGE_DTYPE), ), ("input_ids", shapes["input_ids"], torch.int64), ] if "target" in shapes: keys.append(f"{key}_tgt") - tensor_specs.append(("target", shapes["target"], dtypes.get("target", torch.bfloat16))) + tensor_specs.append( + ("target", shapes["target"], dtypes.get("target", HIDDEN_STATES_STORAGE_DTYPE)) + ) if "last_hidden_states" in shapes: keys.append(f"{key}_lhs") @@ -297,7 +323,7 @@ def get( ( "last_hidden_states", shapes["last_hidden_states"], - dtypes.get("hidden_states", torch.bfloat16), + dtypes.get("hidden_states", HIDDEN_STATES_STORAGE_DTYPE), ) )