From c93aaa78865f3a3c5c6d4a6313f2a59a29898a46 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 5 Mar 2026 22:10:24 +0000 Subject: [PATCH 1/4] feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 68 ++++ Makefile | 10 +- README.md | 1 + backend/index.yaml | 81 ++++ backend/python/mlx-distributed/Makefile | 23 ++ backend/python/mlx-distributed/backend.py | 351 ++++++++++++++++++ backend/python/mlx-distributed/coordinator.py | 104 ++++++ backend/python/mlx-distributed/install.sh | 15 + .../mlx-distributed/requirements-mps.txt | 1 + .../python/mlx-distributed/requirements.txt | 4 + backend/python/mlx-distributed/run.sh | 11 + backend/python/mlx-distributed/sharding.py | 136 +++++++ backend/python/mlx-distributed/test.py | 33 ++ backend/python/mlx-distributed/test.sh | 12 + core/application/p2p.go | 28 +- core/cli/run.go | 12 +- core/cli/worker/worker.go | 1 + core/cli/worker/worker_p2p.go | 4 +- core/cli/worker/worker_p2p_mlx.go | 149 ++++++++ core/config/application_config.go | 13 +- core/explorer/discovery.go | 2 +- core/http/endpoints/localai/p2p.go | 3 +- core/http/react-ui/src/pages/P2P.jsx | 167 +++++++-- core/http/routes/ui_api.go | 62 +++- core/http/views/p2p.html | 10 +- core/p2p/node.go | 5 +- core/schema/localai.go | 3 +- docs/content/features/mlx-distributed.md | 110 ++++++ 28 files changed, 1348 insertions(+), 71 deletions(-) create mode 100644 backend/python/mlx-distributed/Makefile create mode 100644 backend/python/mlx-distributed/backend.py create mode 100644 backend/python/mlx-distributed/coordinator.py create mode 100644 backend/python/mlx-distributed/install.sh create mode 100644 backend/python/mlx-distributed/requirements-mps.txt create mode 100644 backend/python/mlx-distributed/requirements.txt create mode 100644 backend/python/mlx-distributed/run.sh create mode 100644 backend/python/mlx-distributed/sharding.py create mode 100644 backend/python/mlx-distributed/test.py create mode 100644 backend/python/mlx-distributed/test.sh create mode 100644 core/cli/worker/worker_p2p_mlx.go create mode 100644 docs/content/features/mlx-distributed.md diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 2bc4c259368b..6804b85d7a0f 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -157,6 +157,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-cpu-mlx-distributed' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'true' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' # CUDA 12 builds - build-type: 'cublas' cuda-major-version: "12" @@ -470,6 +483,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "8" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-mlx-distributed' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" @@ -822,6 +848,19 @@ jobs: backend: "mlx-audio" dockerfile: "./backend/Dockerfile.python" context: "./" + - build-type: 'l4t' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx-distributed' + runs-on: 'ubuntu-24.04-arm' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + ubuntu-version: '2404' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" @@ -926,6 +965,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "13" + cuda-minor-version: "0" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-13-mlx-distributed' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" @@ -1423,6 +1475,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' + - build-type: 'l4t' + cuda-major-version: "12" + cuda-minor-version: "0" + platforms: 'linux/arm64' + tag-latest: 'auto' + tag-suffix: '-nvidia-l4t-mlx-distributed' + runs-on: 'ubuntu-24.04-arm' + base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" + skip-drivers: 'true' + backend: "mlx-distributed" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2204' # SYCL additional backends - build-type: 'intel' cuda-major-version: "" @@ -2016,6 +2081,9 @@ jobs: - backend: "mlx-audio" tag-suffix: "-metal-darwin-arm64-mlx-audio" build-type: "mps" + - backend: "mlx-distributed" + tag-suffix: "-metal-darwin-arm64-mlx-distributed" + build-type: "mps" - backend: "stablediffusion-ggml" tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml" build-type: "metal" diff --git a/Makefile b/Makefile index 54c21088afa0..15af1d39a5f6 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/voxtral GOCMD=go GOTEST=$(GOCMD) test @@ -451,6 +451,10 @@ backends/mlx-audio: BACKEND=mlx-audio $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)" +backends/mlx-distributed: + BACKEND=mlx-distributed $(MAKE) build-darwin-python-backend + ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-distributed.tar)" + backends/stablediffusion-ggml-darwin: BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)" @@ -495,6 +499,7 @@ BACKEND_NEMO = nemo|python|.|false|true BACKEND_VOXCPM = voxcpm|python|.|false|true BACKEND_WHISPERX = whisperx|python|.|false|true BACKEND_ACE_STEP = ace-step|python|.|false|true +BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true # Helper function to build docker image for a backend # Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG) @@ -548,12 +553,13 @@ $(eval $(call generate-docker-build-target,$(BACKEND_NEMO))) $(eval $(call generate-docker-build-target,$(BACKEND_VOXCPM))) $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX))) $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) +$(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral +docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-voxtral docker-build-mlx-distributed ######################################################## ### Mock Backend for E2E Tests diff --git a/README.md b/README.md index af937aa014dd..2ef0d9a6c2e9 100644 --- a/README.md +++ b/README.md @@ -432,6 +432,7 @@ LocalAI couldn't have been built without the help of great software already avai - https://github.com/EdVince/Stable-Diffusion-NCNN - https://github.com/ggerganov/whisper.cpp - https://github.com/rhasspy/piper +- [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation ## 🤗 Contributors diff --git a/backend/index.yaml b/backend/index.yaml index e518170ca680..392afa7357f0 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -259,6 +259,31 @@ nvidia-l4t: "nvidia-l4t-mlx-audio" nvidia-l4t-cuda-12: "nvidia-l4t-mlx-audio" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-audio" +- &mlx-distributed + name: "mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-distributed" + icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 + urls: + - https://github.com/ml-explore/mlx-lm + mirrors: + - localai/localai-backends:latest-metal-darwin-arm64-mlx-distributed + license: MIT + description: | + Run distributed LLM inference with MLX across multiple Apple Silicon Macs + tags: + - text-to-text + - LLM + - MLX + - distributed + capabilities: + default: "cpu-mlx-distributed" + nvidia: "cuda12-mlx-distributed" + metal: "metal-mlx-distributed" + nvidia-cuda-12: "cuda12-mlx-distributed" + nvidia-cuda-13: "cuda13-mlx-distributed" + nvidia-l4t: "nvidia-l4t-mlx-distributed" + nvidia-l4t-cuda-12: "nvidia-l4t-mlx-distributed" + nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-distributed" - &rerankers name: "rerankers" alias: "rerankers" @@ -791,6 +816,11 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx-audio +- !!merge <<: *mlx-distributed + name: "mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-distributed" + mirrors: + - localai/localai-backends:master-metal-darwin-arm64-mlx-distributed ## mlx - !!merge <<: *mlx name: "cpu-mlx" @@ -944,6 +974,57 @@ uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio +## mlx-distributed +- !!merge <<: *mlx-distributed + name: "cpu-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx-distributed" + mirrors: + - localai/localai-backends:latest-cpu-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cpu-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx-distributed" + mirrors: + - localai/localai-backends:master-cpu-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda12-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda12-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx-distributed" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-12-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx-distributed" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-13-mlx-distributed +- !!merge <<: *mlx-distributed + name: "nvidia-l4t-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx-distributed" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-mlx-distributed +- !!merge <<: *mlx-distributed + name: "nvidia-l4t-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx-distributed" + mirrors: + - localai/localai-backends:master-nvidia-l4t-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-nvidia-l4t-arm64-mlx-distributed" + uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed" + mirrors: + - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed +- !!merge <<: *mlx-distributed + name: "cuda13-nvidia-l4t-arm64-mlx-distributed-development" + uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed" + mirrors: + - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed - !!merge <<: *kitten-tts name: "kitten-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts" diff --git a/backend/python/mlx-distributed/Makefile b/backend/python/mlx-distributed/Makefile new file mode 100644 index 000000000000..b322efbe97cb --- /dev/null +++ b/backend/python/mlx-distributed/Makefile @@ -0,0 +1,23 @@ +.PHONY: mlx-distributed +mlx-distributed: + bash install.sh + +.PHONY: run +run: + @echo "Running mlx-distributed..." + bash run.sh + @echo "mlx-distributed run." + +.PHONY: test +test: + @echo "Testing mlx-distributed..." + bash test.sh + @echo "mlx-distributed tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py new file mode 100644 index 000000000000..34d1fad3eba1 --- /dev/null +++ b/backend/python/mlx-distributed/backend.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +""" +MLX Distributed Inference Backend for LocalAI. + +Rank 0 mode: Starts a gRPC server that coordinates distributed inference. +Worker mode: Enters a loop waiting for commands from rank 0. +""" +import asyncio +from concurrent import futures +import argparse +import json +import os +import signal +import sys +import tempfile + +import grpc + +import backend_pb2 +import backend_pb2_grpc + +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + + +def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None): + """Initialize MLX distributed runtime.""" + import mlx.core as mx + + if backend == "ring": + os.environ["MLX_HOSTFILE"] = hostfile + os.environ["MLX_RANK"] = str(rank) + os.environ["MLX_RING_VERBOSE"] = "1" + return mx.distributed.init(backend="ring", strict=True) + elif backend == "jaccl": + os.environ["MLX_IBV_DEVICES"] = hostfile + os.environ["MLX_RANK"] = str(rank) + if coordinator: + os.environ["MLX_JACCL_COORDINATOR"] = coordinator + return mx.distributed.init(backend="jaccl", strict=True) + else: + raise ValueError(f"Unknown backend: {backend}") + + +def is_float(s): + try: + float(s) + return True + except ValueError: + return False + + +def is_int(s): + try: + int(s) + return True + except ValueError: + return False + + +class BackendServicer(backend_pb2_grpc.BackendServicer): + """gRPC servicer for distributed MLX inference (runs only on rank 0).""" + + def __init__(self, group, dist_backend="ring"): + self.group = group + self.dist_backend = dist_backend + self.model = None + self.tokenizer = None + self.coordinator = None + self.options = {} + + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + async def LoadModel(self, request, context): + try: + import mlx.core as mx + from mlx_lm import load + from coordinator import DistributedCoordinator, CMD_LOAD_MODEL + from sharding import pipeline_auto_parallel + + print(f"[Rank 0] Loading distributed model: {request.Model}", file=sys.stderr) + + options = request.Options + self.options = {} + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + if is_float(value): + value = float(value) + elif is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + self.options[key] = value + + self.coordinator = DistributedCoordinator(self.group) + + # Broadcast load command to all ranks + self.coordinator.broadcast_command(CMD_LOAD_MODEL) + self.coordinator.broadcast_model_name(request.Model) + + tokenizer_config = {} + if request.TrustRemoteCode or self.options.get("trust_remote_code", False): + tokenizer_config["trust_remote_code"] = True + + if tokenizer_config: + self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) + else: + self.model, self.tokenizer = load(request.Model) + + # Apply pipeline parallelism + self.model = pipeline_auto_parallel(self.model, self.group) + + print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr) + + except Exception as err: + print(f"[Rank 0] Error loading model: {err}", file=sys.stderr) + return backend_pb2.Result(success=False, message=f"Error loading model: {err}") + + return backend_pb2.Result(message="Model loaded with distributed sharding", success=True) + + async def Predict(self, request, context): + try: + import mlx.core as mx + from mlx_lm import stream_generate + from mlx_lm.sample_utils import make_sampler + from coordinator import CMD_GENERATE + + prompt_text = self._prepare_prompt(request) + tokens = self.tokenizer.encode(prompt_text) + if hasattr(tokens, 'tolist'): + tokens = tokens.tolist() + + # Broadcast generate command + tokens + self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) + self.coordinator.broadcast_tokens(tokens) + + max_tokens, sampler_params = self._build_generation_params(request) + gen_params = self.coordinator.broadcast_generation_params( + max_tokens=max_tokens, + temperature=sampler_params.get('temp', 0.6), + top_p=sampler_params.get('top_p', 1.0), + ) + + sampler = make_sampler(**sampler_params) + + generated = [] + for response in stream_generate( + self.model, + self.tokenizer, + prompt=tokens, + max_tokens=gen_params["max_tokens"], + sampler=sampler, + ): + generated.append(response.text) + + return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8')) + + except Exception as e: + print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Generation failed: {str(e)}") + return backend_pb2.Reply(message=bytes("", encoding='utf-8')) + + async def PredictStream(self, request, context): + try: + import mlx.core as mx + from mlx_lm import stream_generate + from mlx_lm.sample_utils import make_sampler + from coordinator import CMD_GENERATE + + prompt_text = self._prepare_prompt(request) + tokens = self.tokenizer.encode(prompt_text) + if hasattr(tokens, 'tolist'): + tokens = tokens.tolist() + + self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) + self.coordinator.broadcast_tokens(tokens) + + max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) + gen_params = self.coordinator.broadcast_generation_params( + max_tokens=max_tokens, + temperature=sampler_params.get('temp', 0.6), + top_p=sampler_params.get('top_p', 1.0), + ) + + sampler = make_sampler(**sampler_params) + + for response in stream_generate( + self.model, + self.tokenizer, + prompt=tokens, + max_tokens=gen_params["max_tokens"], + sampler=sampler, + ): + yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) + + except Exception as e: + print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Streaming failed: {str(e)}") + yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + + def _prepare_prompt(self, request): + if not request.Prompt and request.UseTokenizerTemplate and request.Messages: + messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages] + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return request.Prompt + + def _build_generation_params(self, request, default_max_tokens=200): + max_tokens = getattr(request, 'Tokens', default_max_tokens) + if max_tokens == 0: + max_tokens = default_max_tokens + + temp = getattr(request, 'Temperature', 0.0) + if temp == 0.0: + temp = 0.6 + + top_p = getattr(request, 'TopP', 0.0) + if top_p == 0.0: + top_p = 1.0 + + sampler_params = { + 'temp': temp, + 'top_p': top_p, + 'min_p': getattr(request, 'MinP', 0.0), + 'top_k': getattr(request, 'TopK', 0), + 'xtc_threshold': 0.0, + 'xtc_probability': 0.0, + } + + seed = getattr(request, 'Seed', 0) + if seed != 0: + import mlx.core as mx + mx.random.seed(seed) + + if hasattr(self, 'options'): + if 'max_tokens' in self.options: + max_tokens = self.options['max_tokens'] + option_mapping = { + 'temp': 'temp', 'temperature': 'temp', + 'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k', + } + for opt_key, param_key in option_mapping.items(): + if opt_key in self.options: + sampler_params[param_key] = self.options[opt_key] + + xtc_special_tokens = [] + if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: + xtc_special_tokens = [self.tokenizer.eos_token_id] + sampler_params['xtc_special_tokens'] = xtc_special_tokens + + return max_tokens, sampler_params + + +def run_worker(group): + """Worker loop for ranks > 0. Waits for commands from rank 0.""" + from mlx_lm import load, stream_generate + from mlx_lm.sample_utils import make_sampler + from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN + from sharding import pipeline_auto_parallel + import mlx.core as mx + + coordinator = DistributedCoordinator(group) + model = None + tokenizer = None + + print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr) + + while True: + cmd, payload_size = coordinator.wait_for_command() + + if cmd == CMD_LOAD_MODEL: + model_name = coordinator.broadcast_model_name() + print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr) + model, tokenizer = load(model_name) + model = pipeline_auto_parallel(model, group) + print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr) + + elif cmd == CMD_GENERATE: + if model is None: + print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr) + continue + + token_count = coordinator.broadcast_token_count(payload_size) + tokens_array = coordinator.broadcast_tokens([0] * token_count) + tokens = tokens_array.tolist() + + gen_params = coordinator.broadcast_generation_params() + + sampler = make_sampler( + temp=gen_params["temperature"], + top_p=gen_params["top_p"], + ) + + # Participate in distributed compute, discard output + for _ in stream_generate( + model, tokenizer, + prompt=tokens, + max_tokens=gen_params["max_tokens"], + sampler=sampler, + ): + pass + + elif cmd == CMD_SHUTDOWN: + print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr) + break + + +async def serve(address, group, dist_backend): + server = grpc.aio.server( + migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), + ('grpc.max_send_message_length', 50 * 1024 * 1024), + ('grpc.max_receive_message_length', 50 * 1024 * 1024), + ], + ) + backend_pb2_grpc.add_BackendServicer_to_server( + BackendServicer(group, dist_backend), server + ) + server.add_insecure_port(address) + + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5))) + + await server.start() + print(f"[Rank 0] gRPC server listening on {address}", file=sys.stderr) + await server.wait_for_termination() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MLX Distributed Backend") + parser.add_argument("--addr", default="localhost:50051", help="gRPC listen address (rank 0 only)") + parser.add_argument("--worker", action="store_true", help="Run in worker mode (rank > 0)") + parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"], help="MLX distributed backend") + parser.add_argument("--hostfile", required=True, help="Path to hostfile JSON") + parser.add_argument("--rank", type=int, required=True, help="Rank of this process") + parser.add_argument("--coordinator", default=None, help="JACCL coordinator address") + args = parser.parse_args() + + group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator) + + if args.worker or args.rank > 0: + run_worker(group) + else: + asyncio.run(serve(args.addr, group, args.backend)) diff --git a/backend/python/mlx-distributed/coordinator.py b/backend/python/mlx-distributed/coordinator.py new file mode 100644 index 000000000000..1ec686a45070 --- /dev/null +++ b/backend/python/mlx-distributed/coordinator.py @@ -0,0 +1,104 @@ +""" +Distributed coordination using MLX distributed primitives. + +Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather. +Worker ranks wait in a loop for commands from rank 0. +""" +import json +import struct + +import mlx.core as mx + + +CMD_IDLE = 0 +CMD_GENERATE = 1 +CMD_LOAD_MODEL = 2 +CMD_SHUTDOWN = -1 + + +class DistributedCoordinator: + def __init__(self, group): + self.group = group + self.rank = group.rank() + self.world_size = group.size() + + def broadcast_command(self, cmd, payload_size=0): + """Rank 0 broadcasts a command to all ranks. + + Uses all_sum with only rank 0 providing non-zero values so every + rank receives the same command array. + """ + if self.rank == 0: + cmd_array = mx.array([cmd, payload_size], dtype=mx.int32) + else: + cmd_array = mx.zeros((2,), dtype=mx.int32) + result = mx.distributed.all_sum(cmd_array, group=self.group) + mx.eval(result) + return int(result[0].item()), int(result[1].item()) + + def broadcast_tokens(self, tokens): + """Broadcast input token ids from rank 0 to all ranks. + + Rank 0 provides the real token array; other ranks provide zeros of the + same shape. ``all_sum`` ensures every rank ends up with identical data. + """ + if self.rank == 0: + token_array = mx.array(tokens, dtype=mx.int32) + else: + token_array = mx.zeros((len(tokens),), dtype=mx.int32) + result = mx.distributed.all_sum(token_array, group=self.group) + mx.eval(result) + return result + + def broadcast_token_count(self, count): + """Broadcast the number of tokens so workers can prepare a buffer.""" + if self.rank == 0: + count_array = mx.array([count], dtype=mx.int32) + else: + count_array = mx.zeros((1,), dtype=mx.int32) + result = mx.distributed.all_sum(count_array, group=self.group) + mx.eval(result) + return int(result[0].item()) + + def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0): + """Broadcast generation parameters from rank 0.""" + if self.rank == 0: + params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32) + else: + params = mx.zeros((3,), dtype=mx.float32) + result = mx.distributed.all_sum(params, group=self.group) + mx.eval(result) + return { + "max_tokens": int(result[0].item()), + "temperature": float(result[1].item()), + "top_p": float(result[2].item()), + } + + def wait_for_command(self): + """Worker ranks block here until rank 0 broadcasts a command.""" + return self.broadcast_command(CMD_IDLE, 0) + + def broadcast_model_name(self, model_name=""): + """Broadcast model name string from rank 0 to all ranks. + + Encodes the model name as int32 codepoints so it can travel via + all_sum. + """ + if self.rank == 0: + encoded = [ord(c) for c in model_name] + # First broadcast the length + length = self.broadcast_token_count(len(encoded)) + if length > 0: + name_array = mx.array(encoded, dtype=mx.int32) + result = mx.distributed.all_sum(name_array, group=self.group) + mx.eval(result) + return model_name + return "" + else: + length = self.broadcast_token_count(0) + if length > 0: + name_array = mx.zeros((length,), dtype=mx.int32) + result = mx.distributed.all_sum(name_array, group=self.group) + mx.eval(result) + return "".join(chr(int(c.item())) for c in result) + return "" diff --git a/backend/python/mlx-distributed/install.sh b/backend/python/mlx-distributed/install.sh new file mode 100644 index 000000000000..253ee0c13f1b --- /dev/null +++ b/backend/python/mlx-distributed/install.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +USE_PIP=true +PYTHON_VERSION="" + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +installRequirements diff --git a/backend/python/mlx-distributed/requirements-mps.txt b/backend/python/mlx-distributed/requirements-mps.txt new file mode 100644 index 000000000000..1b47ff0e5d37 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-mps.txt @@ -0,0 +1 @@ +mlx-lm diff --git a/backend/python/mlx-distributed/requirements.txt b/backend/python/mlx-distributed/requirements.txt new file mode 100644 index 000000000000..fe67cdb50ebe --- /dev/null +++ b/backend/python/mlx-distributed/requirements.txt @@ -0,0 +1,4 @@ +grpcio==1.71.0 +protobuf +certifi +setuptools diff --git a/backend/python/mlx-distributed/run.sh b/backend/python/mlx-distributed/run.sh new file mode 100644 index 000000000000..8f608a54165d --- /dev/null +++ b/backend/python/mlx-distributed/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/mlx-distributed/sharding.py b/backend/python/mlx-distributed/sharding.py new file mode 100644 index 000000000000..0ea3dad54b4e --- /dev/null +++ b/backend/python/mlx-distributed/sharding.py @@ -0,0 +1,136 @@ +""" +Auto-parallelism for MLX distributed inference. + +Provides pipeline parallelism (Ring backend) by wrapping model layers with +distributed send/recv operations. Ported from exo's auto_parallel.py with +simplifications for LocalAI's use case. +""" +import mlx.core as mx +import mlx.nn as nn + + +class PipelineFirstLayer(nn.Module): + """Wraps the first layer on each rank to receive from the previous rank.""" + + def __init__(self, original_layer, rank, group): + super().__init__() + dict.__setitem__(self, "_original_layer", original_layer) + self.rank = rank + self.group = group + + @property + def original_layer(self): + return self["_original_layer"] + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self["_original_layer"], name) + + def __call__(self, x, *args, **kwargs): + if self.rank != 0: + mx.eval(x) + x = mx.distributed.recv_like(x, self.rank - 1, group=self.group) + mx.eval(x) + return self.original_layer(x, *args, **kwargs) + + +class PipelineLastLayer(nn.Module): + """Wraps the last layer on each rank to send to the next rank.""" + + def __init__(self, original_layer, rank, world_size, group): + super().__init__() + dict.__setitem__(self, "_original_layer", original_layer) + self.rank = rank + self.world_size = world_size + self.group = group + + @property + def original_layer(self): + return self["_original_layer"] + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self["_original_layer"], name) + + def __call__(self, x, *args, **kwargs): + output = self.original_layer(x, *args, **kwargs) + mx.eval(output) + if self.rank != self.world_size - 1: + output = mx.distributed.send( + output, (self.rank + 1) % self.world_size, group=self.group + ) + mx.eval(output) + # Gather output from all ranks so every rank has the final result + output = mx.distributed.all_gather(output, group=self.group)[ + -output.shape[0] : + ] + mx.eval(output) + return output + + +def get_inner_model(model): + """Get the inner model (model.model or model.transformer).""" + for attr in ("model", "transformer"): + inner = getattr(model, attr, None) + if isinstance(inner, nn.Module): + # Some models have model.model (e.g. language_model.model) + inner_inner = getattr(inner, "model", None) + if isinstance(inner_inner, nn.Module): + return inner_inner + return inner + raise ValueError("Model must have a 'model' or 'transformer' attribute") + + +def get_layers(inner_model): + """Get the list of transformer layers.""" + for attr in ("layers", "h"): + layers = getattr(inner_model, attr, None) + if layers is not None: + return layers + raise ValueError("Model must have a 'layers' or 'h' attribute") + + +def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None): + """Apply pipeline parallelism to a model. + + Each rank only keeps its slice of layers. The first layer receives from + the previous rank, and the last layer sends to the next rank. + + Args: + model: The MLX model (must have model.layers or similar) + group: The distributed group + start_layer: First layer index for this rank (auto-computed if None) + end_layer: Last layer index (exclusive) for this rank (auto-computed if None) + """ + rank = group.rank() + world_size = group.size() + + inner = get_inner_model(model) + layers = list(get_layers(inner)) + total_layers = len(layers) + + if start_layer is None or end_layer is None: + layers_per_rank = total_layers // world_size + remainder = total_layers % world_size + start_layer = rank * layers_per_rank + min(rank, remainder) + end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0) + + layers = layers[start_layer:end_layer] + for layer in layers: + mx.eval(layer) + + # Wrap first and last layers + layers[0] = PipelineFirstLayer(layers[0], rank, group=group) + layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group) + + # Replace layers on the inner model + if hasattr(inner, "layers"): + inner.layers = layers + elif hasattr(inner, "h"): + inner.h = layers + + return model diff --git a/backend/python/mlx-distributed/test.py b/backend/python/mlx-distributed/test.py new file mode 100644 index 000000000000..16f96d04c692 --- /dev/null +++ b/backend/python/mlx-distributed/test.py @@ -0,0 +1,33 @@ +import unittest +import subprocess +import time + +import grpc +import backend_pb2 +import backend_pb2_grpc + + +class TestBackendServicer(unittest.TestCase): + def setUp(self): + self.service = subprocess.Popen( + ["python", "backend.py", "--addr", "localhost:50051", + "--hostfile", "/dev/null", "--rank", "0"] + ) + time.sleep(10) + + def tearDown(self) -> None: + self.service.terminate() + self.service.wait() + + def test_server_startup(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() diff --git a/backend/python/mlx-distributed/test.sh b/backend/python/mlx-distributed/test.sh new file mode 100644 index 000000000000..f31ae54e47dc --- /dev/null +++ b/backend/python/mlx-distributed/test.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/core/application/p2p.go b/core/application/p2p.go index 99527e841260..8522a121dc96 100644 --- a/core/application/p2p.go +++ b/core/application/p2p.go @@ -84,19 +84,37 @@ func (a *Application) StartP2P() error { n = node } - // Attach a ServiceDiscoverer to the p2p node + // Attach a ServiceDiscoverer to the p2p node for llama.cpp workers xlog.Info("Starting P2P server discovery...") - if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node schema.NodeData) { + if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID), func(serviceID string, node schema.NodeData) { var tunnelAddresses []string - for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) { + for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID)) { if v.IsOnline() { tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) } else { xlog.Info("Node is offline", "node", v.ID) } } - if a.applicationConfig.TunnelCallback != nil { - a.applicationConfig.TunnelCallback(tunnelAddresses) + if a.applicationConfig.LlamaCPPTunnelCallback != nil { + a.applicationConfig.LlamaCPPTunnelCallback(tunnelAddresses) + } + }, true); err != nil { + return err + } + + // Attach a ServiceDiscoverer for MLX distributed workers + xlog.Info("Starting MLX P2P worker discovery...") + if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.MLXWorkerID), func(serviceID string, node schema.NodeData) { + var tunnelAddresses []string + for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.MLXWorkerID)) { + if v.IsOnline() { + tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) + } else { + xlog.Info("MLX node is offline", "node", v.ID) + } + } + if a.applicationConfig.MLXTunnelCallback != nil { + a.applicationConfig.MLXTunnelCallback(tunnelAddresses) } }, true); err != nil { return err diff --git a/core/cli/run.go b/core/cli/run.go index 7c7c92d0ed42..33671d3e1965 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -2,8 +2,10 @@ package cli import ( "context" + "encoding/json" "fmt" "os" + "path/filepath" "strings" "time" @@ -140,12 +142,18 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { config.WithMachineTag(r.MachineTag), config.WithAPIAddress(r.Address), config.WithAgentJobRetentionDays(r.AgentJobRetentionDays), - config.WithTunnelCallback(func(tunnels []string) { + config.WithLlamaCPPTunnelCallback(func(tunnels []string) { tunnelEnvVar := strings.Join(tunnels, ",") - // TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar) xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar) }), + config.WithMLXTunnelCallback(func(tunnels []string) { + hostfile := filepath.Join(os.TempDir(), "localai_mlx_hostfile.json") + data, _ := json.Marshal(tunnels) + os.WriteFile(hostfile, data, 0644) + os.Setenv("MLX_DISTRIBUTED_HOSTFILE", hostfile) + xlog.Debug("setting MLX_DISTRIBUTED_HOSTFILE", "value", hostfile, "tunnels", tunnels) + }), } if r.DisableMetricsEndpoint { diff --git a/core/cli/worker/worker.go b/core/cli/worker/worker.go index 0a636c3bfacb..4cd5ac471444 100644 --- a/core/cli/worker/worker.go +++ b/core/cli/worker/worker.go @@ -9,5 +9,6 @@ type WorkerFlags struct { type Worker struct { P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"` + P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"` LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"` } diff --git a/core/cli/worker/worker_p2p.go b/core/cli/worker/worker_p2p.go index 868357ccffd5..b9baf3bf34d4 100644 --- a/core/cli/worker/worker_p2p.go +++ b/core/cli/worker/worker_p2p.go @@ -62,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error { p = r.RunnerPort } - _, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) + _, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID)) if err != nil { return err } @@ -104,7 +104,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error { } }() - _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) + _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID)) if err != nil { return err } diff --git a/core/cli/worker/worker_p2p_mlx.go b/core/cli/worker/worker_p2p_mlx.go new file mode 100644 index 000000000000..802577ed3f0a --- /dev/null +++ b/core/cli/worker/worker_p2p_mlx.go @@ -0,0 +1,149 @@ +package worker + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/signals" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/xlog" + "github.com/phayes/freeport" +) + +const ( + mlxDistributedGalleryName = "mlx-distributed" +) + +type P2PMLX struct { + WorkerFlags `embed:""` + Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"` + Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode" group:"p2p"` + MLXListenPort string `env:"MLX_LISTEN_PORT" default:"5555" help:"Port for MLX distributed communication"` + MLXBackend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend (ring or jaccl)"` +} + +func findMLXDistributedBackend(galleries string, systemState *system.SystemState) (string, error) { + backends, err := gallery.ListSystemBackends(systemState) + if err != nil { + xlog.Warn("Failed listing system backends", "error", err) + return "", err + } + + backend, ok := backends.Get(mlxDistributedGalleryName) + if !ok { + ml := model.NewModelLoader(systemState) + var gals []config.Gallery + if err := json.Unmarshal([]byte(galleries), &gals); err != nil { + xlog.Error("failed loading galleries", "error", err) + return "", err + } + err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true) + if err != nil { + xlog.Error("mlx-distributed backend not found, failed to install it", "error", err) + return "", err + } + } + + backendPath := filepath.Dir(backend.RunFile) + if backendPath == "" { + return "", errors.New("mlx-distributed backend not found, install it first") + } + + return backendPath, nil +} + +func (r *P2PMLX) Run(ctx *cliContext.Context) error { + systemState, err := system.GetSystemState( + system.WithBackendPath(r.BackendsPath), + system.WithBackendSystemPath(r.BackendsSystemPath), + ) + if err != nil { + return err + } + + if r.Token == "" { + return fmt.Errorf("Token is required") + } + + port, err := freeport.GetFreePort() + if err != nil { + return err + } + if r.MLXListenPort != "" { + fmt.Sscanf(r.MLXListenPort, "%d", &port) + } + + address := "127.0.0.1" + + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + backendPath, err := findMLXDistributedBackend(r.BackendGalleries, systemState) + if err != nil { + xlog.Warn("Could not find mlx-distributed backend from gallery, will use backend.py directly", "error", err) + } + + // Start the MLX worker process + go func() { + for { + xlog.Info("Starting mlx-distributed worker", "address", address, "port", port) + + var cmd *exec.Cmd + if backendPath != "" { + cmd = exec.Command( + filepath.Join(backendPath, "run.sh"), + "--worker", + "--backend", r.MLXBackend, + "--hostfile", os.Getenv("MLX_DISTRIBUTED_HOSTFILE"), + "--rank", "0", // Will be overridden by hostfile position + ) + } else { + cmd = exec.Command( + "python3", "backend.py", + "--worker", + "--backend", r.MLXBackend, + "--hostfile", os.Getenv("MLX_DISTRIBUTED_HOSTFILE"), + "--rank", "0", + ) + } + + cmd.Env = os.Environ() + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + if err := cmd.Start(); err != nil { + xlog.Error("Failed to start mlx-distributed worker", "error", err) + } + + cmd.Wait() + time.Sleep(2 * time.Second) + } + }() + + // Expose this worker on the p2p network + _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID)) + if err != nil { + return err + } + + xlog.Info("MLX distributed worker registered on P2P network", "address", address, "port", port) + + signals.RegisterGracefulTerminationHandler(func() { + cancel() + }) + + for { + time.Sleep(1 * time.Second) + } +} diff --git a/core/config/application_config.go b/core/config/application_config.go index 79276e49feed..dd4e0bcfffae 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -80,7 +80,8 @@ type ApplicationConfig struct { APIAddress string - TunnelCallback func(tunnels []string) + LlamaCPPTunnelCallback func(tunnels []string) + MLXTunnelCallback func(tunnels []string) DisableRuntimeSettings bool @@ -416,9 +417,15 @@ func WithContextSize(ctxSize int) AppOption { } } -func WithTunnelCallback(callback func(tunnels []string)) AppOption { +func WithLlamaCPPTunnelCallback(callback func(tunnels []string)) AppOption { return func(o *ApplicationConfig) { - o.TunnelCallback = callback + o.LlamaCPPTunnelCallback = callback + } +} + +func WithMLXTunnelCallback(callback func(tunnels []string)) AppOption { + return func(o *ApplicationConfig) { + o.MLXTunnelCallback = callback } } diff --git a/core/explorer/discovery.go b/core/explorer/discovery.go index 989e784d32b6..36a193b71ec4 100644 --- a/core/explorer/discovery.go +++ b/core/explorer/discovery.go @@ -156,7 +156,7 @@ func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockch for d := range data { toScanForWorkers := false cd := ClusterData{} - isWorkerCluster := d == p2p.WorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.WorkerID)) + isWorkerCluster := d == p2p.LlamaCPPWorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.LlamaCPPWorkerID)) isFederatedCluster := d == p2p.FederatedID || (strings.Contains(d, "_") && strings.Contains(d, p2p.FederatedID)) switch { case isWorkerCluster: diff --git a/core/http/endpoints/localai/p2p.go b/core/http/endpoints/localai/p2p.go index afd7d048dc83..cc630be4f440 100644 --- a/core/http/endpoints/localai/p2p.go +++ b/core/http/endpoints/localai/p2p.go @@ -15,8 +15,9 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc { // Render index return func(c echo.Context) error { return c.JSON(200, schema.P2PNodesResponse{ - Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)), + LlamaCPPNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)), FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)), + MLXNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)), }) } } diff --git a/core/http/react-ui/src/pages/P2P.jsx b/core/http/react-ui/src/pages/P2P.jsx index c1f97a4ec0fc..f77b5db219de 100644 --- a/core/http/react-ui/src/pages/P2P.jsx +++ b/core/http/react-ui/src/pages/P2P.jsx @@ -103,8 +103,9 @@ function StepNumber({ n, bg, color }) { export default function P2P() { const { addToast } = useOutletContext() const [workers, setWorkers] = useState([]) + const [mlxWorkers, setMlxWorkers] = useState([]) const [federation, setFederation] = useState([]) - const [stats, setStats] = useState({ workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 } }) + const [stats, setStats] = useState({ llama_cpp_workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 }, mlx_workers: { online: 0, total: 0 } }) const [loading, setLoading] = useState(true) const [enabled, setEnabled] = useState(false) const [token, setToken] = useState('') @@ -129,7 +130,13 @@ export default function P2P() { if (p2pToken) { if (wRes.status === 'fulfilled') { const data = wRes.value - setWorkers(data?.nodes || (Array.isArray(data) ? data : [])) + // Handle both old format ({nodes: [...]}) and new grouped format ({llama_cpp: {nodes: [...]}, mlx: {nodes: [...]}}) + if (data?.llama_cpp) { + setWorkers(data.llama_cpp.nodes || []) + setMlxWorkers(data.mlx?.nodes || []) + } else { + setWorkers(data?.nodes || (Array.isArray(data) ? data : [])) + } } if (fRes.status === 'fulfilled') { const data = fRes.value @@ -274,8 +281,10 @@ export default function P2P() { // ── P2P Enabled ── const fedOnline = stats.federated?.online ?? 0 const fedTotal = stats.federated?.total ?? 0 - const wrkOnline = stats.workers?.online ?? 0 - const wrkTotal = stats.workers?.total ?? 0 + const llamaOnline = stats.llama_cpp_workers?.online ?? 0 + const llamaTotal = stats.llama_cpp_workers?.total ?? 0 + const mlxOnline = stats.mlx_workers?.online ?? 0 + const mlxTotal = stats.mlx_workers?.total ?? 0 return (
@@ -401,7 +410,7 @@ export default function P2P() { fontSize: '0.75rem', color: activeTab === 'sharding' ? 'var(--color-accent)' : 'var(--color-text-muted)', }}> - {wrkOnline}/{wrkTotal} workers + {llamaOnline + mlxOnline}/{llamaTotal + mlxTotal} workers
@@ -562,6 +571,21 @@ export default function P2P() { borderRadius: 'var(--radius-lg)', overflow: 'hidden', }}>
+
+ + Different from federation: Federation distributes whole requests across instances. Model sharding splits a single model across machines for joint inference. +
+ + {/* ── llama.cpp RPC Workers Section ── */} +

