33Provides a custom model loader that subclasses vLLM's DefaultModelLoader
44and runs inside vLLM's EngineCore subprocess where weights are actually loaded.
55
6- Key optimizations:
7- 1. Network volume fix: eager read instead of mmap on FUSE/NFS (30-50x faster)
8- 2. Patched safe_open: detect network volumes and use fast path
9- 3. Auto-registered via vLLM's plugin system (entry_points)
6+ Key optimization:
7+ Network volume fix: sets safetensors_load_strategy="eager" on FUSE/NFS
8+ volumes where mmap is 30-50x slower than sequential read.
109
1110Usage:
1211 # Option 1: Auto-registration via entry_points (pip install zerostart)
1615 from zerostart.integrations.vllm import register
1716 register()
1817 # Then: --load-format zerostart
19-
20- # Option 3: Transparent hook (patches from_pretrained in parent process)
21- zerostart run --accelerate -p vllm -- python -m vllm.entrypoints.openai.api_server ...
2218"""
2319
2420from __future__ import annotations
3127log = logging .getLogger ("zerostart.vllm" )
3228
3329if TYPE_CHECKING :
34- from collections .abc import Generator , Iterator
35-
36- import torch
3730 import torch .nn as nn
3831 from vllm .config import ModelConfig
3932 from vllm .config .load import LoadConfig
@@ -53,7 +46,6 @@ def register() -> None:
5346 register_model_loader ("zerostart" )(ZerostartModelLoader )
5447 log .info ("Registered zerostart model loader with vLLM" )
5548 except ImportError :
56- # Fallback for older vLLM versions
5749 try :
5850 import vllm .model_executor .model_loader as ml
5951 registry = getattr (ml , "_LOAD_FORMAT_TO_MODEL_LOADER" , None )
@@ -118,9 +110,12 @@ def _is_network_volume(path: str) -> bool:
118110 return _network_volume_cache [path ]
119111
120112 result = False
113+ # Only truly network-backed filesystems where mmap page faults
114+ # trigger network round-trips. Overlay is excluded because it's
115+ # backed by local storage and mmap works fine there.
121116 slow_fs = frozenset ({
122117 "fuse" , "fuse.juicefs" , "fuse.gcsfuse" , "fuse.sshfs" ,
123- "nfs" , "nfs4" , "cifs" , "smbfs" , "9p" , "overlay" ,
118+ "nfs" , "nfs4" , "cifs" , "smbfs" , "9p" ,
124119 })
125120
126121 try :
@@ -144,105 +139,56 @@ def _is_network_volume(path: str) -> bool:
144139 return result
145140
146141
147- # ---------------------------------------------------------------------------
148- # Fast weight iterator — replaces safetensors mmap with eager read on
149- # network volumes, and patches safe_open for the same
150- # ---------------------------------------------------------------------------
151-
152- def _fast_safetensors_weights_iterator (
153- hf_weights_files : list [str ],
154- ) -> Generator [tuple [str , torch .Tensor ], None , None ]:
155- """Yield (name, tensor) pairs from safetensors files.
156-
157- On network volumes: reads entire file into memory first (eager),
158- avoiding the 30-50x mmap penalty on FUSE/NFS.
159- On local NVMe: uses standard safe_open (mmap is fast).
160- """
161- import safetensors .torch
162-
163- for st_file in hf_weights_files :
164- t0 = time .monotonic ()
165-
166- if _is_network_volume (st_file ):
167- # Eager read: load entire file to avoid mmap page fault penalty
168- with open (st_file , "rb" ) as f :
169- data = f .read ()
170- tensors = safetensors .torch .load (data )
171- elapsed = time .monotonic () - t0
172- log .info (
173- "Eager read %s (%.2fs, %d tensors, %.0f MB)" ,
174- Path (st_file ).name , elapsed , len (tensors ),
175- len (data ) / 1e6 ,
176- )
177- yield from tensors .items ()
178- else :
179- # Local NVMe: mmap is fast, use standard safe_open
180- from safetensors import safe_open
181- with safe_open (st_file , framework = "pt" ) as f :
182- for name in f .keys ():
183- yield name , f .get_tensor (name )
184-
185-
186142# ---------------------------------------------------------------------------
187143# ZerostartModelLoader
188144# ---------------------------------------------------------------------------
189145
190146class ZerostartModelLoader (_DefaultLoader ): # type: ignore[misc]
191147 """vLLM model loader with network volume acceleration.
192148
193- Subclasses DefaultModelLoader and overrides the weight iteration
194- to use eager read on FUSE/NFS volumes. This runs INSIDE vLLM's
195- EngineCore subprocess where weights are actually loaded .
149+ Subclasses DefaultModelLoader. On FUSE/NFS network volumes, sets
150+ safetensors_load_strategy=" eager" so vLLM reads entire files into
151+ memory instead of using mmap (which is 30-50x slower on these FSes) .
196152
197- Key difference from transparent accelerate() hook:
198- - accelerate() patches from_pretrained in the parent process
199- - This loader patches weight loading in the EngineCore subprocess
200- - vLLM loads weights via safe_open, not from_pretrained
153+ On local NVMe, delegates entirely to DefaultModelLoader (mmap is fast).
201154 """
202155
203156 def __init__ (self , load_config : LoadConfig ):
204- # Rewrite load_format to "safetensors" BEFORE super().__init__
205- # so DefaultModelLoader._prepare_weights() doesn't reject "zerostart".
206- # We store the original to know we were invoked as zerostart.
157+ import os
158+
159+ # Rewrite load_format from "zerostart" to "safetensors" so
160+ # DefaultModelLoader._prepare_weights() doesn't reject it.
207161 self ._zerostart_requested = getattr (load_config , "load_format" , None ) == "zerostart"
208162 if self ._zerostart_requested :
209163 load_config .load_format = "safetensors"
210164
211- if _DefaultLoader is not object :
212- super ().__init__ (load_config )
213- else :
214- self .load_config = load_config
215-
216- # Detect if we're on a network volume
217- self ._on_network_volume = any (
165+ # Switch to eager loading if explicitly requested or on a network FS
166+ # where mmap page faults trigger expensive network round-trips.
167+ #
168+ # Note: on FUSE mounts with warm page cache (e.g. RunPod MFS), mmap
169+ # is actually faster than eager because it avoids copying data.
170+ # Eager only helps on cold reads from slow network FSes (NFS, JuiceFS).
171+ # Use ZEROSTART_EAGER=1 to force eager loading.
172+ force_eager = os .environ .get ("ZEROSTART_EAGER" , "" ).lower () in ("1" , "true" )
173+ on_network_volume = any (
218174 _is_network_volume (p )
219175 for p in ["/volume" , "/gpu-cli-workspaces" , "/workspace" ]
220176 if Path (p ).exists ()
221177 )
222178
223- if self ._on_network_volume :
224- log .info ("Network volume detected — using eager read for safetensors" )
225- self ._patch_safe_open ()
226-
227- def _patch_safe_open (self ) -> None :
228- """Patch safetensors in this subprocess for eager read on network volumes."""
229- try :
230- import safetensors .torch as st
231-
232- original_load_file = st .load_file
233-
234- def patched_load_file (filename : str , device : str = "cpu" ) -> dict [str , Any ]:
235- if _is_network_volume (str (filename )):
236- with open (filename , "rb" ) as f :
237- data = f .read ()
238- return st .load (data , device = device )
239- return original_load_file (filename , device = device )
179+ if force_eager or on_network_volume :
180+ current = getattr (load_config , "safetensors_load_strategy" , "lazy" )
181+ if current != "eager" :
182+ load_config .safetensors_load_strategy = "eager"
183+ reason = "ZEROSTART_EAGER=1" if force_eager else "network volume detected"
184+ log .info (
185+ "Switched safetensors_load_strategy to 'eager' (%s)" , reason ,
186+ )
240187
241- st .load_file = patched_load_file
242- self ._original_load_file = original_load_file
243- log .debug ("Patched safetensors.torch.load_file in subprocess" )
244- except ImportError :
245- pass
188+ if _DefaultLoader is not object :
189+ super ().__init__ (load_config )
190+ else :
191+ self .load_config = load_config
246192
247193 def download_model (self , model_config : ModelConfig ) -> None :
248194 """Download model via HF hub (standard path)."""
@@ -259,26 +205,18 @@ def download_model(self, model_config: ModelConfig) -> None:
259205 log .warning ("HF download failed: %s" , e )
260206
261207 def load_weights (self , model : nn .Module , model_config : ModelConfig ) -> None :
262- """Load weights with network volume optimization.
263-
264- On network volumes: uses eager read (30-50x faster than mmap).
265- On local NVMe: delegates to DefaultModelLoader (mmap is fast).
266- """
208+ """Load weights, delegating to DefaultModelLoader."""
267209 t0 = time .monotonic ()
268210
269211 if _DefaultLoader is not object and hasattr (super (), "load_weights" ):
270- # Let DefaultModelLoader handle it — our safe_open patch
271- # is already installed and will intercept the reads
272212 super ().load_weights (model , model_config )
273213 else :
274214 log .warning ("DefaultModelLoader not available — basic weight loading" )
275215 self ._fallback_load_weights (model , model_config )
276216
277217 elapsed = time .monotonic () - t0
278- log .info (
279- "Weight loading complete (%.2fs, network_volume=%s)" ,
280- elapsed , self ._on_network_volume ,
281- )
218+ strategy = getattr (self .load_config , "safetensors_load_strategy" , "unknown" )
219+ log .info ("Weight loading complete (%.2fs, strategy=%s)" , elapsed , strategy )
282220
283221 def _fallback_load_weights (
284222 self , model : nn .Module , model_config : ModelConfig ,
@@ -288,10 +226,13 @@ def _fallback_load_weights(
288226
289227 model_path = Path (model_config .model )
290228 if not model_path .is_dir ():
291- from zerostart .snapshot import _find_hf_cache_dir
292- cache_dir = _find_hf_cache_dir (model_config .model )
293- if cache_dir :
294- model_path = cache_dir
229+ try :
230+ from zerostart .snapshot import _find_hf_cache_dir
231+ cache_dir = _find_hf_cache_dir (model_config .model )
232+ if cache_dir :
233+ model_path = cache_dir
234+ except ImportError :
235+ pass
295236
296237 sf_files = sorted (model_path .glob ("*.safetensors" ))
297238 if not sf_files :
0 commit comments