feat: MLX distributed model transfer between nodes#1463
feat: MLX distributed model transfer between nodes#1463AlexCheema wants to merge 20 commits intomainfrom
Conversation
652f4a7 to
a19ee34
Compare
1ad5c4b to
c82d20b
Compare
Add two transfer features using MLX distributed all_sum:
1. Disk-to-memory (automatic): During model loading in multi-node instances,
if a peer has the model and the local node doesn't, stream weight tensors
directly into memory via all_sum. No disk write on the receiver.
2. Disk-to-disk (explicit API): POST /v1/models/{model_id}/distribute
copies all model files from source to target nodes' disk via MLX
distributed file transfer.
New module: model_transfer.py with coordinate_transfer(),
transfer_metadata_files(), broadcast_model_weights(), transfer_all_files()
Modified plan.py to skip downloads when peers have the model and accept
partial downloads for multi-node LoadModel.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…t receivers When receiving weights via MLX distributed broadcast, the receiver node has no .safetensors files on disk. Two issues caused rms_norm shape mismatch during warmup: 1. model.safetensors.index.json was transferred as metadata (has .json ext), causing load_model to create lazy tensor refs to nonexistent files 2. lazy=True created dangling references even without the index file Fix: exclude *.safetensors.index.json from metadata transfer, and use lazy=False when receiver has no local weight files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Temporary debug logging to compare broadcast weight names/shapes against model parameters after load_model to diagnose rms_norm crash. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Root cause: load_model(lazy=False, strict=False) without weight files skips quantization, creating float32 Linear layers (310 params) instead of uint32 QuantizedLinear layers (704 params). The broadcast weights from the source are quantized, so load_weights silently fails to replace them due to shape mismatches (e.g., broadcast (2048,128)/uint32 vs model (2048,1024)/float32). Fix: read quantization config from config.json and call nn.quantize() before loading broadcast weights, ensuring QuantizedLinear layers match. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The transfer is working end-to-end. Remove the _debug_compare_weights function that was used to diagnose the shape mismatch issue. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Previously, all ranks with the model sent real data in all_sum, which corrupts results with >2 nodes (data+data+0 = 2*data). Now only the designated source_rank sends; all others send zeros regardless of whether they have local files. Also switch to mx.load(lazy=True) + weights.pop() so the source only has one tensor in memory at a time instead of loading all safetensors upfront. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
No circular dependency — model_transfer.py doesn't import from utils_mlx.py or runner.py. Also remove redundant `import json` that shadowed the module-level import. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Collapse _broadcast_int/_broadcast_bytes/_broadcast_json chain into
single _broadcast_json (each was only called once)
- Extract _node_has_download helper to deduplicate download-checking
logic across _any_peer_has_model, _all_downloads_complete, and
_any_download_complete
- Remove unused has_metadata_files function
- Fix module docstring ("two" → "three" transfer modes)
- Remove section divider comment banners
- Simplify redundant is_source check in temp_dir conditional
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
MLX < 0.31 doesn't support mx.load(lazy=True). Try lazy first, fall back to eager loading on TypeError. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the receiver has no safetensors files, load_model's internal nn.quantize skips quantization (class_predicate finds no .scales keys in empty weights dict), leaving the model un-quantized as full fp16. With lazy=False, mx.eval(model.parameters()) materializes ~72GB of fp16 data for a 36B-param model on a 24GB machine → silent OOM kill. Fix: use lazy=True when broadcast_weights is available. This skips the eager eval, and our code handles quantization correctly before loading the broadcast weights. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
chat_template.jinja is needed by transformers for chat formatting (e.g., GLM-4.7-Flash stores its chat template there). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of maintaining an allowlist of metadata extensions (which broke when chat_template.jinja was missing), treat everything that isn't a .safetensors file as metadata. More robust against new file formats. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
After model.load_weights(), both the broadcast_weights dict and the model's parameter tree hold references to the same arrays. During tensor_auto_parallel, the old full-size arrays can't be freed because the dict still references them, causing ~2x peak memory. Delete the dict before sharding so arrays are freed as each layer is replaced with its sharded version. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of accumulating all weight tensors (~18GB for GLM) in a dict before loading into the model, broadcast weights incrementally during the sharding loop. Non-layer weights (embeddings, norms, lm_head) are loaded upfront (~600MB), then each layer's weights are broadcast and loaded just before that layer is sharded. Peak memory drops from ~22GB to ~10GB, matching lazy-from-disk behavior. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… pipeline peak memory - Add has_local_model field to LoadModel and TransferModelToDisk tasks, computed from download status in plan.py instead of filesystem checks - Remove has_weight_files() function entirely - In pipeline broadcast path, only load weights for layers in this node's [start_layer, end_layer) range — discard out-of-range results to reduce peak memory from ~22GB to ~10GB for 2-node pipeline Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move transfer_metadata_files() outside the conditional needs_transfer block so it always runs in multi-node setups. This ensures config.json, tokenizer files, and other metadata are present on all nodes before load_model() is called, regardless of download status. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add missing weight_loader parameter to Step35ShardingStrategy.shard_model to match base class signature (fixes basedpyright errors) - Exclude start_distributed_test.py from pytest collection Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The receiver's nn.quantize() used no class_predicate, quantizing ALL Linear/Embedding layers. load_model's internal quantize selectively skips layers without .scales in the weights (e.g. lm_head, embeddings). For large models with selective quantization this created shape mismatches — broadcast weights couldn't load into incorrectly-quantized layers, leaving them with garbage data. Fix: use broadcast metadata weight names to build the same class_predicate as load_model, also pass the mode parameter and respect per-layer overrides from config.json. Exclude .safetensors.index.json from metadata transfer to avoid stale weight references on receivers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nsfer Step35ShardingStrategy.shard_model() accepted the weight_loader parameter but never called it, unlike all other sharding strategies. This meant receiver nodes using distributed weight broadcast with tensor parallelism on Step35 models would get zero/garbage weights after sharding. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
c82d20b to
de5c121
Compare
Code Review: feat: MLX distributed model transfer between nodesNice work on this feature -- the idea of leveraging Bugs / Correctness Issues1. 2. instance.shard_assignments.transfer_only = True
3. arr = mx.array(list(data), dtype=mx.uint8)This creates a Python bytes(cast(list[int], result.tolist()))Consider using 4. chunk_arr = mx.array(list(data), dtype=mx.uint8)Converting 100MB of file data to a Python f.write(bytes(cast(list[int], chunk_data.tolist())))For multi-GB safetensors transfer, this will be extremely slow and memory-hungry. Using Reliability / Error Handling5. No error recovery in the all_sum protocol
At minimum, consider wrapping transfer operations with a timeout mechanism similar to 6. Missing Design / Architecture7. 8. Metadata files are always transferred, even when all nodes have the model # Called unconditionally BEFORE checking needs_transfer
transfer_metadata_files(model_path, group, is_source)
9. Quantization replication is fragile def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
if p in quant_config:
return quant_config[p]
if not hasattr(m, "to_quantized"):
return False
return f"{p}.scales" in broadcast_weight_namesIf
10. Testing11. No test coverage for the transfer protocol itself
Even without a real MLX distributed group, pure unit tests for 12. The renamed test changes semantics significantly Minor / Nits13. 14. 15. Unnecessary all_node_ids = set([command.source_node_id] + list(command.target_node_ids))Could be simplified to: all_node_ids = {command.source_node_id, *command.target_node_ids}16. Duplicate weight loading code between SummaryThe core concept is sound and the implementation handles the happy path well. The main concerns are:
The plan.py refactoring is clean and the new test cases are good additions. The API endpoint design is reasonable. Overall a strong first iteration that needs some hardening, particularly around the byte conversion performance issue and error recovery. |
Pull request was converted to draft
When multiple nodes need the same model, only one downloads from HuggingFace while others fetch it over the LAN — eliminating redundant internet downloads and cutting cluster startup time roughly in half. Architecture: - PeerFileServer: lightweight aiohttp server on each node (port 52416) that serves model files from local cache with Range request support - PeerAwareShardDownloader: wraps ResumableShardDownloader, checks if any peer already has the model before hitting HuggingFace - Streaming relay: followers can download from a peer while it's still downloading from HF, via .partial.meta companion files that track flushed byte boundaries - Graceful fallback: if peer transfer fails, falls back to HuggingFace with .partial resume support Key design decisions: - No new gossipsub messages — reuses existing NodeDownloadProgress events and topology for peer discovery and IP resolution - No leader election — first node to start becomes de facto seed - Backend-agnostic — works with MLX, tinygrad, PyTorch (any engine) - Network-agnostic — works over any LAN (Ethernet, WiFi, Thunderbolt) - Zero config — enabled by default, disable with --no-peer-download - Complementary to PR exo-explore#1463 (MLX memory-to-memory transfer) Addresses: exo-explore#1257, exo-explore#721, exo-explore#1606 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When multiple nodes need the same model, only one downloads from HuggingFace while others fetch it over the LAN — eliminating redundant internet downloads and cutting cluster startup time roughly in half. Architecture: - PeerFileServer: lightweight aiohttp server on each node (port 52416) that serves model files from local cache with Range request support - PeerAwareShardDownloader: wraps ResumableShardDownloader, checks if any peer already has the model before hitting HuggingFace - Streaming relay: followers can download from a peer while it's still downloading from HF, via .partial.meta companion files that track flushed byte boundaries - Graceful fallback: if peer transfer fails, falls back to HuggingFace with .partial resume support Key design decisions: - No new gossipsub messages — reuses existing NodeDownloadProgress events and topology for peer discovery and IP resolution - No leader election — first node to start becomes de facto seed - Backend-agnostic — works with MLX, tinygrad, PyTorch (any engine) - Network-agnostic — works over any LAN (Ethernet, WiFi, Thunderbolt) - Zero config — enabled by default, disable with --no-peer-download - Complementary to PR exo-explore#1463 (MLX memory-to-memory transfer) Addresses: exo-explore#1257, exo-explore#721, exo-explore#1606 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Motivation
When running exo across multiple nodes, every node downloads model weights independently from HuggingFace. This is slow and wasteful when one node already has the model — especially since MLX distributed provides high-speed local interconnects (5-6 GB/s on Thunderbolt) that are much faster than internet download.
Changes
Automatic disk-to-memory transfer (during model loading):
When a multi-node instance is placed and only some nodes have the model downloaded, the node with the model broadcasts weights directly into the receivers' memory via MLX distributed
all_sum. No disk write for weight data on the receiver.Protocol:
all_sumbitmask to determine who has the modelKey files:
model_transfer.py— new module: coordinate, metadata transfer, weight broadcastutils_mlx.py—shard_and_load()integrates transfer before model loadingplan.py— skip download when peer has model; allowLoadModelwith partial downloadsExplicit disk-to-disk transfer (API endpoint):
POST /v1/models/{model_id}/distribute— copies all model files (including safetensors) from a source node to target nodes via MLX distributed. For pre-staging models before inference.Key files:
api.py,commands.py,tasks.py— new endpoint and command typesmain.py— handlesDistributeModelby creating transfer-only instancesrunner.py— handlesTransferModelToDisktaskrunners.py—transfer_onlyflag onShardAssignmentsWhy It Works
MLX distributed
all_sumadds arrays across all ranks. Source has real data, receivers have zeros:all_sum(data + zeros) = dataon all ranks. This works for any number of receivers and preserves dtypes (including uint32 packed quantized weights).For the receiver,
load_model(config_only, lazy=False, strict=False)creates the model architecture,nn.quantize()converts to QuantizedLinear layers matching the broadcast weight shapes, thenload_weights()populates with the broadcast data.Test Plan
Manual Testing
Hardware: 2x Mac Mini M4 Pro 48GB, connected via Thunderbolt (10GbE LAN)
mlx-community/Qwen3-0.6B-4bitwith model only on node 1/v1/chat/completions)Automated Testing
test_download_and_loading.pywith 2 new tests:test_plan_loads_model_when_any_node_has_download_for_multi_node— LoadModel fires when only one node has the modeltest_plan_does_not_load_model_when_no_node_has_download— no LoadModel when nobody has itbasedpyright0 errors,ruff checkpasses🤖 Generated with Claude Code