+ + llama.cpp RPC Workers +

+ {/* Architecture diagram */}

Model weights are split across RPC workers. Each worker holds a portion of the model layers in its memory (GPU or CPU). - The LocalAI instance orchestrates inference by communicating with all workers via RPC.

-
- - Different from federation: Federation distributes whole requests across instances. Model sharding splits a single model's weights across machines for joint inference. Currently only supported with llama.cpp based models. -
- - {/* Status + nodes */} + {/* llama.cpp Status + nodes */}
-

Connected Workers

+

Connected Workers

- 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{wrkOnline} - /{wrkTotal} + 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{llamaOnline} + /{llamaTotal}
@@ -633,7 +647,7 @@ export default function P2P() { borderRadius: 'var(--radius-lg)', }}> -

No workers available

+

No llama.cpp workers connected

Start workers to see them here

) : ( @@ -645,17 +659,99 @@ export default function P2P() { )} - {/* Setup Guide */} + {/* ── MLX Distributed Workers Section ── */} +
+

+ + MLX Distributed Workers +

+ + {/* MLX Architecture diagram */} +
+
+
+
+ +
+
LocalAI
+
Rank 0
+
+
+ + Ring / JACCL +
+
+
+ {['Layers 1-16', 'Layers 17-32'].map((label, i) => ( +
+
+ +
+
{label}
+
+ ))} +
+
MLX Workers
+
Pipeline parallel
+
+
+

