Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions torchspec/inference/engine/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,23 +231,19 @@ def generate(
for i, sample in enumerate(inference_outputs):
if self.mooncake_store is not None:
key = str(uuid.uuid4())
shapes = self.mooncake_store.put(
store_meta = self.mooncake_store.put(
key=key,
hidden_states=sample["hidden_states"],
target=sample["target"],
input_ids=sample["input_ids"],
last_hidden_states=sample["last_hidden_states"],
)

dtypes = {"hidden_states": sample["hidden_states"].dtype}
if sample["target"] is not None:
dtypes["target"] = sample["target"].dtype

results.append(
{
"mooncake_key": key,
"tensor_shapes": shapes,
"tensor_dtypes": dtypes,
"tensor_shapes": store_meta["shapes"],
"tensor_dtypes": store_meta["dtypes"],
"packed_loss_mask": packed_loss_mask_list[i],
}
)
Expand Down
5 changes: 3 additions & 2 deletions torchspec/inference/engine/sgl_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torchspec.inference.engine.base import InferenceEngine
from torchspec.inference.engine.sgl_engine_decode import SglDecodeEngineMixin
from torchspec.ray.ray_actor import RayActor
from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE
from torchspec.utils.logging import logger, setup_file_logging
from torchspec.utils.misc import get_default_eagle3_aux_layer_ids

Expand Down Expand Up @@ -540,7 +541,7 @@ def _get_tensor_shapes(self, seq_len: int) -> dict:
def _get_tensor_dtypes(self) -> dict:
"""Get tensor dtypes for mooncake metadata."""
return {
"hidden_states": torch.bfloat16,
"hidden_states": HIDDEN_STATES_STORAGE_DTYPE,
"input_ids": torch.long,
"last_hidden_states": torch.bfloat16,
"last_hidden_states": HIDDEN_STATES_STORAGE_DTYPE,
}
5 changes: 3 additions & 2 deletions torchspec/inference/engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from torchspec.inference.engine.base import InferenceEngine
from torchspec.ray.ray_actor import RayActor
from torchspec.transfer.mooncake.eagle_store import HIDDEN_STATES_STORAGE_DTYPE
from torchspec.utils.logging import logger, setup_file_logging
from torchspec.utils.misc import get_default_eagle3_aux_layer_ids

Expand Down Expand Up @@ -495,7 +496,7 @@ def _get_tensor_shapes(self, seq_len: int) -> dict:

def _get_tensor_dtypes(self) -> dict:
return {
"hidden_states": torch.bfloat16,
"hidden_states": HIDDEN_STATES_STORAGE_DTYPE,
"input_ids": torch.long,
"last_hidden_states": torch.bfloat16,
"last_hidden_states": HIDDEN_STATES_STORAGE_DTYPE,
}
10 changes: 3 additions & 7 deletions torchspec/inference/engine/vllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def _store_and_get_metadata(
)

# Store to Mooncake
tensor_shapes = self._mooncake_store.put(
store_meta = self._mooncake_store.put(
key=mooncake_key,
hidden_states=hidden_states,
input_ids=input_ids,
Expand All @@ -676,15 +676,11 @@ def _store_and_get_metadata(

logger.debug(f"Successfully stored to Mooncake: key={mooncake_key}")

# Convert dtype to string for RPC serialization
# Include input_ids as a list for reconstruction (avoids Mooncake storage issues)
result[req_id] = {
"mooncake_key": mooncake_key,
"tensor_shapes": tensor_shapes,
"tensor_shapes": store_meta["shapes"],
"tensor_dtypes": {
"hidden_states": str(hidden_states.dtype).replace("torch.", ""),
"input_ids": str(input_ids.dtype).replace("torch.", ""),
"last_hidden_states": str(last_hidden_states.dtype).replace("torch.", ""),
k: str(v).replace("torch.", "") for k, v in store_meta["dtypes"].items()
},
"num_layers": len(layer_tensors),
"packed_loss_mask": self._packed_loss_mask_map.get(req_id),
Expand Down
38 changes: 32 additions & 6 deletions torchspec/transfer/mooncake/eagle_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import atexit
import ctypes
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import torch

Expand All @@ -48,6 +48,9 @@
torch.bool: 1,
}

# Canonical dtype for hidden-state tensors written to / read from Mooncake.
HIDDEN_STATES_STORAGE_DTYPE = torch.bfloat16


class EagleMooncakeStore(MooncakeHiddenStateStore):
"""
Expand Down Expand Up @@ -118,7 +121,7 @@ def put(
input_ids: torch.Tensor,
last_hidden_states: Optional[torch.Tensor],
target: Optional[torch.Tensor] = None,
) -> Dict[str, Tuple[int, ...]]:
) -> Dict[str, Any]:
"""Store Eagle3 output tensors via async batch_put_from.

DtoH staging runs on ``_copy_stream`` so the caller's compute stream
Expand All @@ -128,9 +131,24 @@ def put(
still in-flight.

For GPU Direct send the path is synchronous (no DtoH needed).

Returns a dict with ``"shapes"`` and ``"dtypes"`` sub-dicts that
reflect the *actually stored* tensors (post dtype-cast). Callers
should forward these to consumers so metadata always matches bytes.
"""
self._ensure_initialized()
logger.debug("put: starting for key=%s", key)

if hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE:
hidden_states = hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE)
if (
last_hidden_states is not None
and last_hidden_states.dtype != HIDDEN_STATES_STORAGE_DTYPE
):
last_hidden_states = last_hidden_states.to(HIDDEN_STATES_STORAGE_DTYPE)
if target is not None and target.dtype != HIDDEN_STATES_STORAGE_DTYPE:
target = target.to(HIDDEN_STATES_STORAGE_DTYPE)

keys = [f"{key}_hs", f"{key}_ids"]
tensors = [hidden_states, input_ids]

Expand Down Expand Up @@ -187,13 +205,19 @@ def put(
"hidden_states": tuple(hidden_states.shape),
"input_ids": tuple(input_ids.shape),
}
dtypes = {
"hidden_states": hidden_states.dtype,
"input_ids": input_ids.dtype,
}
if target is not None:
shapes["target"] = tuple(target.shape)
dtypes["target"] = target.dtype
if last_hidden_states is not None:
shapes["last_hidden_states"] = tuple(last_hidden_states.shape)
dtypes["last_hidden_states"] = last_hidden_states.dtype

logger.debug("put: completed key=%s, shapes=%s", key, shapes)
return shapes
return {"shapes": shapes, "dtypes": dtypes}

def flush(self) -> None:
"""Block until all in-flight async puts have completed.
Expand Down Expand Up @@ -282,22 +306,24 @@ def get(
(
"hidden_states",
shapes["hidden_states"],
dtypes.get("hidden_states", torch.bfloat16),
dtypes.get("hidden_states", HIDDEN_STATES_STORAGE_DTYPE),
),
("input_ids", shapes["input_ids"], torch.int64),
]

if "target" in shapes:
keys.append(f"{key}_tgt")
tensor_specs.append(("target", shapes["target"], dtypes.get("target", torch.bfloat16)))
tensor_specs.append(
("target", shapes["target"], dtypes.get("target", HIDDEN_STATES_STORAGE_DTYPE))
)

if "last_hidden_states" in shapes:
keys.append(f"{key}_lhs")
tensor_specs.append(
(
"last_hidden_states",
shapes["last_hidden_states"],
dtypes.get("hidden_states", torch.bfloat16),
dtypes.get("hidden_states", HIDDEN_STATES_STORAGE_DTYPE),
)
)

Expand Down
Loading