Skip to content

Commit 1cac262

Browse files
committed
Add missing features from mlx backend
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent dfebb53 commit 1cac262

8 files changed

Lines changed: 446 additions & 17 deletions

File tree

backend/python/mlx-distributed/backend.py

Lines changed: 114 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import signal
2121
import sys
2222
import tempfile
23+
from typing import List
2324

2425
import grpc
2526

@@ -104,6 +105,9 @@ def __init__(self):
104105
self.tokenizer = None
105106
self.coordinator = None
106107
self.options = {}
108+
self.lru_cache = None
109+
self.model_key = None
110+
self.max_kv_size = None
107111

108112
def Health(self, request, context):
109113
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
@@ -112,12 +116,12 @@ async def LoadModel(self, request, context):
112116
try:
113117
import mlx.core as mx
114118
from mlx_lm import load
115-
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL
116-
from sharding import pipeline_auto_parallel
119+
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
117120

118121
print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr)
119122

120123
self.options = parse_options(request.Options)
124+
print(f"Options: {self.options}", file=sys.stderr)
121125

122126
# Get distributed config from model options, falling back to env vars.
123127
# If neither is set, run as single-node (no distributed).
@@ -130,6 +134,9 @@ async def LoadModel(self, request, context):
130134
jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "")
131135

132136
if hostfile:
137+
from coordinator import DistributedCoordinator, CMD_LOAD_MODEL
138+
from sharding import pipeline_auto_parallel
139+
133140
print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr)
134141
self.dist_backend = dist_backend
135142
self.group = mlx_distributed_init(
@@ -144,20 +151,38 @@ async def LoadModel(self, request, context):
144151
else:
145152
print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr)
146153

154+
# Build tokenizer config from request and options
147155
tokenizer_config = {}
148156
if request.TrustRemoteCode or self.options.get("trust_remote_code", False):
149157
tokenizer_config["trust_remote_code"] = True
158+
# Token overrides from options
159+
for key in ["eos_token", "pad_token", "bos_token", "unk_token",
160+
"sep_token", "cls_token", "mask_token"]:
161+
if key in self.options:
162+
tokenizer_config[key] = self.options[key]
150163

151164
if tokenizer_config:
165+
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
152166
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
153167
else:
154168
self.model, self.tokenizer = load(request.Model)
155169

156170
if self.group is not None:
171+
from sharding import pipeline_auto_parallel
157172
self.model = pipeline_auto_parallel(self.model, self.group)
158173
print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr)
159174
else:
160-
print("[Rank 0] Model loaded (single-node)", file=sys.stderr)
175+
# Single-node: set up prompt cache for efficient generation
176+
from mlx_cache import ThreadSafeLRUPromptCache
177+
max_cache_entries = self.options.get("max_cache_entries", 10)
178+
self.max_kv_size = self.options.get("max_kv_size", None)
179+
self.model_key = request.Model
180+
self.lru_cache = ThreadSafeLRUPromptCache(
181+
max_size=max_cache_entries,
182+
can_trim_fn=can_trim_prompt_cache,
183+
trim_fn=trim_prompt_cache,
184+
)
185+
print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr)
161186

162187
except Exception as err:
163188
print(f"[Rank 0] Error loading model: {err}", file=sys.stderr)
@@ -166,18 +191,19 @@ async def LoadModel(self, request, context):
166191
return backend_pb2.Result(message="Model loaded successfully", success=True)
167192

168193
async def Predict(self, request, context):
194+
prompt_cache = None
195+
cache_key = None
196+
169197
try:
170198
import mlx.core as mx
171199
from mlx_lm import stream_generate
172200
from mlx_lm.sample_utils import make_sampler
173-
from coordinator import CMD_GENERATE
174201

175202
prompt_text = self._prepare_prompt(request)
176-
tokens = self.tokenizer.encode(prompt_text)
177-
if hasattr(tokens, 'tolist'):
178-
tokens = tokens.tolist()
203+
tokens = self._get_tokens_from_prompt(prompt_text)
179204

180205
if self.coordinator:
206+
from coordinator import CMD_GENERATE
181207
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
182208
self.coordinator.broadcast_tokens(tokens)
183209

@@ -193,15 +219,35 @@ async def Predict(self, request, context):
193219

194220
sampler = make_sampler(**sampler_params)
195221

222+
# Use prompt cache in single-node mode
223+
gen_kwargs = {}
224+
if self.lru_cache is not None:
225+
from mlx_lm.models.cache import make_prompt_cache
226+
cache_key = list(tokens)
227+
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
228+
self.model_key, cache_key
229+
)
230+
if prompt_cache is None:
231+
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
232+
remaining_tokens = cache_key
233+
gen_kwargs['prompt_cache'] = prompt_cache
234+
tokens = remaining_tokens if remaining_tokens else cache_key
235+
196236
generated = []
197237
for response in stream_generate(
198238
self.model,
199239
self.tokenizer,
200240
prompt=tokens,
201241
max_tokens=max_tokens,
202242
sampler=sampler,
243+
**gen_kwargs,
203244
):
204245
generated.append(response.text)
246+
if cache_key is not None:
247+
cache_key.append(response.token)
248+
249+
if self.lru_cache is not None and cache_key is not None:
250+
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
205251

