Skip to content

Commit 1f3d10e

Browse files
vLLM loader: use safetensors_load_strategy="eager" on network volumes
The previous approach patched safetensors.torch.load_file, but vLLM uses safe_open (mmap) via safetensors_weights_iterator(). vLLM v0.17+ already has a built-in "eager" strategy that does open().read() — we just set load_config.safetensors_load_strategy="eager" on network FUSE volumes. Benchmarking on RunPod (RTX 4090, MFS FUSE, Qwen2.5-1.5B) shows eager is actually slower with warm page cache (2.5s vs 0.7s for mmap). The eager path helps on cold reads from slow network FSes where mmap page faults trigger expensive round-trips. Removed overlay from slow_fs detection since overlay-backed storage has fast mmap. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fe408f2 commit 1f3d10e

File tree

2 files changed

+47
-106
lines changed

2 files changed

+47
-106
lines changed

python/zerostart/integrations/vllm.py

Lines changed: 46 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
Provides a custom model loader that subclasses vLLM's DefaultModelLoader
44
and runs inside vLLM's EngineCore subprocess where weights are actually loaded.
55
6-
Key optimizations:
7-
1. Network volume fix: eager read instead of mmap on FUSE/NFS (30-50x faster)
8-
2. Patched safe_open: detect network volumes and use fast path
9-
3. Auto-registered via vLLM's plugin system (entry_points)
6+
Key optimization:
7+
Network volume fix: sets safetensors_load_strategy="eager" on FUSE/NFS
8+
volumes where mmap is 30-50x slower than sequential read.
109
1110
Usage:
1211
# Option 1: Auto-registration via entry_points (pip install zerostart)
@@ -16,9 +15,6 @@
1615
from zerostart.integrations.vllm import register
1716
register()
1817
# Then: --load-format zerostart
19-
20-
# Option 3: Transparent hook (patches from_pretrained in parent process)
21-
zerostart run --accelerate -p vllm -- python -m vllm.entrypoints.openai.api_server ...
2218
"""
2319

2420
from __future__ import annotations
@@ -31,9 +27,6 @@
3127
log = logging.getLogger("zerostart.vllm")
3228

3329
if TYPE_CHECKING:
34-
from collections.abc import Generator, Iterator
35-
36-
import torch
3730
import torch.nn as nn
3831
from vllm.config import ModelConfig
3932
from vllm.config.load import LoadConfig
@@ -53,7 +46,6 @@ def register() -> None:
5346
register_model_loader("zerostart")(ZerostartModelLoader)
5447
log.info("Registered zerostart model loader with vLLM")
5548
except ImportError:
56-
# Fallback for older vLLM versions
5749
try:
5850
import vllm.model_executor.model_loader as ml
5951
registry = getattr(ml, "_LOAD_FORMAT_TO_MODEL_LOADER", None)
@@ -118,9 +110,12 @@ def _is_network_volume(path: str) -> bool:
118110
return _network_volume_cache[path]
119111

120112
result = False
113+
# Only truly network-backed filesystems where mmap page faults
114+
# trigger network round-trips. Overlay is excluded because it's
115+
# backed by local storage and mmap works fine there.
121116
slow_fs = frozenset({
122117
"fuse", "fuse.juicefs", "fuse.gcsfuse", "fuse.sshfs",
123-
"nfs", "nfs4", "cifs", "smbfs", "9p", "overlay",
118+
"nfs", "nfs4", "cifs", "smbfs", "9p",
124119
})
125120

126121
try:
@@ -144,105 +139,56 @@ def _is_network_volume(path: str) -> bool:
144139
return result
145140

146141

147-
# ---------------------------------------------------------------------------
148-
# Fast weight iterator — replaces safetensors mmap with eager read on
149-
# network volumes, and patches safe_open for the same
150-
# ---------------------------------------------------------------------------
151-
152-
def _fast_safetensors_weights_iterator(
153-
hf_weights_files: list[str],
154-
) -> Generator[tuple[str, torch.Tensor], None, None]:
155-
"""Yield (name, tensor) pairs from safetensors files.
156-
157-
On network volumes: reads entire file into memory first (eager),
158-
avoiding the 30-50x mmap penalty on FUSE/NFS.
159-
On local NVMe: uses standard safe_open (mmap is fast).
160-
"""
161-
import safetensors.torch
162-
163-
for st_file in hf_weights_files:
164-
t0 = time.monotonic()
165-
166-
if _is_network_volume(st_file):
167-
# Eager read: load entire file to avoid mmap page fault penalty
168-
with open(st_file, "rb") as f:
169-
data = f.read()
170-
tensors = safetensors.torch.load(data)
171-
elapsed = time.monotonic() - t0
172-
log.info(
173-
"Eager read %s (%.2fs, %d tensors, %.0f MB)",
174-
Path(st_file).name, elapsed, len(tensors),
175-
len(data) / 1e6,
176-
)
177-
yield from tensors.items()
178-
else:
179-
# Local NVMe: mmap is fast, use standard safe_open
180-
from safetensors import safe_open
181-
with safe_open(st_file, framework="pt") as f:
182-
for name in f.keys():
183-
yield name, f.get_tensor(name)
184-
185-
186142
# ---------------------------------------------------------------------------
187143
# ZerostartModelLoader
188144
# ---------------------------------------------------------------------------
189145

190146
class ZerostartModelLoader(_DefaultLoader): # type: ignore[misc]
191147
"""vLLM model loader with network volume acceleration.
192148
193-
Subclasses DefaultModelLoader and overrides the weight iteration
194-
to use eager read on FUSE/NFS volumes. This runs INSIDE vLLM's
195-
EngineCore subprocess where weights are actually loaded.
149+
Subclasses DefaultModelLoader. On FUSE/NFS network volumes, sets
150+
safetensors_load_strategy="eager" so vLLM reads entire files into
151+
memory instead of using mmap (which is 30-50x slower on these FSes).
196152
197-
Key difference from transparent accelerate() hook:
198-
- accelerate() patches from_pretrained in the parent process
199-
- This loader patches weight loading in the EngineCore subprocess
200-
- vLLM loads weights via safe_open, not from_pretrained
153+
On local NVMe, delegates entirely to DefaultModelLoader (mmap is fast).
201154
"""
202155

203156
def __init__(self, load_config: LoadConfig):
204-
# Rewrite load_format to "safetensors" BEFORE super().__init__
205-
# so DefaultModelLoader._prepare_weights() doesn't reject "zerostart".
206-
# We store the original to know we were invoked as zerostart.
157+
import os
158+
159+
# Rewrite load_format from "zerostart" to "safetensors" so
160+
# DefaultModelLoader._prepare_weights() doesn't reject it.
207161
self._zerostart_requested = getattr(load_config, "load_format", None) == "zerostart"
208162
if self._zerostart_requested:
209163
load_config.load_format = "safetensors"
210164

211-
if _DefaultLoader is not object:
212-
super().__init__(load_config)
213-
else:
214-
self.load_config = load_config
215-
216-
# Detect if we're on a network volume
217-
self._on_network_volume = any(
165+
# Switch to eager loading if explicitly requested or on a network FS
166+
# where mmap page faults trigger expensive network round-trips.
167+
#
168+
# Note: on FUSE mounts with warm page cache (e.g. RunPod MFS), mmap
169+
# is actually faster than eager because it avoids copying data.
170+
# Eager only helps on cold reads from slow network FSes (NFS, JuiceFS).
171+
# Use ZEROSTART_EAGER=1 to force eager loading.
172+
force_eager = os.environ.get("ZEROSTART_EAGER", "").lower() in ("1", "true")
173+
on_network_volume = any(
218174
_is_network_volume(p)
219175
for p in ["/volume", "/gpu-cli-workspaces", "/workspace"]
220176
if Path(p).exists()
221177
)
222178

223-
if self._on_network_volume:
224-
log.info("Network volume detected — using eager read for safetensors")
225-
self._patch_safe_open()
226-
227-
def _patch_safe_open(self) -> None:
228-
"""Patch safetensors in this subprocess for eager read on network volumes."""
229-
try:
230-
import safetensors.torch as st
231-
232-
original_load_file = st.load_file
233-
234-
def patched_load_file(filename: str, device: str = "cpu") -> dict[str, Any]:
235-
if _is_network_volume(str(filename)):
236-
with open(filename, "rb") as f:
237-
data = f.read()
238-
return st.load(data, device=device)
239-
return original_load_file(filename, device=device)
179+
if force_eager or on_network_volume:
180+
current = getattr(load_config, "safetensors_load_strategy", "lazy")
181+
if current != "eager":
182+
load_config.safetensors_load_strategy = "eager"
183+
reason = "ZEROSTART_EAGER=1" if force_eager else "network volume detected"
184+
log.info(
185+
"Switched safetensors_load_strategy to 'eager' (%s)", reason,
186+
)
240187

241-
st.load_file = patched_load_file
242-
self._original_load_file = original_load_file
243-
log.debug("Patched safetensors.torch.load_file in subprocess")
244-
except ImportError:
245-
pass
188+
if _DefaultLoader is not object:
189+
super().__init__(load_config)
190+
else:
191+
self.load_config = load_config
246192

247193
def download_model(self, model_config: ModelConfig) -> None:
248194
"""Download model via HF hub (standard path)."""
@@ -259,26 +205,18 @@ def download_model(self, model_config: ModelConfig) -> None:
259205
log.warning("HF download failed: %s", e)
260206

261207
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
262-
"""Load weights with network volume optimization.
263-
264-
On network volumes: uses eager read (30-50x faster than mmap).
265-
On local NVMe: delegates to DefaultModelLoader (mmap is fast).
266-
"""
208+
"""Load weights, delegating to DefaultModelLoader."""
267209
t0 = time.monotonic()
268210

269211
if _DefaultLoader is not object and hasattr(super(), "load_weights"):
270-
# Let DefaultModelLoader handle it — our safe_open patch
271-
# is already installed and will intercept the reads
272212
super().load_weights(model, model_config)
273213
else:
274214
log.warning("DefaultModelLoader not available — basic weight loading")
275215
self._fallback_load_weights(model, model_config)
276216

277217
elapsed = time.monotonic() - t0
278-
log.info(
279-
"Weight loading complete (%.2fs, network_volume=%s)",
280-
elapsed, self._on_network_volume,
281-
)
218+
strategy = getattr(self.load_config, "safetensors_load_strategy", "unknown")
219+
log.info("Weight loading complete (%.2fs, strategy=%s)", elapsed, strategy)
282220

283221
def _fallback_load_weights(
284222
self, model: nn.Module, model_config: ModelConfig,
@@ -288,10 +226,13 @@ def _fallback_load_weights(
288226

289227
model_path = Path(model_config.model)
290228
if not model_path.is_dir():
291-
from zerostart.snapshot import _find_hf_cache_dir
292-
cache_dir = _find_hf_cache_dir(model_config.model)
293-
if cache_dir:
294-
model_path = cache_dir
229+
try:
230+
from zerostart.snapshot import _find_hf_cache_dir
231+
cache_dir = _find_hf_cache_dir(model_config.model)
232+
if cache_dir:
233+
model_path = cache_dir
234+
except ImportError:
235+
pass
295236

296237
sf_files = sorted(model_path.glob("*.safetensors"))
297238
if not sf_files:

tests/test_vllm_integration.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ r = subprocess.run(
162162
elapsed = time.monotonic() - t0
163163
print(r.stdout.strip())
164164
if r.returncode != 0:
165-
print("STDERR:", r.stderr[-1000:])
165+
print("STDERR:", r.stderr[-1500:])
166166
print(f"Wall clock: {elapsed:.2f}s")
167167
PYEOF
168168

0 commit comments

Comments
 (0)