Skip to content

feat: MLX distributed model transfer between nodes#1463

Draft
AlexCheema wants to merge 20 commits intomainfrom
alexcheema/mlx-distributed-transfer
Draft

feat: MLX distributed model transfer between nodes#1463
AlexCheema wants to merge 20 commits intomainfrom
alexcheema/mlx-distributed-transfer

Conversation

@AlexCheema
Copy link
Copy Markdown
Contributor

Motivation

When running exo across multiple nodes, every node downloads model weights independently from HuggingFace. This is slow and wasteful when one node already has the model — especially since MLX distributed provides high-speed local interconnects (5-6 GB/s on Thunderbolt) that are much faster than internet download.

Changes

Automatic disk-to-memory transfer (during model loading):

When a multi-node instance is placed and only some nodes have the model downloaded, the node with the model broadcasts weights directly into the receivers' memory via MLX distributed all_sum. No disk write for weight data on the receiver.

Protocol:

  1. Ranks coordinate via all_sum bitmask to determine who has the model
  2. Source broadcasts metadata files (config.json, tokenizer, etc.) to receivers' disk
  3. Source broadcasts weight tensors directly into receivers' memory — 704 tensors for Qwen3-0.6B-4bit in ~4 seconds
  4. Receiver applies quantization config, then populates model with broadcast weights

Key files:

  • model_transfer.py — new module: coordinate, metadata transfer, weight broadcast
  • utils_mlx.pyshard_and_load() integrates transfer before model loading
  • plan.py — skip download when peer has model; allow LoadModel with partial downloads

Explicit disk-to-disk transfer (API endpoint):

POST /v1/models/{model_id}/distribute — copies all model files (including safetensors) from a source node to target nodes via MLX distributed. For pre-staging models before inference.

Key files:

  • api.py, commands.py, tasks.py — new endpoint and command types
  • main.py — handles DistributeModel by creating transfer-only instances
  • runner.py — handles TransferModelToDisk task
  • runners.pytransfer_only flag on ShardAssignments

Why It Works

MLX distributed all_sum adds arrays across all ranks. Source has real data, receivers have zeros: all_sum(data + zeros) = data on all ranks. This works for any number of receivers and preserves dtypes (including uint32 packed quantized weights).

For the receiver, load_model(config_only, lazy=False, strict=False) creates the model architecture, nn.quantize() converts to QuantizedLinear layers matching the broadcast weight shapes, then load_weights() populates with the broadcast data.

Test Plan

Manual Testing

Hardware: 2x Mac Mini M4 Pro 48GB, connected via Thunderbolt (10GbE LAN)

  • Placed 2-node pipeline-parallel instance for mlx-community/Qwen3-0.6B-4bit with model only on node 1
  • Node 2 received 8 metadata files + 704 weight tensors via MLX distributed in ~4 seconds
  • Warmup inference generated coherent tokens
  • Chat completion returned correct response (/v1/chat/completions)

Automated Testing

  • Updated test_download_and_loading.py with 2 new tests:
    • test_plan_loads_model_when_any_node_has_download_for_multi_node — LoadModel fires when only one node has the model
    • test_plan_does_not_load_model_when_no_node_has_download — no LoadModel when nobody has it
  • All 22 plan unit tests pass
  • basedpyright 0 errors, ruff check passes

🤖 Generated with Claude Code

@AlexCheema AlexCheema force-pushed the alexcheema/mlx-distributed-transfer branch from 652f4a7 to a19ee34 Compare February 16, 2026 18:02
@AlexCheema AlexCheema enabled auto-merge (squash) February 17, 2026 17:45
@AlexCheema AlexCheema force-pushed the alexcheema/mlx-distributed-transfer branch from 1ad5c4b to c82d20b Compare February 17, 2026 17:52
AlexCheema and others added 20 commits February 17, 2026 10:05
Add two transfer features using MLX distributed all_sum:

1. Disk-to-memory (automatic): During model loading in multi-node instances,
   if a peer has the model and the local node doesn't, stream weight tensors
   directly into memory via all_sum. No disk write on the receiver.

2. Disk-to-disk (explicit API): POST /v1/models/{model_id}/distribute
   copies all model files from source to target nodes' disk via MLX
   distributed file transfer.