+ MLX distributed uses native Apple Silicon communication. All nodes execute model code simultaneously via pipeline or tensor parallelism. +

+
+ + {/* MLX Status + nodes */} +
+

Connected MLX Workers

+
+ 0 ? 'var(--color-success)' : 'var(--color-error)' }}>{mlxOnline} + /{mlxTotal} +
+
+ + {mlxWorkers.length === 0 ? ( +
+ +

No MLX workers connected

+

Start MLX workers on Apple Silicon Macs

+
+ ) : ( +
+ {mlxWorkers.map((node, i) => ( + + ))} +
+ )} +
+ + {/* Setup Guides */}

- Start a llama.cpp RPC Worker + Setup Workers

+

llama.cpp RPC Worker

Each worker exposes its GPU/CPU memory as a shard for distributed model inference.

@@ -663,17 +759,24 @@ export default function P2P() { command={`docker run -ti --net host \\\n -e TOKEN="${token}" \\\n --name local-ai-worker \\\n localai/localai:latest-cpu worker p2p-llama-cpp-rpc`} addToast={addToast} /> -

- Run this on each machine you want to contribute as a shard. The worker will automatically join the network and advertise its resources. -

+
-

- For GPU images and all available options, see the{' '} - Container images - {' '}and{' '} - Worker docs. +

+

MLX Distributed Worker

+

+ Run on Apple Silicon Macs to participate in distributed MLX inference via pipeline parallelism. +

+ +

+ For more information, see the{' '} + MLX Distributed docs.

diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go index 563c9b4999fa..0ef0aa0bb435 100644 --- a/core/http/routes/ui_api.go +++ b/core/http/routes/ui_api.go @@ -1026,11 +1026,12 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model // P2P APIs app.GET("/api/p2p/workers", func(c echo.Context) error { - nodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)) + llamaNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)) + mlxNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)) - nodesJSON := make([]map[string]interface{}, 0, len(nodes)) - for _, n := range nodes { - nodesJSON = append(nodesJSON, map[string]interface{}{ + llamaJSON := make([]map[string]any, 0, len(llamaNodes)) + for _, n := range llamaNodes { + llamaJSON = append(llamaJSON, map[string]any{ "name": n.Name, "id": n.ID, "tunnelAddress": n.TunnelAddress, @@ -1040,8 +1041,27 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model }) } - return c.JSON(200, map[string]interface{}{ - "nodes": nodesJSON, + mlxJSON := make([]map[string]any, 0, len(mlxNodes)) + for _, n := range mlxNodes { + mlxJSON = append(mlxJSON, map[string]any{ + "name": n.Name, + "id": n.ID, + "tunnelAddress": n.TunnelAddress, + "serviceID": n.ServiceID, + "lastSeen": n.LastSeen, + "isOnline": n.IsOnline(), + }) + } + + return c.JSON(200, map[string]any{ + "llama_cpp": map[string]any{ + "nodes": llamaJSON, + }, + "mlx": map[string]any{ + "nodes": mlxJSON, + }, + // Keep backward-compatible "nodes" key with llama.cpp workers + "nodes": llamaJSON, }) }) @@ -1066,13 +1086,14 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model }) app.GET("/api/p2p/stats", func(c echo.Context) error { - workerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)) + llamaCPPNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)) federatedNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)) + mlxWorkerNodes := p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)) - workersOnline := 0 - for _, n := range workerNodes { + llamaCPPOnline := 0 + for _, n := range llamaCPPNodes { if n.IsOnline() { - workersOnline++ + llamaCPPOnline++ } } @@ -1083,15 +1104,26 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model } } - return c.JSON(200, map[string]interface{}{ - "workers": map[string]interface{}{ - "online": workersOnline, - "total": len(workerNodes), + mlxWorkersOnline := 0 + for _, n := range mlxWorkerNodes { + if n.IsOnline() { + mlxWorkersOnline++ + } + } + + return c.JSON(200, map[string]any{ + "llama_cpp_workers": map[string]any{ + "online": llamaCPPOnline, + "total": len(llamaCPPNodes), }, - "federated": map[string]interface{}{ + "federated": map[string]any{ "online": federatedOnline, "total": len(federatedNodes), }, + "mlx_workers": map[string]any{ + "online": mlxWorkersOnline, + "total": len(mlxWorkerNodes), + }, }) }) diff --git a/core/http/views/p2p.html b/core/http/views/p2p.html index c05c488f8b84..b3b595cbd216 100644 --- a/core/http/views/p2p.html +++ b/core/http/views/p2p.html @@ -262,8 +262,8 @@

Workers

- - / + + /

workers

@@ -469,8 +469,8 @@

Worker Network
Active Workers
- - / + + /
@@ -657,7 +657,7 @@

workerNodes: [], federationNodes: [], stats: { - workers: { online: 0, total: 0 }, + llama_cpp_workers: { online: 0, total: 0 }, federated: { online: 0, total: 0 } }, diff --git a/core/p2p/node.go b/core/p2p/node.go index 78efb77cacce..4996531dc9ec 100644 --- a/core/p2p/node.go +++ b/core/p2p/node.go @@ -9,8 +9,9 @@ import ( ) const ( - defaultServicesID = "services" - WorkerID = "worker" + defaultServicesID = "services" + LlamaCPPWorkerID = "worker" + MLXWorkerID = "mlx_worker" ) var mu sync.Mutex diff --git a/core/schema/localai.go b/core/schema/localai.go index ccf3e6e10f18..6f98bf320eee 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -136,8 +136,9 @@ func (d NodeData) IsOnline() bool { } type P2PNodesResponse struct { - Nodes []NodeData `json:"nodes" yaml:"nodes"` + LlamaCPPNodes []NodeData `json:"llama_cpp_nodes" yaml:"llama_cpp_nodes"` FederatedNodes []NodeData `json:"federated_nodes" yaml:"federated_nodes"` + MLXNodes []NodeData `json:"mlx_nodes" yaml:"mlx_nodes"` } type SysInfoModel struct { diff --git a/docs/content/features/mlx-distributed.md b/docs/content/features/mlx-distributed.md new file mode 100644 index 000000000000..e34f31c461a6 --- /dev/null +++ b/docs/content/features/mlx-distributed.md @@ -0,0 +1,110 @@ ++++ +disableToc = false +title = "MLX Distributed Inference" +weight = 18 +url = '/features/mlx-distributed/' ++++ + +MLX distributed inference allows you to split large language models across multiple Apple Silicon Macs (or other devices) for joint inference. Unlike federation (which distributes whole requests), MLX distributed splits a single model's layers across machines so they all participate in every forward pass. + +## How It Works + +MLX distributed uses **pipeline parallelism** via the Ring backend: each node holds a slice of the model's layers. During inference, activations flow from rank 0 through each subsequent rank in a pipeline. The last rank gathers the final output. + +For high-bandwidth setups (e.g., Thunderbolt-connected Macs), **JACCL** (tensor parallelism via RDMA) is also supported, where each rank holds all layers but with sharded weights. + +## Prerequisites + +- Two or more machines with MLX installed (Apple Silicon recommended) +- Network connectivity between all nodes (TCP for Ring, RDMA/Thunderbolt for JACCL) +- Same model accessible on all nodes (e.g., from Hugging Face cache) + +## Quick Start with P2P + +The simplest way to use MLX distributed is with LocalAI's P2P auto-discovery. + +### 1. Start LocalAI with P2P + +```bash +docker run -ti --net host \ + --name local-ai \ + localai/localai:latest-metal-darwin-arm64 run --p2p +``` + +This generates a network token. Copy it for the next step. + +### 2. Start MLX Workers + +On each additional Mac: + +```bash +docker run -ti --net host \ + -e TOKEN="" \ + --name local-ai-mlx-worker \ + localai/localai:latest-metal-darwin-arm64 worker p2p-mlx +``` + +Workers auto-register on the P2P network. The LocalAI server discovers them and generates a hostfile for MLX distributed. + +### 3. Use the Model + +Load any MLX-compatible model. The `mlx-distributed` backend will automatically shard it across all available ranks: + +```yaml +name: llama-distributed +backend: mlx-distributed +parameters: + model: mlx-community/Llama-3.2-1B-Instruct-4bit +``` + +## Manual Setup with Hostfile + +For setups without P2P, you can provide a hostfile directly. + +### Ring Backend (TCP) + +Create a JSON hostfile listing all ranks: + +```json +["192.168.1.10:5555", "192.168.1.11:5555"] +``` + +Start rank 0 (the gRPC server): + +```bash +python backend.py --addr localhost:50051 --backend ring --hostfile hosts.json --rank 0 +``` + +Start rank 1 (worker): + +```bash +python backend.py --worker --backend ring --hostfile hosts.json --rank 1 +``` + +### JACCL Backend (RDMA/Thunderbolt) + +Create a JSON device matrix (`null` on diagonal): + +```json +[ + [null, "rdma_thunderbolt0"], + ["rdma_thunderbolt0", null] +] +``` + +Start with `--backend jaccl` and `--coordinator `. + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `MLX_DISTRIBUTED_HOSTFILE` | Path to hostfile JSON (auto-set by P2P) | +| `MLX_LISTEN_PORT` | Port for MLX communication (default: 5555) | +| `MLX_DISTRIBUTED_BACKEND` | Backend type: `ring` or `jaccl` (default: ring) | + +## Troubleshooting + +- **All ranks must have the model downloaded.** MLX distributed does not transfer model weights over the network. Ensure each node has the model in its Hugging Face cache. +- **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. +- **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs. +- **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory. From f81cbb67d12f5917d7a5eb3c4d6f1b1bbcaa6204 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 5 Mar 2026 22:28:32 +0000 Subject: [PATCH 2/4] expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto --- backend/python/mlx-distributed/backend.py | 26 ++++-- core/cli/worker/worker.go | 7 +- core/cli/worker/worker_mlx_common.go | 65 +++++++++++++++ core/cli/worker/worker_mlx_distributed.go | 62 +++++++++++++++ core/cli/worker/worker_p2p_mlx.go | 96 ++++++----------------- docs/content/features/mlx-distributed.md | 88 +++++++++++++++++---- 6 files changed, 246 insertions(+), 98 deletions(-) create mode 100644 core/cli/worker/worker_mlx_common.go create mode 100644 core/cli/worker/worker_mlx_distributed.go diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py index 34d1fad3eba1..c6454af54f2e 100644 --- a/backend/python/mlx-distributed/backend.py +++ b/backend/python/mlx-distributed/backend.py @@ -23,7 +23,16 @@ def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None): - """Initialize MLX distributed runtime.""" + """Initialize MLX distributed runtime. + + Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank + binds to its own entry (hostfile[rank]) and connects to neighbors for the + ring pipeline. + + JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names. + MLX_JACCL_COORDINATOR is rank 0's ip:port where it runs a TCP service that + helps all ranks establish RDMA connections. + """ import mlx.core as mx if backend == "ring": @@ -335,12 +344,17 @@ async def serve(address, group, dist_backend): if __name__ == "__main__": parser = argparse.ArgumentParser(description="MLX Distributed Backend") - parser.add_argument("--addr", default="localhost:50051", help="gRPC listen address (rank 0 only)") + parser.add_argument("--addr", default="localhost:50051", + help="gRPC API listen address for LocalAI (rank 0 only, separate from ring communication)") parser.add_argument("--worker", action="store_true", help="Run in worker mode (rank > 0)") - parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"], help="MLX distributed backend") - parser.add_argument("--hostfile", required=True, help="Path to hostfile JSON") - parser.add_argument("--rank", type=int, required=True, help="Rank of this process") - parser.add_argument("--coordinator", default=None, help="JACCL coordinator address") + parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"], + help="ring = TCP pipeline parallelism, jaccl = RDMA tensor parallelism") + parser.add_argument("--hostfile", required=True, + help="Ring: JSON array of 'ip:port' where entry i is rank i's listen address. " + "JACCL: JSON 2D matrix of RDMA device names (null on diagonal).") + parser.add_argument("--rank", type=int, required=True, help="Rank of this process (0 = server, >0 = worker)") + parser.add_argument("--coordinator", default=None, + help="JACCL only: coordinator ip:port — rank 0's address for RDMA setup (same value on all ranks)") args = parser.parse_args() group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator) diff --git a/core/cli/worker/worker.go b/core/cli/worker/worker.go index 4cd5ac471444..1ddb972a0b7e 100644 --- a/core/cli/worker/worker.go +++ b/core/cli/worker/worker.go @@ -8,7 +8,8 @@ type WorkerFlags struct { } type Worker struct { - P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"` - P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"` - LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"` + P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"` + P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"` + LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"` + MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"` } diff --git a/core/cli/worker/worker_mlx_common.go b/core/cli/worker/worker_mlx_common.go new file mode 100644 index 000000000000..2a1b2443f23c --- /dev/null +++ b/core/cli/worker/worker_mlx_common.go @@ -0,0 +1,65 @@ +package worker + +import ( + "context" + "encoding/json" + "errors" + "os/exec" + "path/filepath" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/xlog" +) + +const mlxDistributedGalleryName = "mlx-distributed" + +// findMLXDistributedBackendPath finds or installs the mlx-distributed backend +// and returns the directory containing run.sh. +func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) { + backends, err := gallery.ListSystemBackends(systemState) + if err != nil { + return "", err + } + + backend, ok := backends.Get(mlxDistributedGalleryName) + if !ok { + ml := model.NewModelLoader(systemState) + var gals []config.Gallery + if err := json.Unmarshal([]byte(galleries), &gals); err != nil { + xlog.Error("failed loading galleries", "error", err) + return "", err + } + if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true); err != nil { + xlog.Error("mlx-distributed backend not found, failed to install it", "error", err) + return "", err + } + // Re-fetch after install + backends, err = gallery.ListSystemBackends(systemState) + if err != nil { + return "", err + } + backend, ok = backends.Get(mlxDistributedGalleryName) + if !ok { + return "", errors.New("mlx-distributed backend not found after install") + } + } + + backendPath := filepath.Dir(backend.RunFile) + if backendPath == "" { + return "", errors.New("mlx-distributed backend not found, install it first") + } + return backendPath, nil +} + +// buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend. +// backendPath is the directory containing run.sh (empty string to fall back to +// running backend.py directly via python3). +func buildMLXCommand(backendPath string, args ...string) *exec.Cmd { + if backendPath != "" { + return exec.Command(filepath.Join(backendPath, "run.sh"), args...) + } + return exec.Command("python3", append([]string{"backend.py"}, args...)...) +} diff --git a/core/cli/worker/worker_mlx_distributed.go b/core/cli/worker/worker_mlx_distributed.go new file mode 100644 index 000000000000..e701d927a9ad --- /dev/null +++ b/core/cli/worker/worker_mlx_distributed.go @@ -0,0 +1,62 @@ +package worker + +import ( + "fmt" + "os" + "syscall" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/xlog" +) + +type MLXDistributed struct { + WorkerFlags `embed:""` + Hostfile string `env:"MLX_DISTRIBUTED_HOSTFILE" required:"" help:"Path to hostfile JSON. Ring: array of 'ip:port' where entry i is rank i's listen address. JACCL: 2D matrix of RDMA device names."` + Rank int `env:"MLX_RANK" required:"" help:"Rank of this process (0 = gRPC server + ring participant, >0 = worker only)"` + Backend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend: 'ring' (TCP pipeline parallelism) or 'jaccl' (RDMA tensor parallelism)"` + Addr string `env:"MLX_DISTRIBUTED_ADDR" default:"localhost:50051" help:"gRPC API listen address for LocalAI (rank 0 only, separate from ring communication)"` + Coordinator string `env:"MLX_JACCL_COORDINATOR" default:"" help:"JACCL coordinator ip:port — rank 0's address where it accepts RDMA setup connections (all ranks must use the same value)"` +} + +func (r *MLXDistributed) Run(ctx *cliContext.Context) error { + systemState, err := system.GetSystemState( + system.WithBackendPath(r.BackendsPath), + system.WithBackendSystemPath(r.BackendsSystemPath), + ) + if err != nil { + return err + } + + backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState) + if err != nil { + return fmt.Errorf("cannot find mlx-distributed backend: %w", err) + } + + args := []string{ + "--backend", r.Backend, + "--hostfile", r.Hostfile, + "--rank", fmt.Sprint(r.Rank), + } + + if r.Rank == 0 { + args = append(args, "--addr", r.Addr) + } else { + args = append(args, "--worker") + } + + if r.Backend == "jaccl" && r.Coordinator != "" { + args = append(args, "--coordinator", r.Coordinator) + } + + cmd := buildMLXCommand(backendPath, args...) + runSh := cmd.Path + + xlog.Info("Starting mlx-distributed", "rank", r.Rank, "backend", r.Backend, "hostfile", r.Hostfile) + + return syscall.Exec( + runSh, + append([]string{runSh}, args...), + os.Environ(), + ) +} diff --git a/core/cli/worker/worker_p2p_mlx.go b/core/cli/worker/worker_p2p_mlx.go index 802577ed3f0a..ffa1c4ef912d 100644 --- a/core/cli/worker/worker_p2p_mlx.go +++ b/core/cli/worker/worker_p2p_mlx.go @@ -2,29 +2,18 @@ package worker import ( "context" - "encoding/json" - "errors" "fmt" "os" - "os/exec" - "path/filepath" "time" cliContext "github.com/mudler/LocalAI/core/cli/context" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/p2p" - "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" "github.com/phayes/freeport" ) -const ( - mlxDistributedGalleryName = "mlx-distributed" -) - type P2PMLX struct { WorkerFlags `embed:""` Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"` @@ -33,37 +22,11 @@ type P2PMLX struct { MLXBackend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend (ring or jaccl)"` } -func findMLXDistributedBackend(galleries string, systemState *system.SystemState) (string, error) { - backends, err := gallery.ListSystemBackends(systemState) - if err != nil { - xlog.Warn("Failed listing system backends", "error", err) - return "", err - } - - backend, ok := backends.Get(mlxDistributedGalleryName) - if !ok { - ml := model.NewModelLoader(systemState) - var gals []config.Gallery - if err := json.Unmarshal([]byte(galleries), &gals); err != nil { - xlog.Error("failed loading galleries", "error", err) - return "", err - } - err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true) - if err != nil { - xlog.Error("mlx-distributed backend not found, failed to install it", "error", err) - return "", err - } - } - - backendPath := filepath.Dir(backend.RunFile) - if backendPath == "" { - return "", errors.New("mlx-distributed backend not found, install it first") +func (r *P2PMLX) Run(ctx *cliContext.Context) error { + if r.Token == "" { + return fmt.Errorf("token is required") } - return backendPath, nil -} - -func (r *P2PMLX) Run(ctx *cliContext.Context) error { systemState, err := system.GetSystemState( system.WithBackendPath(r.BackendsPath), system.WithBackendSystemPath(r.BackendsSystemPath), @@ -72,10 +35,6 @@ func (r *P2PMLX) Run(ctx *cliContext.Context) error { return err } - if r.Token == "" { - return fmt.Errorf("Token is required") - } - port, err := freeport.GetFreePort() if err != nil { return err @@ -89,49 +48,39 @@ func (r *P2PMLX) Run(ctx *cliContext.Context) error { c, cancel := context.WithCancel(context.Background()) defer cancel() - backendPath, err := findMLXDistributedBackend(r.BackendGalleries, systemState) + backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState) if err != nil { - xlog.Warn("Could not find mlx-distributed backend from gallery, will use backend.py directly", "error", err) + xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err) } - // Start the MLX worker process go func() { for { - xlog.Info("Starting mlx-distributed worker", "address", address, "port", port) - - var cmd *exec.Cmd - if backendPath != "" { - cmd = exec.Command( - filepath.Join(backendPath, "run.sh"), - "--worker", - "--backend", r.MLXBackend, - "--hostfile", os.Getenv("MLX_DISTRIBUTED_HOSTFILE"), - "--rank", "0", // Will be overridden by hostfile position - ) - } else { - cmd = exec.Command( - "python3", "backend.py", - "--worker", - "--backend", r.MLXBackend, - "--hostfile", os.Getenv("MLX_DISTRIBUTED_HOSTFILE"), - "--rank", "0", - ) + hostfile := os.Getenv("MLX_DISTRIBUTED_HOSTFILE") + if hostfile == "" { + xlog.Info("Waiting for MLX_DISTRIBUTED_HOSTFILE to be set by P2P discovery...") + time.Sleep(2 * time.Second) + continue } + xlog.Info("Starting mlx-distributed worker", "address", address, "port", port, "hostfile", hostfile) + + cmd := buildMLXCommand(backendPath, + "--worker", + "--backend", r.MLXBackend, + "--hostfile", hostfile, + "--rank", "0", + ) cmd.Env = os.Environ() cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout - if err := cmd.Start(); err != nil { - xlog.Error("Failed to start mlx-distributed worker", "error", err) + if err := cmd.Run(); err != nil { + xlog.Error("mlx-distributed worker exited", "error", err) } - - cmd.Wait() time.Sleep(2 * time.Second) } }() - // Expose this worker on the p2p network _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID)) if err != nil { return err @@ -143,7 +92,6 @@ func (r *P2PMLX) Run(ctx *cliContext.Context) error { cancel() }) - for { - time.Sleep(1 * time.Second) - } + <-c.Done() + return nil } diff --git a/docs/content/features/mlx-distributed.md b/docs/content/features/mlx-distributed.md index e34f31c461a6..7970c3a6d109 100644 --- a/docs/content/features/mlx-distributed.md +++ b/docs/content/features/mlx-distributed.md @@ -57,33 +57,46 @@ parameters: model: mlx-community/Llama-3.2-1B-Instruct-4bit ``` -## Manual Setup with Hostfile +## Manual Setup with CLI -For setups without P2P, you can provide a hostfile directly. +For setups without P2P, use the `worker mlx-distributed` command. LocalAI handles backend installation automatically. ### Ring Backend (TCP) -Create a JSON hostfile listing all ranks: +The Ring backend uses TCP for pipeline parallelism. Each rank listens on a TCP port for ring communication with its neighbors. The **hostfile** is a JSON array where entry `i` is the `"ip:port"` that **rank `i` binds to and listens on**. All ranks must use the same hostfile so they know how to reach each other. + +**Example:** Two Macs on a local network — Mac A is `192.168.1.10`, Mac B is `192.168.1.11`. + +Create `hosts.json` (identical on both machines): ```json ["192.168.1.10:5555", "192.168.1.11:5555"] ``` -Start rank 0 (the gRPC server): +- Entry 0 (`192.168.1.10:5555`) — the address rank 0 (Mac A) listens on +- Entry 1 (`192.168.1.11:5555`) — the address rank 1 (Mac B) listens on + +Each rank binds to its own entry and connects to its neighbors for the ring pipeline. Port 5555 is arbitrary — use any available port, but it must be open in your firewall. + +Start rank 0 on **Mac A** (`192.168.1.10`): ```bash -python backend.py --addr localhost:50051 --backend ring --hostfile hosts.json --rank 0 +local-ai worker mlx-distributed --hostfile hosts.json --rank 0 --addr localhost:50051 ``` -Start rank 1 (worker): +Start rank 1 on **Mac B** (`192.168.1.11`): ```bash -python backend.py --worker --backend ring --hostfile hosts.json --rank 1 +local-ai worker mlx-distributed --hostfile hosts.json --rank 1 ``` +Rank 0 starts a gRPC server (on `--addr`) that LocalAI connects to for inference requests. The `--addr` flag is separate from the ring hostfile — it controls where the gRPC API listens, not the ring communication. All other ranks run as workers that participate in each forward pass. + ### JACCL Backend (RDMA/Thunderbolt) -Create a JSON device matrix (`null` on diagonal): +For Thunderbolt-connected Macs, JACCL provides tensor parallelism via RDMA for higher throughput. + +The **device matrix** is a JSON 2D array describing the RDMA device name used to communicate between each pair of ranks. Entry `[i][j]` is the RDMA device that rank `i` uses to talk to rank `j`. The diagonal is `null` (a rank doesn't talk to itself). ```json [ @@ -92,15 +105,56 @@ Create a JSON device matrix (`null` on diagonal): ] ``` -Start with `--backend jaccl` and `--coordinator `. +The **coordinator** is a TCP endpoint where one node (typically rank 0) runs a coordination service that helps all ranks establish their RDMA connections. Rank 0 binds to this address; all other ranks connect to it. Use rank 0's IP address and any available port. -## Environment Variables +**Example:** Mac A (`192.168.1.10`) is rank 0, Mac B is rank 1, connected via Thunderbolt. -| Variable | Description | -|----------|-------------| -| `MLX_DISTRIBUTED_HOSTFILE` | Path to hostfile JSON (auto-set by P2P) | -| `MLX_LISTEN_PORT` | Port for MLX communication (default: 5555) | -| `MLX_DISTRIBUTED_BACKEND` | Backend type: `ring` or `jaccl` (default: ring) | +Start rank 0 on **Mac A** (`192.168.1.10`): + +```bash +local-ai worker mlx-distributed \ + --hostfile devices.json \ + --rank 0 \ + --backend jaccl \ + --coordinator 192.168.1.10:5555 \ + --addr localhost:50051 +``` + +Start rank 1 on **Mac B**: + +```bash +local-ai worker mlx-distributed \ + --hostfile devices.json \ + --rank 1 \ + --backend jaccl \ + --coordinator 192.168.1.10:5555 +``` + +Both ranks point `--coordinator` to rank 0's IP. Rank 0 binds to that address to accept RDMA setup connections from other ranks. + +## CLI Reference + +### `worker mlx-distributed` + +Standalone mode — run with a manual hostfile. + +| Flag | Env | Default | Description | +|------|-----|---------|-------------| +| `--hostfile` | `MLX_DISTRIBUTED_HOSTFILE` | *(required)* | Path to hostfile JSON. For Ring: array of `"ip:port"` where entry `i` is rank `i`'s listen address. For JACCL: device matrix of RDMA device names. | +| `--rank` | `MLX_RANK` | *(required)* | Rank of this process (0 = gRPC server + ring participant, >0 = worker only) | +| `--backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | Backend type: `ring` (TCP pipeline parallelism) or `jaccl` (RDMA tensor parallelism) | +| `--addr` | `MLX_DISTRIBUTED_ADDR` | `localhost:50051` | gRPC API listen address for LocalAI to connect to (rank 0 only, separate from ring communication) | +| `--coordinator` | `MLX_JACCL_COORDINATOR` | | JACCL coordinator `ip:port` — rank 0's IP where it accepts RDMA setup connections (jaccl only, required for all ranks) | + +### `worker p2p-mlx` + +P2P mode — auto-discovers peers and generates hostfile. + +| Flag | Env | Default | Description | +|------|-----|---------|-------------| +| `--token` | `TOKEN` | *(required)* | P2P network token | +| `--mlx-listen-port` | `MLX_LISTEN_PORT` | `5555` | Port for MLX communication | +| `--mlx-backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | Backend type: `ring` or `jaccl` | ## Troubleshooting @@ -108,3 +162,7 @@ Start with `--backend jaccl` and `--coordinator `. - **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. - **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs. - **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory. + +## Acknowledgements + +The MLX distributed auto-parallel sharding implementation is based on [exo](https://github.com/exo-explore/exo). From dfebb53ed6ff5562fa2f663351ff48a8eba83dc1 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 5 Mar 2026 22:40:16 +0000 Subject: [PATCH 3/4] feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto --- backend/python/mlx-distributed/backend.py | 175 ++++++++++++++-------- docs/content/features/mlx-distributed.md | 126 +++++++++++----- 2 files changed, 195 insertions(+), 106 deletions(-) diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py index c6454af54f2e..4212f555b724 100644 --- a/backend/python/mlx-distributed/backend.py +++ b/backend/python/mlx-distributed/backend.py @@ -2,8 +2,15 @@ """ MLX Distributed Inference Backend for LocalAI. -Rank 0 mode: Starts a gRPC server that coordinates distributed inference. -Worker mode: Enters a loop waiting for commands from rank 0. +Two startup modes: + +1. Server mode (started by LocalAI automatically): + run.sh --addr localhost:50051 + Distributed config comes from LoadModel options or env vars. + +2. Worker mode (started by CLI for remote ranks): + run.sh --worker --hostfile hosts.json --rank 1 --backend ring + Enters a loop waiting for commands from rank 0. """ import asyncio from concurrent import futures @@ -66,12 +73,33 @@ def is_int(s): return False +def parse_options(options): + """Parse key:value option strings into a dict.""" + result = {} + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + if is_float(value): + value = float(value) + elif is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + result[key] = value + return result + + class BackendServicer(backend_pb2_grpc.BackendServicer): - """gRPC servicer for distributed MLX inference (runs only on rank 0).""" + """gRPC servicer for distributed MLX inference (runs on rank 0). + + When started by LocalAI (server mode), distributed init happens at + LoadModel time using config from model options or environment variables. + """ - def __init__(self, group, dist_backend="ring"): - self.group = group - self.dist_backend = dist_backend + def __init__(self): + self.group = None + self.dist_backend = None self.model = None self.tokenizer = None self.coordinator = None @@ -87,27 +115,34 @@ async def LoadModel(self, request, context): from coordinator import DistributedCoordinator, CMD_LOAD_MODEL from sharding import pipeline_auto_parallel - print(f"[Rank 0] Loading distributed model: {request.Model}", file=sys.stderr) - - options = request.Options - self.options = {} - for opt in options: - if ":" not in opt: - continue - key, value = opt.split(":", 1) - if is_float(value): - value = float(value) - elif is_int(value): - value = int(value) - elif value.lower() in ["true", "false"]: - value = value.lower() == "true" - self.options[key] = value - - self.coordinator = DistributedCoordinator(self.group) - - # Broadcast load command to all ranks - self.coordinator.broadcast_command(CMD_LOAD_MODEL) - self.coordinator.broadcast_model_name(request.Model) + print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr) + + self.options = parse_options(request.Options) + + # Get distributed config from model options, falling back to env vars. + # If neither is set, run as single-node (no distributed). + hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", "")) + dist_backend = str(self.options.get("distributed_backend", + os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring"))) + # JACCL coordinator: rank 0 reads from env (set by CLI --coordinator). + # Not in model options — rank 0 is the coordinator, workers get + # the address via their own --coordinator CLI flag. + jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "") + + if hostfile: + print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr) + self.dist_backend = dist_backend + self.group = mlx_distributed_init( + rank=0, + hostfile=hostfile, + backend=dist_backend, + coordinator=jaccl_coordinator or None, + ) + self.coordinator = DistributedCoordinator(self.group) + self.coordinator.broadcast_command(CMD_LOAD_MODEL) + self.coordinator.broadcast_model_name(request.Model) + else: + print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr) tokenizer_config = {} if request.TrustRemoteCode or self.options.get("trust_remote_code", False): @@ -118,16 +153,17 @@ async def LoadModel(self, request, context): else: self.model, self.tokenizer = load(request.Model) - # Apply pipeline parallelism - self.model = pipeline_auto_parallel(self.model, self.group) - - print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr) + if self.group is not None: + self.model = pipeline_auto_parallel(self.model, self.group) + print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr) + else: + print("[Rank 0] Model loaded (single-node)", file=sys.stderr) except Exception as err: print(f"[Rank 0] Error loading model: {err}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading model: {err}") - return backend_pb2.Result(message="Model loaded with distributed sharding", success=True) + return backend_pb2.Result(message="Model loaded successfully", success=True) async def Predict(self, request, context): try: @@ -141,16 +177,19 @@ async def Predict(self, request, context): if hasattr(tokens, 'tolist'): tokens = tokens.tolist() - # Broadcast generate command + tokens - self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) - self.coordinator.broadcast_tokens(tokens) + if self.coordinator: + self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) + self.coordinator.broadcast_tokens(tokens) max_tokens, sampler_params = self._build_generation_params(request) - gen_params = self.coordinator.broadcast_generation_params( - max_tokens=max_tokens, - temperature=sampler_params.get('temp', 0.6), - top_p=sampler_params.get('top_p', 1.0), - ) + + if self.coordinator: + gen_params = self.coordinator.broadcast_generation_params( + max_tokens=max_tokens, + temperature=sampler_params.get('temp', 0.6), + top_p=sampler_params.get('top_p', 1.0), + ) + max_tokens = gen_params["max_tokens"] sampler = make_sampler(**sampler_params) @@ -159,7 +198,7 @@ async def Predict(self, request, context): self.model, self.tokenizer, prompt=tokens, - max_tokens=gen_params["max_tokens"], + max_tokens=max_tokens, sampler=sampler, ): generated.append(response.text) @@ -184,15 +223,19 @@ async def PredictStream(self, request, context): if hasattr(tokens, 'tolist'): tokens = tokens.tolist() - self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) - self.coordinator.broadcast_tokens(tokens) + if self.coordinator: + self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) + self.coordinator.broadcast_tokens(tokens) max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) - gen_params = self.coordinator.broadcast_generation_params( - max_tokens=max_tokens, - temperature=sampler_params.get('temp', 0.6), - top_p=sampler_params.get('top_p', 1.0), - ) + + if self.coordinator: + gen_params = self.coordinator.broadcast_generation_params( + max_tokens=max_tokens, + temperature=sampler_params.get('temp', 0.6), + top_p=sampler_params.get('top_p', 1.0), + ) + max_tokens = gen_params["max_tokens"] sampler = make_sampler(**sampler_params) @@ -200,7 +243,7 @@ async def PredictStream(self, request, context): self.model, self.tokenizer, prompt=tokens, - max_tokens=gen_params["max_tokens"], + max_tokens=max_tokens, sampler=sampler, ): yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) @@ -305,7 +348,6 @@ def run_worker(group): top_p=gen_params["top_p"], ) - # Participate in distributed compute, discard output for _ in stream_generate( model, tokenizer, prompt=tokens, @@ -319,7 +361,7 @@ def run_worker(group): break -async def serve(address, group, dist_backend): +async def serve(address): server = grpc.aio.server( migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ @@ -328,9 +370,7 @@ async def serve(address, group, dist_backend): ('grpc.max_receive_message_length', 50 * 1024 * 1024), ], ) - backend_pb2_grpc.add_BackendServicer_to_server( - BackendServicer(group, dist_backend), server - ) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) loop = asyncio.get_event_loop() @@ -345,21 +385,26 @@ async def serve(address, group, dist_backend): if __name__ == "__main__": parser = argparse.ArgumentParser(description="MLX Distributed Backend") parser.add_argument("--addr", default="localhost:50051", - help="gRPC API listen address for LocalAI (rank 0 only, separate from ring communication)") - parser.add_argument("--worker", action="store_true", help="Run in worker mode (rank > 0)") + help="gRPC listen address (used by LocalAI to send requests)") + parser.add_argument("--worker", action="store_true", + help="Run in worker mode (for remote ranks started by CLI)") parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"], help="ring = TCP pipeline parallelism, jaccl = RDMA tensor parallelism") - parser.add_argument("--hostfile", required=True, - help="Ring: JSON array of 'ip:port' where entry i is rank i's listen address. " - "JACCL: JSON 2D matrix of RDMA device names (null on diagonal).") - parser.add_argument("--rank", type=int, required=True, help="Rank of this process (0 = server, >0 = worker)") + parser.add_argument("--hostfile", default=None, + help="Path to hostfile JSON (required for --worker mode)") + parser.add_argument("--rank", type=int, default=0, + help="Rank of this process (0 = server, >0 = worker)") parser.add_argument("--coordinator", default=None, - help="JACCL only: coordinator ip:port — rank 0's address for RDMA setup (same value on all ranks)") + help="JACCL coordinator ip:port (jaccl backend only)") args = parser.parse_args() - group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator) - - if args.worker or args.rank > 0: + if args.worker: + if not args.hostfile: + print("Error: --hostfile is required in worker mode", file=sys.stderr) + sys.exit(1) + group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator) run_worker(group) else: - asyncio.run(serve(args.addr, group, args.backend)) + # Server mode: started by LocalAI with just --addr. + # Distributed init deferred to LoadModel (reads config from model options/env vars). + asyncio.run(serve(args.addr)) diff --git a/docs/content/features/mlx-distributed.md b/docs/content/features/mlx-distributed.md index 7970c3a6d109..d4b22ca1edfd 100644 --- a/docs/content/features/mlx-distributed.md +++ b/docs/content/features/mlx-distributed.md @@ -57,46 +57,48 @@ parameters: model: mlx-community/Llama-3.2-1B-Instruct-4bit ``` -## Manual Setup with CLI +## Model Configuration -For setups without P2P, use the `worker mlx-distributed` command. LocalAI handles backend installation automatically. +The `mlx-distributed` backend is started automatically by LocalAI like any other backend. You configure distributed inference through the model YAML file using the `options` field: ### Ring Backend (TCP) -The Ring backend uses TCP for pipeline parallelism. Each rank listens on a TCP port for ring communication with its neighbors. The **hostfile** is a JSON array where entry `i` is the `"ip:port"` that **rank `i` binds to and listens on**. All ranks must use the same hostfile so they know how to reach each other. +```yaml +name: llama-distributed +backend: mlx-distributed +parameters: + model: mlx-community/Llama-3.2-1B-Instruct-4bit +options: + - "hostfile:/path/to/hosts.json" + - "distributed_backend:ring" +``` -**Example:** Two Macs on a local network — Mac A is `192.168.1.10`, Mac B is `192.168.1.11`. +The **hostfile** is a JSON array where entry `i` is the `"ip:port"` that **rank `i` listens on** for ring communication. All ranks must use the same hostfile so they know how to reach each other. -Create `hosts.json` (identical on both machines): +**Example:** Two Macs — Mac A (`192.168.1.10`) and Mac B (`192.168.1.11`): ```json ["192.168.1.10:5555", "192.168.1.11:5555"] ``` -- Entry 0 (`192.168.1.10:5555`) — the address rank 0 (Mac A) listens on -- Entry 1 (`192.168.1.11:5555`) — the address rank 1 (Mac B) listens on +- Entry 0 (`192.168.1.10:5555`) — the address rank 0 (Mac A) listens on for ring communication +- Entry 1 (`192.168.1.11:5555`) — the address rank 1 (Mac B) listens on for ring communication -Each rank binds to its own entry and connects to its neighbors for the ring pipeline. Port 5555 is arbitrary — use any available port, but it must be open in your firewall. - -Start rank 0 on **Mac A** (`192.168.1.10`): - -```bash -local-ai worker mlx-distributed --hostfile hosts.json --rank 0 --addr localhost:50051 -``` - -Start rank 1 on **Mac B** (`192.168.1.11`): - -```bash -local-ai worker mlx-distributed --hostfile hosts.json --rank 1 -``` - -Rank 0 starts a gRPC server (on `--addr`) that LocalAI connects to for inference requests. The `--addr` flag is separate from the ring hostfile — it controls where the gRPC API listens, not the ring communication. All other ranks run as workers that participate in each forward pass. +Port 5555 is arbitrary — use any available port, but it must be open in your firewall. ### JACCL Backend (RDMA/Thunderbolt) -For Thunderbolt-connected Macs, JACCL provides tensor parallelism via RDMA for higher throughput. +```yaml +name: llama-distributed +backend: mlx-distributed +parameters: + model: mlx-community/Llama-3.2-1B-Instruct-4bit +options: + - "hostfile:/path/to/devices.json" + - "distributed_backend:jaccl" +``` -The **device matrix** is a JSON 2D array describing the RDMA device name used to communicate between each pair of ranks. Entry `[i][j]` is the RDMA device that rank `i` uses to talk to rank `j`. The diagonal is `null` (a rank doesn't talk to itself). +The **device matrix** is a JSON 2D array describing the RDMA device name between each pair of ranks. The diagonal is `null` (a rank doesn't talk to itself): ```json [ @@ -105,22 +107,43 @@ The **device matrix** is a JSON 2D array describing the RDMA device name used to ] ``` -The **coordinator** is a TCP endpoint where one node (typically rank 0) runs a coordination service that helps all ranks establish their RDMA connections. Rank 0 binds to this address; all other ranks connect to it. Use rank 0's IP address and any available port. +JACCL requires a **coordinator** — a TCP service that helps all ranks establish RDMA connections. Rank 0 (the LocalAI machine) is always the coordinator. Workers are told the coordinator address via their `--coordinator` CLI flag (see [Starting Workers](#jaccl-workers) below). + +### Without hostfile (single-node) -**Example:** Mac A (`192.168.1.10`) is rank 0, Mac B is rank 1, connected via Thunderbolt. +If no `hostfile` option is set and no `MLX_DISTRIBUTED_HOSTFILE` environment variable exists, the backend runs as a regular single-node MLX backend. This is useful for testing or when you don't need distributed inference. -Start rank 0 on **Mac A** (`192.168.1.10`): +### Available Options + +| Option | Description | +|--------|-------------| +| `hostfile` | Path to the hostfile JSON. Ring: array of `"ip:port"`. JACCL: device matrix. | +| `distributed_backend` | `ring` (default) or `jaccl` | +| `trust_remote_code` | Allow trust_remote_code for the tokenizer | +| `max_tokens` | Override default max generation tokens | +| `temperature` / `temp` | Sampling temperature | +| `top_p` | Top-p sampling | + +These can also be set via environment variables (`MLX_DISTRIBUTED_HOSTFILE`, `MLX_DISTRIBUTED_BACKEND`) which are used as fallbacks when the model options don't specify them. + +## Starting Workers + +LocalAI starts the rank 0 process (gRPC server) automatically when the model is loaded. But you still need to start **worker processes** (ranks 1, 2, ...) on the other machines. These workers participate in every forward pass but don't serve any API — they wait for commands from rank 0. + +### Ring Workers + +On each worker machine, start a worker with the same hostfile: ```bash -local-ai worker mlx-distributed \ - --hostfile devices.json \ - --rank 0 \ - --backend jaccl \ - --coordinator 192.168.1.10:5555 \ - --addr localhost:50051 +local-ai worker mlx-distributed --hostfile hosts.json --rank 1 ``` -Start rank 1 on **Mac B**: +The `--rank` must match the worker's position in the hostfile. For example, if `hosts.json` is `["192.168.1.10:5555", "192.168.1.11:5555", "192.168.1.12:5555"]`, then: +- Rank 0: started automatically by LocalAI on `192.168.1.10` +- Rank 1: `local-ai worker mlx-distributed --hostfile hosts.json --rank 1` on `192.168.1.11` +- Rank 2: `local-ai worker mlx-distributed --hostfile hosts.json --rank 2` on `192.168.1.12` + +### JACCL Workers ```bash local-ai worker mlx-distributed \ @@ -130,21 +153,42 @@ local-ai worker mlx-distributed \ --coordinator 192.168.1.10:5555 ``` -Both ranks point `--coordinator` to rank 0's IP. Rank 0 binds to that address to accept RDMA setup connections from other ranks. +The `--coordinator` address is the IP of the machine running LocalAI (rank 0) with any available port. Rank 0 binds the coordinator service there; workers connect to it to establish RDMA connections. + +### Worker Startup Order + +Start workers **before** loading the model in LocalAI. When LocalAI sends the LoadModel request, rank 0 initializes `mx.distributed` which tries to connect to all ranks listed in the hostfile. If workers aren't running yet, it will time out. + +## Advanced: Manual Rank 0 + +For advanced use cases, you can also run rank 0 manually as an external gRPC backend instead of letting LocalAI start it automatically: + +```bash +# On Mac A: start rank 0 manually +local-ai worker mlx-distributed --hostfile hosts.json --rank 0 --addr 192.168.1.10:50051 + +# On Mac B: start rank 1 +local-ai worker mlx-distributed --hostfile hosts.json --rank 1 + +# On any machine: start LocalAI pointing at rank 0 +local-ai run --external-grpc-backends "mlx-distributed:192.168.1.10:50051" +``` + +Then use a model config with `backend: mlx-distributed` (no need for `hostfile` in options since rank 0 already has it from CLI args). ## CLI Reference ### `worker mlx-distributed` -Standalone mode — run with a manual hostfile. +Starts a worker or manual rank 0 process. | Flag | Env | Default | Description | |------|-----|---------|-------------| -| `--hostfile` | `MLX_DISTRIBUTED_HOSTFILE` | *(required)* | Path to hostfile JSON. For Ring: array of `"ip:port"` where entry `i` is rank `i`'s listen address. For JACCL: device matrix of RDMA device names. | +| `--hostfile` | `MLX_DISTRIBUTED_HOSTFILE` | *(required)* | Path to hostfile JSON. Ring: array of `"ip:port"` where entry `i` is rank `i`'s listen address. JACCL: device matrix of RDMA device names. | | `--rank` | `MLX_RANK` | *(required)* | Rank of this process (0 = gRPC server + ring participant, >0 = worker only) | -| `--backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | Backend type: `ring` (TCP pipeline parallelism) or `jaccl` (RDMA tensor parallelism) | -| `--addr` | `MLX_DISTRIBUTED_ADDR` | `localhost:50051` | gRPC API listen address for LocalAI to connect to (rank 0 only, separate from ring communication) | -| `--coordinator` | `MLX_JACCL_COORDINATOR` | | JACCL coordinator `ip:port` — rank 0's IP where it accepts RDMA setup connections (jaccl only, required for all ranks) | +| `--backend` | `MLX_DISTRIBUTED_BACKEND` | `ring` | `ring` (TCP pipeline parallelism) or `jaccl` (RDMA tensor parallelism) | +| `--addr` | `MLX_DISTRIBUTED_ADDR` | `localhost:50051` | gRPC API listen address (rank 0 only, for LocalAI or external access) | +| `--coordinator` | `MLX_JACCL_COORDINATOR` | | JACCL coordinator `ip:port` — rank 0's address for RDMA setup (all ranks must use the same value) | ### `worker p2p-mlx` @@ -159,7 +203,7 @@ P2P mode — auto-discovers peers and generates hostfile. ## Troubleshooting - **All ranks must have the model downloaded.** MLX distributed does not transfer model weights over the network. Ensure each node has the model in its Hugging Face cache. -- **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. +- **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. Start workers before loading the model. - **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs. - **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory. From c2ac2bde2978e7b5296612a6086f6a8181f22ab1 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 5 Mar 2026 22:46:17 +0000 Subject: [PATCH 4/4] Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto --- backend/python/mlx-distributed/backend.py | 129 ++++++++- backend/python/mlx-distributed/mlx_cache.py | 266 ++++++++++++++++++ .../mlx-distributed/requirements-cpu.txt | 2 + .../mlx-distributed/requirements-cublas12.txt | 2 + .../mlx-distributed/requirements-cublas13.txt | 2 + .../mlx-distributed/requirements-l4t12.txt | 2 + .../mlx-distributed/requirements-l4t13.txt | 2 + backend/python/mlx-distributed/test.py | 58 +++- docs/content/features/mlx-distributed.md | 2 +- 9 files changed, 447 insertions(+), 18 deletions(-) create mode 100644 backend/python/mlx-distributed/mlx_cache.py create mode 100644 backend/python/mlx-distributed/requirements-cpu.txt create mode 100644 backend/python/mlx-distributed/requirements-cublas12.txt create mode 100644 backend/python/mlx-distributed/requirements-cublas13.txt create mode 100644 backend/python/mlx-distributed/requirements-l4t12.txt create mode 100644 backend/python/mlx-distributed/requirements-l4t13.txt diff --git a/backend/python/mlx-distributed/backend.py b/backend/python/mlx-distributed/backend.py index 4212f555b724..b21a98070612 100644 --- a/backend/python/mlx-distributed/backend.py +++ b/backend/python/mlx-distributed/backend.py @@ -20,6 +20,7 @@ import signal import sys import tempfile +from typing import List import grpc @@ -104,6 +105,9 @@ def __init__(self): self.tokenizer = None self.coordinator = None self.options = {} + self.lru_cache = None + self.model_key = None + self.max_kv_size = None def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) @@ -112,12 +116,12 @@ async def LoadModel(self, request, context): try: import mlx.core as mx from mlx_lm import load - from coordinator import DistributedCoordinator, CMD_LOAD_MODEL - from sharding import pipeline_auto_parallel + from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr) self.options = parse_options(request.Options) + print(f"Options: {self.options}", file=sys.stderr) # Get distributed config from model options, falling back to env vars. # If neither is set, run as single-node (no distributed). @@ -130,6 +134,9 @@ async def LoadModel(self, request, context): jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "") if hostfile: + from coordinator import DistributedCoordinator, CMD_LOAD_MODEL + from sharding import pipeline_auto_parallel + print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr) self.dist_backend = dist_backend self.group = mlx_distributed_init( @@ -144,20 +151,38 @@ async def LoadModel(self, request, context): else: print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr) + # Build tokenizer config from request and options tokenizer_config = {} if request.TrustRemoteCode or self.options.get("trust_remote_code", False): tokenizer_config["trust_remote_code"] = True + # Token overrides from options + for key in ["eos_token", "pad_token", "bos_token", "unk_token", + "sep_token", "cls_token", "mask_token"]: + if key in self.options: + tokenizer_config[key] = self.options[key] if tokenizer_config: + print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr) self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) else: self.model, self.tokenizer = load(request.Model) if self.group is not None: + from sharding import pipeline_auto_parallel self.model = pipeline_auto_parallel(self.model, self.group) print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr) else: - print("[Rank 0] Model loaded (single-node)", file=sys.stderr) + # Single-node: set up prompt cache for efficient generation + from mlx_cache import ThreadSafeLRUPromptCache + max_cache_entries = self.options.get("max_cache_entries", 10) + self.max_kv_size = self.options.get("max_kv_size", None) + self.model_key = request.Model + self.lru_cache = ThreadSafeLRUPromptCache( + max_size=max_cache_entries, + can_trim_fn=can_trim_prompt_cache, + trim_fn=trim_prompt_cache, + ) + print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr) except Exception as err: print(f"[Rank 0] Error loading model: {err}", file=sys.stderr) @@ -166,18 +191,19 @@ async def LoadModel(self, request, context): return backend_pb2.Result(message="Model loaded successfully", success=True) async def Predict(self, request, context): + prompt_cache = None + cache_key = None + try: import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.sample_utils import make_sampler - from coordinator import CMD_GENERATE prompt_text = self._prepare_prompt(request) - tokens = self.tokenizer.encode(prompt_text) - if hasattr(tokens, 'tolist'): - tokens = tokens.tolist() + tokens = self._get_tokens_from_prompt(prompt_text) if self.coordinator: + from coordinator import CMD_GENERATE self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) self.coordinator.broadcast_tokens(tokens) @@ -193,6 +219,20 @@ async def Predict(self, request, context): sampler = make_sampler(**sampler_params) + # Use prompt cache in single-node mode + gen_kwargs = {} + if self.lru_cache is not None: + from mlx_lm.models.cache import make_prompt_cache + cache_key = list(tokens) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + gen_kwargs['prompt_cache'] = prompt_cache + tokens = remaining_tokens if remaining_tokens else cache_key + generated = [] for response in stream_generate( self.model, @@ -200,8 +240,14 @@ async def Predict(self, request, context): prompt=tokens, max_tokens=max_tokens, sampler=sampler, + **gen_kwargs, ): generated.append(response.text) + if cache_key is not None: + cache_key.append(response.token) + + if self.lru_cache is not None and cache_key is not None: + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8')) @@ -212,18 +258,19 @@ async def Predict(self, request, context): return backend_pb2.Reply(message=bytes("", encoding='utf-8')) async def PredictStream(self, request, context): + prompt_cache = None + cache_key = None + try: import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.sample_utils import make_sampler - from coordinator import CMD_GENERATE prompt_text = self._prepare_prompt(request) - tokens = self.tokenizer.encode(prompt_text) - if hasattr(tokens, 'tolist'): - tokens = tokens.tolist() + tokens = self._get_tokens_from_prompt(prompt_text) if self.coordinator: + from coordinator import CMD_GENERATE self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) self.coordinator.broadcast_tokens(tokens) @@ -239,13 +286,30 @@ async def PredictStream(self, request, context): sampler = make_sampler(**sampler_params) + # Use prompt cache in single-node mode + gen_kwargs = {} + if self.lru_cache is not None: + from mlx_lm.models.cache import make_prompt_cache + cache_key = list(tokens) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + gen_kwargs['prompt_cache'] = prompt_cache + tokens = remaining_tokens if remaining_tokens else cache_key + for response in stream_generate( self.model, self.tokenizer, prompt=tokens, max_tokens=max_tokens, sampler=sampler, + **gen_kwargs, ): + if cache_key is not None: + cache_key.append(response.token) yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) except Exception as e: @@ -254,6 +318,19 @@ async def PredictStream(self, request, context): context.set_details(f"Streaming failed: {str(e)}") yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + finally: + if self.lru_cache is not None and prompt_cache is not None and cache_key is not None: + try: + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) + except Exception as e: + print(f"Error inserting cache: {e}", file=sys.stderr) + + def Embedding(self, request, context): + print("Embeddings not supported in MLX distributed backend", file=sys.stderr) + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Embeddings are not supported in the MLX distributed backend.") + return backend_pb2.EmbeddingResult() + def _prepare_prompt(self, request): if not request.Prompt and request.UseTokenizerTemplate and request.Messages: messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages] @@ -262,7 +339,15 @@ def _prepare_prompt(self, request): ) return request.Prompt + def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: + tokens = self.tokenizer.encode(prompt_text) + if hasattr(tokens, 'tolist'): + return tokens.tolist() + return list(tokens) + def _build_generation_params(self, request, default_max_tokens=200): + import mlx.core as mx + max_tokens = getattr(request, 'Tokens', default_max_tokens) if max_tokens == 0: max_tokens = default_max_tokens @@ -286,23 +371,37 @@ def _build_generation_params(self, request, default_max_tokens=200): seed = getattr(request, 'Seed', 0) if seed != 0: - import mlx.core as mx mx.random.seed(seed) if hasattr(self, 'options'): if 'max_tokens' in self.options: max_tokens = self.options['max_tokens'] option_mapping = { - 'temp': 'temp', 'temperature': 'temp', - 'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k', + 'temp': 'temp', + 'temperature': 'temp', + 'top_p': 'top_p', + 'min_p': 'min_p', + 'top_k': 'top_k', + 'xtc_threshold': 'xtc_threshold', + 'xtc_probability': 'xtc_probability', } for opt_key, param_key in option_mapping.items(): if opt_key in self.options: sampler_params[param_key] = self.options[opt_key] + if 'seed' in self.options: + mx.random.seed(self.options['seed']) + # XTC special tokens xtc_special_tokens = [] - if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: + if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids: + xtc_special_tokens = list(self.tokenizer.eos_token_ids) + elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: xtc_special_tokens = [self.tokenizer.eos_token_id] + try: + newline_tokens = self.tokenizer.encode("\n") + xtc_special_tokens.extend(newline_tokens) + except: + pass sampler_params['xtc_special_tokens'] = xtc_special_tokens return max_tokens, sampler_params diff --git a/backend/python/mlx-distributed/mlx_cache.py b/backend/python/mlx-distributed/mlx_cache.py new file mode 100644 index 000000000000..6ec2bb9baabb --- /dev/null +++ b/backend/python/mlx-distributed/mlx_cache.py @@ -0,0 +1,266 @@ +""" +Thread-safe LRU prompt cache for MLX-based backends. + +Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.) +with thread-safety additions for LocalAI's gRPC backend. + +Usage: + from mlx_cache import ThreadSafeLRUPromptCache + + # In LoadModel: + self.lru_cache = ThreadSafeLRUPromptCache(max_size=10) + + # In Predict/PredictStream: + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens) + # ... generate ... + self.lru_cache.insert_cache(model_key, tokens, prompt_cache) +""" +import copy +import threading +from collections import deque +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + + +@dataclass +class CacheEntry: + """A cache entry with reference counting.""" + prompt_cache: List[Any] + count: int + + +@dataclass +class SearchResult: + """Result of searching the cache trie.""" + model: Any + exact: Optional[List[int]] + shorter: Optional[List[int]] + longer: Optional[List[int]] + common_prefix: int + + +class ThreadSafeLRUPromptCache: + """ + Thread-safe LRU cache with prefix matching for prompt KV caches. + + This cache stores KV caches keyed by token sequences and supports: + - Exact match: Return the cache for the exact token sequence + - Shorter prefix match: Return a cache for a prefix of the tokens + - Longer prefix match: If a longer sequence is cached and can be trimmed + - LRU eviction: When max_size is exceeded, evict least recently used + + Thread safety is provided via a threading.Lock that protects all + cache operations. + + Args: + max_size: Maximum number of cache entries (default: 10) + can_trim_fn: Optional function to check if a cache can be trimmed + trim_fn: Optional function to trim a cache + """ + + def __init__( + self, + max_size: int = 10, + can_trim_fn: Optional[Any] = None, + trim_fn: Optional[Any] = None, + ): + self.max_size = max_size + self._cache = {} + self._lru = deque() + self._lock = threading.Lock() + + # Optional trim functions (for longer prefix reuse) + self._can_trim_fn = can_trim_fn + self._trim_fn = trim_fn + + def _search(self, model, tokens: List[int]) -> SearchResult: + """ + Search the cache for a prompt cache. Return exact or close match. + + The cache is organized as a trie where each node is keyed by a token. + This allows efficient prefix matching. + """ + if model not in self._cache: + return SearchResult(model, None, None, None, 0) + + current = self._cache[model] + last_cache_index = -1 + index = 0 + + # Traverse the trie following the token sequence + while index < len(tokens) and tokens[index] in current: + current = current[tokens[index]] + if "cache" in current: + last_cache_index = index + index += 1 + + # Exact match - no need to search for longer or shorter caches + if last_cache_index == len(tokens) - 1: + return SearchResult(model, tuple(tokens), None, None, 0) + + # Find the shorter cache (a prefix that has a cache) + # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior. + # Single-token prefixes are not matched, which allows longer cached + # sequences to be preferred for trimming. This is acceptable because + # real prompts with chat templates are always many tokens. + shorter = None + if last_cache_index > 0: + shorter = tuple(tokens[: last_cache_index + 1]) + + # Check for caches that are longer than our token sequence + longer = None + common_prefix = index + if index > 0 and last_cache_index <= 0: + best = None + stack = [(current, [])] + while stack: + current, extra = stack.pop() + if "cache" in current: + if best is None or len(extra) < len(best): + best = extra + else: + for tok in current: + stack.append((current[tok], extra + [tok])) + if best is not None: + longer = tuple(tokens[:index] + best) + + return SearchResult(model, None, shorter, longer, common_prefix) + + def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """Get a cache entry by traversing the trie.""" + current = self._cache[model] + for tok in tokens: + current = current[tok] + return current["cache"] + + def _delete(self, model, tokens: Tuple[int, ...]) -> None: + """Delete a cache entry and clean up empty trie nodes.""" + path = [self._cache[model]] + for tok in tokens: + path.append(path[-1][tok]) + del path[-1]["cache"] + + # Clean up empty nodes bottom-up + for i in reversed(range(len(tokens))): + d_prev, d, t = path[i], path[i + 1], tokens[i] + if len(d) > 0: + break + del d_prev[t] + + def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """ + Extract a cache entry for exclusive use. + + If the entry has count > 1, deep copy and decrement. + If count == 1, remove from cache entirely. + """ + cache_entry = self._get(model, tokens) + if cache_entry.count == 1: + self._delete(model, tokens) + self._lru.remove((model, tokens)) + return cache_entry + + cache_entry.count -= 1 + return CacheEntry( + copy.deepcopy(cache_entry.prompt_cache), + 1, + ) + + def fetch_nearest_cache( + self, model, tokens: List[int] + ) -> Tuple[Optional[List[Any]], List[int]]: + """ + Fetch the nearest cache for the given token sequence. + + Thread-safe. Returns (cache, remaining_tokens) where: + - cache: The KV cache to use (or None if no cache found) + - remaining_tokens: Tokens that still need to be processed + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence for the prompt + + Returns: + Tuple of (prompt_cache, remaining_tokens) + """ + with self._lock: + tokens_tuple = tuple(tokens) + result = self._search(model, tokens) + + # Exact match - extract and return + if result.exact is not None: + cache_entry = self._extract(result.model, result.exact) + return cache_entry.prompt_cache, [] + + # Shorter prefix match - extract and return remaining + if result.shorter is not None: + cache_entry = self._extract(result.model, result.shorter) + prefix_len = len(result.shorter) + return cache_entry.prompt_cache, list(tokens[prefix_len:]) + + # Longer prefix match - try to trim if possible + if result.longer is not None and self._can_trim_fn is not None: + cache_entry = self._get(result.model, result.longer) + if self._can_trim_fn(cache_entry.prompt_cache): + # Deep copy and trim + trimmed_cache = copy.deepcopy(cache_entry.prompt_cache) + prefix = min(len(tokens) - 1, result.common_prefix) + num_to_trim = len(result.longer) - prefix + if self._trim_fn is not None: + self._trim_fn(trimmed_cache, num_to_trim) + return trimmed_cache, list(tokens[prefix:]) + + # No match found + return None, list(tokens) + + def insert_cache( + self, model, tokens: List[int], prompt_cache: List[Any] + ) -> None: + """ + Insert a cache entry after generation completes. + + Thread-safe. Handles LRU eviction if max_size is exceeded. + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence (prompt + generated) + prompt_cache: The KV cache to store + """ + with self._lock: + tokens_tuple = tuple(tokens) + + if model not in self._cache: + self._cache[model] = {} + current = self._cache[model] + + # Build trie path + for tok in tokens_tuple: + if tok not in current: + current[tok] = {} + current = current[tok] + + # Update or create entry + if "cache" in current: + current["cache"].count += 1 + self._lru.remove((model, tokens_tuple)) + else: + current["cache"] = CacheEntry(prompt_cache, 1) + + # Update LRU order + self._lru.append((model, tokens_tuple)) + + # Evict if over capacity + if len(self._lru) > self.max_size: + evict_model, evict_tokens = self._lru.popleft() + self._delete(evict_model, evict_tokens) + + def clear(self) -> None: + """Clear all cache entries. Thread-safe.""" + with self._lock: + self._cache.clear() + self._lru.clear() + + def __len__(self) -> int: + """Return the number of cache entries. Thread-safe.""" + with self._lock: + return len(self._lru) diff --git a/backend/python/mlx-distributed/requirements-cpu.txt b/backend/python/mlx-distributed/requirements-cpu.txt new file mode 100644 index 000000000000..a381b1461765 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-cpu.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cpu] diff --git a/backend/python/mlx-distributed/requirements-cublas12.txt b/backend/python/mlx-distributed/requirements-cublas12.txt new file mode 100644 index 000000000000..dc057533bc36 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-cublas12.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda12] diff --git a/backend/python/mlx-distributed/requirements-cublas13.txt b/backend/python/mlx-distributed/requirements-cublas13.txt new file mode 100644 index 000000000000..40cc694535fd --- /dev/null +++ b/backend/python/mlx-distributed/requirements-cublas13.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda13] diff --git a/backend/python/mlx-distributed/requirements-l4t12.txt b/backend/python/mlx-distributed/requirements-l4t12.txt new file mode 100644 index 000000000000..dc057533bc36 --- /dev/null +++ b/backend/python/mlx-distributed/requirements-l4t12.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda12] diff --git a/backend/python/mlx-distributed/requirements-l4t13.txt b/backend/python/mlx-distributed/requirements-l4t13.txt new file mode 100644 index 000000000000..40cc694535fd --- /dev/null +++ b/backend/python/mlx-distributed/requirements-l4t13.txt @@ -0,0 +1,2 @@ +mlx-lm +mlx[cuda13] diff --git a/backend/python/mlx-distributed/test.py b/backend/python/mlx-distributed/test.py index 16f96d04c692..4cb1440edc71 100644 --- a/backend/python/mlx-distributed/test.py +++ b/backend/python/mlx-distributed/test.py @@ -10,8 +10,7 @@ class TestBackendServicer(unittest.TestCase): def setUp(self): self.service = subprocess.Popen( - ["python", "backend.py", "--addr", "localhost:50051", - "--hostfile", "/dev/null", "--rank", "0"] + ["python", "backend.py", "--addr", "localhost:50051"] ) time.sleep(10) @@ -31,3 +30,58 @@ def test_server_startup(self): self.fail("Server failed to start") finally: self.tearDown() + + def test_load_model(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_text(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + req = backend_pb2.PredictOptions(Prompt="The capital of France is") + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + except Exception as err: + print(err) + self.fail("text service failed") + finally: + self.tearDown() + + def test_sampling_params(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) + self.assertTrue(response.success) + + req = backend_pb2.PredictOptions( + Prompt="The capital of France is", + TopP=0.8, + Tokens=50, + Temperature=0.7, + TopK=40, + MinP=0.05, + Seed=42, + ) + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + except Exception as err: + print(err) + self.fail("sampling params service failed") + finally: + self.tearDown() diff --git a/docs/content/features/mlx-distributed.md b/docs/content/features/mlx-distributed.md index d4b22ca1edfd..ee6d5070e4db 100644 --- a/docs/content/features/mlx-distributed.md +++ b/docs/content/features/mlx-distributed.md @@ -202,7 +202,7 @@ P2P mode — auto-discovers peers and generates hostfile. ## Troubleshooting -- **All ranks must have the model downloaded.** MLX distributed does not transfer model weights over the network. Ensure each node has the model in its Hugging Face cache. +- **All ranks download the model independently.** Each node auto-downloads from Hugging Face on first use via `mlx_lm.load()`. On rank 0 (started by LocalAI), models are downloaded to LocalAI's model directory (`HF_HOME` is set automatically). On workers, models go to the default HF cache (`~/.cache/huggingface/hub`) unless you set `HF_HOME` yourself. - **Timeout errors:** If ranks can't connect, check firewall rules. The Ring backend uses TCP on the ports listed in the hostfile. Start workers before loading the model. - **Rank assignment:** In P2P mode, rank 0 is always the LocalAI server. Worker ranks are assigned by sorting node IDs. - **Performance:** Pipeline parallelism adds latency proportional to the number of ranks. For best results, use the fewest ranks needed to fit your model in memory.