206252
return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8'))
207253

@@ -212,18 +258,19 @@ async def Predict(self, request, context):
212258
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
213259

214260
async def PredictStream(self, request, context):
261+
prompt_cache = None
262+
cache_key = None
263+
215264
try:
216265
import mlx.core as mx
217266
from mlx_lm import stream_generate
218267
from mlx_lm.sample_utils import make_sampler
219-
from coordinator import CMD_GENERATE
220268

221269
prompt_text = self._prepare_prompt(request)
222-
tokens = self.tokenizer.encode(prompt_text)
223-
if hasattr(tokens, 'tolist'):
224-
tokens = tokens.tolist()
270+
tokens = self._get_tokens_from_prompt(prompt_text)
225271

226272
if self.coordinator:
273+
from coordinator import CMD_GENERATE
227274
self.coordinator.broadcast_command(CMD_GENERATE, len(tokens))
228275
self.coordinator.broadcast_tokens(tokens)
229276

@@ -239,13 +286,30 @@ async def PredictStream(self, request, context):
239286

240287
sampler = make_sampler(**sampler_params)
241288

289+
# Use prompt cache in single-node mode
290+
gen_kwargs = {}
291+
if self.lru_cache is not None:
292+
from mlx_lm.models.cache import make_prompt_cache
293+
cache_key = list(tokens)
294+
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
295+
self.model_key, cache_key
296+
)
297+
if prompt_cache is None:
298+
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
299+
remaining_tokens = cache_key
300+
gen_kwargs['prompt_cache'] = prompt_cache
301+
tokens = remaining_tokens if remaining_tokens else cache_key
302+
242303
for response in stream_generate(
243304
self.model,
244305
self.tokenizer,
245306
prompt=tokens,
246307
max_tokens=max_tokens,
247308
sampler=sampler,
309+
**gen_kwargs,
248310
):
311+
if cache_key is not None:
312+
cache_key.append(response.token)
249313
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
250314

251315
except Exception as e:
@@ -254,6 +318,19 @@ async def PredictStream(self, request, context):
254318
context.set_details(f"Streaming failed: {str(e)}")
255319
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
256320

321+
finally:
322+
if self.lru_cache is not None and prompt_cache is not None and cache_key is not None:
323+
try:
324+
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
325+
except Exception as e:
326+
print(f"Error inserting cache: {e}", file=sys.stderr)
327+
328+
def Embedding(self, request, context):
329+
print("Embeddings not supported in MLX distributed backend", file=sys.stderr)
330+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
331+
context.set_details("Embeddings are not supported in the MLX distributed backend.")
332+
return backend_pb2.EmbeddingResult()
333+
257334
def _prepare_prompt(self, request):
258335
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
259336
messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages]
@@ -262,7 +339,15 @@ def _prepare_prompt(self, request):
262339
)
263340
return request.Prompt
264341

342+
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
343+
tokens = self.tokenizer.encode(prompt_text)
344+
if hasattr(tokens, 'tolist'):
345+
return tokens.tolist()
346+
return list(tokens)
347+
265348
def _build_generation_params(self, request, default_max_tokens=200):
349+
import mlx.core as mx
350+
266351
max_tokens = getattr(request, 'Tokens', default_max_tokens)
267352
if max_tokens == 0:
268353
max_tokens = default_max_tokens
@@ -286,23 +371,37 @@ def _build_generation_params(self, request, default_max_tokens=200):
286371

287372
seed = getattr(request, 'Seed', 0)
288373
if seed != 0:
289-
import mlx.core as mx
290374
mx.random.seed(seed)
291375

292376
if hasattr(self, 'options'):
293377
if 'max_tokens' in self.options:
294378
max_tokens = self.options['max_tokens']
295379
option_mapping = {
296-
'temp': 'temp', 'temperature': 'temp',
297-
'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k',
380+
'temp': 'temp',
381+
'temperature': 'temp',
382+
'top_p': 'top_p',
383+
'min_p': 'min_p',
384+
'top_k': 'top_k',
385+
'xtc_threshold': 'xtc_threshold',
386+
'xtc_probability': 'xtc_probability',
298387
}
299388
for opt_key, param_key in option_mapping.items():
300389
if opt_key in self.options:
301390
sampler_params[param_key] = self.options[opt_key]
391+
if 'seed' in self.options:
392+
mx.random.seed(self.options['seed'])
302393

394+
# XTC special tokens
303395
xtc_special_tokens = []
304-
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
396+
if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids:
397+
xtc_special_tokens = list(self.tokenizer.eos_token_ids)
398+
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
305399
xtc_special_tokens = [self.tokenizer.eos_token_id]
400+
try:
401+
newline_tokens = self.tokenizer.encode("\n")
402+
xtc_special_tokens.extend(newline_tokens)
403+
except:
404+
pass
306405
sampler_params['xtc_special_tokens'] = xtc_special_tokens
307406

308407
return max_tokens, sampler_params

0 commit comments

Comments
 (0)