New module: model_transfer.py with coordinate_transfer(),
transfer_metadata_files(), broadcast_model_weights(), transfer_all_files()

Modified plan.py to skip downloads when peers have the model and accept
partial downloads for multi-node LoadModel.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…t receivers

When receiving weights via MLX distributed broadcast, the receiver node
has no .safetensors files on disk. Two issues caused rms_norm shape mismatch
during warmup:

1. model.safetensors.index.json was transferred as metadata (has .json ext),
   causing load_model to create lazy tensor refs to nonexistent files
2. lazy=True created dangling references even without the index file

Fix: exclude *.safetensors.index.json from metadata transfer, and use
lazy=False when receiver has no local weight files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Temporary debug logging to compare broadcast weight names/shapes
against model parameters after load_model to diagnose rms_norm crash.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Root cause: load_model(lazy=False, strict=False) without weight files
skips quantization, creating float32 Linear layers (310 params) instead
of uint32 QuantizedLinear layers (704 params). The broadcast weights from
the source are quantized, so load_weights silently fails to replace them
due to shape mismatches (e.g., broadcast (2048,128)/uint32 vs model
(2048,1024)/float32).

Fix: read quantization config from config.json and call nn.quantize()
before loading broadcast weights, ensuring QuantizedLinear layers match.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The transfer is working end-to-end. Remove the _debug_compare_weights
function that was used to diagnose the shape mismatch issue.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Previously, all ranks with the model sent real data in all_sum, which
corrupts results with >2 nodes (data+data+0 = 2*data). Now only the
designated source_rank sends; all others send zeros regardless of
whether they have local files.

Also switch to mx.load(lazy=True) + weights.pop() so the source only
has one tensor in memory at a time instead of loading all safetensors
upfront.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
No circular dependency — model_transfer.py doesn't import from
utils_mlx.py or runner.py. Also remove redundant `import json` that
shadowed the module-level import.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Collapse _broadcast_int/_broadcast_bytes/_broadcast_json chain into
  single _broadcast_json (each was only called once)
- Extract _node_has_download helper to deduplicate download-checking
  logic across _any_peer_has_model, _all_downloads_complete, and
  _any_download_complete
- Remove unused has_metadata_files function
- Fix module docstring ("two" → "three" transfer modes)
- Remove section divider comment banners
- Simplify redundant is_source check in temp_dir conditional

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
MLX < 0.31 doesn't support mx.load(lazy=True). Try lazy first,
fall back to eager loading on TypeError.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the receiver has no safetensors files, load_model's internal
nn.quantize skips quantization (class_predicate finds no .scales keys
in empty weights dict), leaving the model un-quantized as full fp16.
With lazy=False, mx.eval(model.parameters()) materializes ~72GB of
fp16 data for a 36B-param model on a 24GB machine → silent OOM kill.

Fix: use lazy=True when broadcast_weights is available. This skips the
eager eval, and our code handles quantization correctly before loading
the broadcast weights.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
chat_template.jinja is needed by transformers for chat formatting
(e.g., GLM-4.7-Flash stores its chat template there).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of maintaining an allowlist of metadata extensions (which broke
when chat_template.jinja was missing), treat everything that isn't a
.safetensors file as metadata. More robust against new file formats.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
After model.load_weights(), both the broadcast_weights dict and the
model's parameter tree hold references to the same arrays. During
tensor_auto_parallel, the old full-size arrays can't be freed because
the dict still references them, causing ~2x peak memory.

Delete the dict before sharding so arrays are freed as each layer is
replaced with its sharded version.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of accumulating all weight tensors (~18GB for GLM) in a dict
before loading into the model, broadcast weights incrementally during
the sharding loop. Non-layer weights (embeddings, norms, lm_head) are
loaded upfront (~600MB), then each layer's weights are broadcast and
loaded just before that layer is sharded.

Peak memory drops from ~22GB to ~10GB, matching lazy-from-disk behavior.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… pipeline peak memory

- Add has_local_model field to LoadModel and TransferModelToDisk tasks,
  computed from download status in plan.py instead of filesystem checks
- Remove has_weight_files() function entirely
- In pipeline broadcast path, only load weights for layers in this
  node's [start_layer, end_layer) range — discard out-of-range results
  to reduce peak memory from ~22GB to ~10GB for 2-node pipeline

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move transfer_metadata_files() outside the conditional needs_transfer
block so it always runs in multi-node setups. This ensures config.json,
tokenizer files, and other metadata are present on all nodes before
load_model() is called, regardless of download status.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add missing weight_loader parameter to Step35ShardingStrategy.shard_model
  to match base class signature (fixes basedpyright errors)
- Exclude start_distributed_test.py from pytest collection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The receiver's nn.quantize() used no class_predicate, quantizing ALL
Linear/Embedding layers. load_model's internal quantize selectively
skips layers without .scales in the weights (e.g. lm_head, embeddings).
For large models with selective quantization this created shape
mismatches — broadcast weights couldn't load into incorrectly-quantized
layers, leaving them with garbage data.

Fix: use broadcast metadata weight names to build the same
class_predicate as load_model, also pass the mode parameter and respect
per-layer overrides from config.json. Exclude .safetensors.index.json
from metadata transfer to avoid stale weight references on receivers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nsfer

Step35ShardingStrategy.shard_model() accepted the weight_loader parameter
but never called it, unlike all other sharding strategies. This meant
receiver nodes using distributed weight broadcast with tensor parallelism
on Step35 models would get zero/garbage weights after sharding.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@AlexCheema AlexCheema force-pushed the alexcheema/mlx-distributed-transfer branch from c82d20b to de5c121 Compare February 17, 2026 18:06
@AlexCheema
Copy link
Copy Markdown
Contributor Author

Code Review: feat: MLX distributed model transfer between nodes

Nice work on this feature -- the idea of leveraging all_sum with zeros to broadcast model weights is clever, and the layer-by-layer streaming approach to bound memory is well thought out. The manual testing on Thunderbolt shows real-world viability. Below are the issues and suggestions I found, ordered roughly by severity.


Bugs / Correctness Issues

1. TransferModelToDisk handler calls coordinate_transfer a second time
In runner.py, the TransferModelToDisk handler calls coordinate_transfer(group, task.has_local_model) again before transfer_all_files. But transfer_all_files -> _transfer_files_to_disk does NOT call coordinate_transfer internally -- it just does manifest broadcast + file transfer. So the second coordinate_transfer call is a standalone collective all_sum that re-negotiates which rank is source. If any node's disk state changed between plan time and execution time (unlikely but possible), the re-negotiation could pick a different source_rank than expected. Worth either removing the redundant coordination or documenting why re-negotiation is intentional.

2. ShardAssignments.transfer_only mutation on a Pydantic model
In main.py:

instance.shard_assignments.transfer_only = True

ShardAssignments extends CamelCaseModel, which is NOT frozen, so this mutation technically works. However, the codebase style says "Use @final and immutability wherever applicable" and the event-sourcing model expects state changes to flow through events. Mutating a field after construction bypasses validation and is fragile -- if ShardAssignments ever gets frozen=True (as the codebase style would suggest), this breaks. Consider either:

  • Making transfer_only part of the PlaceInstance command or a new command variant
  • Using model_copy(update={"transfer_only": True}) to create a new immutable instance

3. _broadcast_json uses list(data) for byte conversion
In model_transfer.py:

arr = mx.array(list(data), dtype=mx.uint8)

This creates a Python list of every byte. For 200KB of JSON metadata, that is ~200K Python int objects at ~28 bytes each = ~5.6MB of allocations plus list overhead. The deserialization side has the same issue:

bytes(cast(list[int], result.tolist()))

Consider using numpy.frombuffer(data, dtype=np.uint8) and converting to mx.array, or mx.array(memoryview(data)) if MLX supports it.

4. _transfer_file_to_disk also uses list(data) for 100MB chunks -- this is the biggest performance concern
Same pattern but at 1000x scale:

chunk_arr = mx.array(list(data), dtype=mx.uint8)

Converting 100MB of file data to a Python list[int] allocates ~2.8GB of Python int objects (100M ints * ~28 bytes each). On the receiver side:

f.write(bytes(cast(list[int], chunk_data.tolist())))

For multi-GB safetensors transfer, this will be extremely slow and memory-hungry. Using numpy.frombuffer as an intermediate would fix both directions. This should be fixed before merge.


Reliability / Error Handling

5. No error recovery in the all_sum protocol
All collective operations are synchronous lock-step across all ranks. If any rank throws (disk full, OOM, permission error) mid-transfer:

  • Other ranks hang forever waiting for the next all_sum
  • There is no timeout on any all_sum call (unlike eval_with_timeout used elsewhere for weight loading)
  • The only recovery is killing all processes

At minimum, consider wrapping transfer operations with a timeout mechanism similar to eval_with_timeout. A more robust approach: a "status" all_sum after each file/layer where ranks can signal failure.

6. Missing Shutdown task for transfer-only instances after transfer completes
After TransferModelToDisk completes, the runner transitions directly to RunnerShutdown and breaks out of the loop. But looking at the plan, _transfer_model_to_disk only fires when the runner is RunnerConnected. After the runner reaches RunnerShutdown, who cleans up the transfer-only instance? The _shutdown plan step checks runner.status but for transfer-only instances, the runner process has already exited. Is the instance left in a zombie state in the cluster? This lifecycle needs verification.


Design / Architecture

7. all_sum bandwidth is O(N * file_size)
all_sum is an all-reduce operation. Even though receivers send zeros, the collective still involves all ranks in the reduction tree, so total network traffic scales with N. For 2 nodes on Thunderbolt this is fine, but for 4+ nodes over a network this becomes a bottleneck vs O(file_size) for point-to-point. Worth at least a comment about scaling characteristics.

8. Metadata files are always transferred, even when all nodes have the model
In utils_mlx.py:

# Called unconditionally BEFORE checking needs_transfer
transfer_metadata_files(model_path, group, is_source)

coordinate_transfer is called first and returns (needs_transfer, source_rank), but transfer_metadata_files is called regardless of the result. This means every multi-node model load does an extra collective file transfer operation even when all nodes already have the model. This should be gated on needs_transfer.

9. Quantization replication is fragile
The quantization logic in utils_mlx.py replicates mlx_lm's internal load_model quantization behavior:

def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
    if p in quant_config:
        return quant_config[p]
    if not hasattr(m, "to_quantized"):
        return False
    return f"{p}.scales" in broadcast_weight_names

If mlx_lm changes how it decides which layers to quantize, this code will silently diverge, potentially causing shape mismatches or incorrect inference. Consider:

  • Adding a comment with the specific mlx_lm version this was matched against
  • Adding an assertion that verifies the quantized model's parameter shapes match the broadcast metadata before loading weights

10. model_path_for_id vs build_model_path
A new function model_path_for_id is introduced that doesn't require the directory to exist (unlike build_model_path). shard_and_load now uses model_path_for_id instead of build_model_path. If build_model_path has any side effects or validation beyond directory existence, this change could introduce subtle issues. Worth verifying they produce identical paths when the directory does exist.


Testing

11. No test coverage for the transfer protocol itself
The new tests only cover plan logic (when to emit LoadModel/TransferModelToDisk). There are no tests for:

  • coordinate_transfer (could be tested with mock groups)
  • _broadcast_json serialization round-trip
  • WeightBroadcastState layer partitioning logic
  • _extract_layer_index regex matching
  • _is_metadata_file classification
  • The quantization replication logic

Even without a real MLX distributed group, pure unit tests for _extract_layer_index, _is_metadata_file, _build_manifest, and _parse_mx_dtype would catch regressions.

12. The renamed test changes semantics significantly
test_plan_does_not_load_model_until_all_shards_downloaded_globally was renamed to test_plan_loads_model_when_any_node_has_download_for_multi_node and the assertion flipped from assert result is None to assert isinstance(result, LoadModel). This is a deliberate behavior change (multi-node instances no longer require all nodes to have the download). But it means the old guarantee (all downloads must complete before loading) is no longer tested for single-node instances. Consider adding an explicit test that verifies single-node instances still require a local download complete before LoadModel fires.


Minor / Nits

13. _LAYER_RE could miss some architectures
The regex (?:^|\.)(layers|h)\.(\d+)\. only matches layers.N. or h.N. patterns. If a new model architecture uses a different name (e.g., blocks, decoder_layers, encoder.layer), layer weights would be classified as non-layer weights and broadcast upfront instead of per-layer. This wouldn't cause correctness issues but would increase peak memory. Consider documenting this assumption.

14. DistributeModel.target_node_ids allows empty list
In commands.py, target_node_ids: list[NodeId] has no minimum length validation. The API endpoint checks for empty targets and returns 400, but the command type itself allows it. Adding a min_length=1 validator would be more defensive.

15. Unnecessary set() construction in DistributeModel handler

all_node_ids = set([command.source_node_id] + list(command.target_node_ids))

Could be simplified to:

all_node_ids = {command.source_node_id, *command.target_node_ids}

16. Duplicate weight loading code between broadcast_model_weights and prepare_weight_broadcast
Both functions have nearly identical logic for loading safetensors lazily and broadcasting metadata. broadcast_model_weights could be refactored to use prepare_weight_broadcast + iterate all layers, reducing code duplication.


Summary

The core concept is sound and the implementation handles the happy path well. The main concerns are:

  1. Performance: list(data) byte conversion for large chunks will be very slow and memory-hungry (item 4) -- this should be fixed before merge
  2. Reliability: No timeout or error recovery on collective operations (item 5)
  3. Correctness: The quantization replication logic is the most fragile part (item 9) and needs either version-pinning or runtime verification
  4. Testing: The transfer protocol itself has no unit tests (item 11)
  5. Unconditional metadata transfer: Should be gated on needs_transfer (item 8)

The plan.py refactoring is clean and the new test cases are good additions. The API endpoint design is reasonable. Overall a strong first iteration that needs some hardening, particularly around the byte conversion performance issue and error recovery.

@exo-explore exo-explore deleted a comment from AlexCheema Feb 18, 2026
@rltakashige rltakashige marked this pull request as draft February 18, 2026 22:35
auto-merge was automatically disabled February 18, 2026 22:35

Pull request was converted to draft

ecohash-co added a commit to ecohash-co/exo that referenced this pull request Mar 11, 2026
When multiple nodes need the same model, only one downloads from
HuggingFace while others fetch it over the LAN — eliminating redundant
internet downloads and cutting cluster startup time roughly in half.

Architecture:
- PeerFileServer: lightweight aiohttp server on each node (port 52416)
  that serves model files from local cache with Range request support
- PeerAwareShardDownloader: wraps ResumableShardDownloader, checks if any
  peer already has the model before hitting HuggingFace
- Streaming relay: followers can download from a peer while it's still
  downloading from HF, via .partial.meta companion files that track
  flushed byte boundaries
- Graceful fallback: if peer transfer fails, falls back to HuggingFace
  with .partial resume support

Key design decisions:
- No new gossipsub messages — reuses existing NodeDownloadProgress events
  and topology for peer discovery and IP resolution
- No leader election — first node to start becomes de facto seed
- Backend-agnostic — works with MLX, tinygrad, PyTorch (any engine)
- Network-agnostic — works over any LAN (Ethernet, WiFi, Thunderbolt)
- Zero config — enabled by default, disable with --no-peer-download
- Complementary to PR exo-explore#1463 (MLX memory-to-memory transfer)

Addresses: exo-explore#1257, exo-explore#721, exo-explore#1606

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ecohash-co added a commit to ecohash-co/exo that referenced this pull request Apr 20, 2026
When multiple nodes need the same model, only one downloads from
HuggingFace while others fetch it over the LAN — eliminating redundant
internet downloads and cutting cluster startup time roughly in half.

Architecture:
- PeerFileServer: lightweight aiohttp server on each node (port 52416)
  that serves model files from local cache with Range request support
- PeerAwareShardDownloader: wraps ResumableShardDownloader, checks if any
  peer already has the model before hitting HuggingFace
- Streaming relay: followers can download from a peer while it's still
  downloading from HF, via .partial.meta companion files that track
  flushed byte boundaries
- Graceful fallback: if peer transfer fails, falls back to HuggingFace
  with .partial resume support

Key design decisions:
- No new gossipsub messages — reuses existing NodeDownloadProgress events
  and topology for peer discovery and IP resolution
- No leader election — first node to start becomes de facto seed
- Backend-agnostic — works with MLX, tinygrad, PyTorch (any engine)
- Network-agnostic — works over any LAN (Ethernet, WiFi, Thunderbolt)
- Zero config — enabled by default, disable with --no-peer-download
- Complementary to PR exo-explore#1463 (MLX memory-to-memory transfer)

Addresses: exo-explore#1257, exo-explore#721, exo-explore#1606

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant