2020import signal
2121import sys
2222import tempfile
23+ from typing import List
2324
2425import grpc
2526
@@ -104,6 +105,9 @@ def __init__(self):
104105 self .tokenizer = None
105106 self .coordinator = None
106107 self .options = {}
108+ self .lru_cache = None
109+ self .model_key = None
110+ self .max_kv_size = None
107111
108112 def Health (self , request , context ):
109113 return backend_pb2 .Reply (message = bytes ("OK" , 'utf-8' ))
@@ -112,12 +116,12 @@ async def LoadModel(self, request, context):
112116 try :
113117 import mlx .core as mx
114118 from mlx_lm import load
115- from coordinator import DistributedCoordinator , CMD_LOAD_MODEL
116- from sharding import pipeline_auto_parallel
119+ from mlx_lm .models .cache import make_prompt_cache , can_trim_prompt_cache , trim_prompt_cache
117120
118121 print (f"[Rank 0] Loading model: { request .Model } " , file = sys .stderr )
119122
120123 self .options = parse_options (request .Options )
124+ print (f"Options: { self .options } " , file = sys .stderr )
121125
122126 # Get distributed config from model options, falling back to env vars.
123127 # If neither is set, run as single-node (no distributed).
@@ -130,6 +134,9 @@ async def LoadModel(self, request, context):
130134 jaccl_coordinator = os .environ .get ("MLX_JACCL_COORDINATOR" , "" )
131135
132136 if hostfile :
137+ from coordinator import DistributedCoordinator , CMD_LOAD_MODEL
138+ from sharding import pipeline_auto_parallel
139+
133140 print (f"[Rank 0] Initializing distributed: backend={ dist_backend } , hostfile={ hostfile } " , file = sys .stderr )
134141 self .dist_backend = dist_backend
135142 self .group = mlx_distributed_init (
@@ -144,20 +151,38 @@ async def LoadModel(self, request, context):
144151 else :
145152 print ("[Rank 0] No hostfile configured, running single-node" , file = sys .stderr )
146153
154+ # Build tokenizer config from request and options
147155 tokenizer_config = {}
148156 if request .TrustRemoteCode or self .options .get ("trust_remote_code" , False ):
149157 tokenizer_config ["trust_remote_code" ] = True
158+ # Token overrides from options
159+ for key in ["eos_token" , "pad_token" , "bos_token" , "unk_token" ,
160+ "sep_token" , "cls_token" , "mask_token" ]:
161+ if key in self .options :
162+ tokenizer_config [key ] = self .options [key ]
150163
151164 if tokenizer_config :
165+ print (f"Loading with tokenizer_config: { tokenizer_config } " , file = sys .stderr )
152166 self .model , self .tokenizer = load (request .Model , tokenizer_config = tokenizer_config )
153167 else :
154168 self .model , self .tokenizer = load (request .Model )
155169
156170 if self .group is not None :
171+ from sharding import pipeline_auto_parallel
157172 self .model = pipeline_auto_parallel (self .model , self .group )
158173 print (f"[Rank 0] Model loaded and sharded across { self .group .size ()} ranks" , file = sys .stderr )
159174 else :
160- print ("[Rank 0] Model loaded (single-node)" , file = sys .stderr )
175+ # Single-node: set up prompt cache for efficient generation
176+ from mlx_cache import ThreadSafeLRUPromptCache
177+ max_cache_entries = self .options .get ("max_cache_entries" , 10 )
178+ self .max_kv_size = self .options .get ("max_kv_size" , None )
179+ self .model_key = request .Model
180+ self .lru_cache = ThreadSafeLRUPromptCache (
181+ max_size = max_cache_entries ,
182+ can_trim_fn = can_trim_prompt_cache ,
183+ trim_fn = trim_prompt_cache ,
184+ )
185+ print ("[Rank 0] Model loaded (single-node with prompt cache)" , file = sys .stderr )
161186
162187 except Exception as err :
163188 print (f"[Rank 0] Error loading model: { err } " , file = sys .stderr )
@@ -166,18 +191,19 @@ async def LoadModel(self, request, context):
166191 return backend_pb2 .Result (message = "Model loaded successfully" , success = True )
167192
168193 async def Predict (self , request , context ):
194+ prompt_cache = None
195+ cache_key = None
196+
169197 try :
170198 import mlx .core as mx
171199 from mlx_lm import stream_generate
172200 from mlx_lm .sample_utils import make_sampler
173- from coordinator import CMD_GENERATE
174201
175202 prompt_text = self ._prepare_prompt (request )
176- tokens = self .tokenizer .encode (prompt_text )
177- if hasattr (tokens , 'tolist' ):
178- tokens = tokens .tolist ()
203+ tokens = self ._get_tokens_from_prompt (prompt_text )
179204
180205 if self .coordinator :
206+ from coordinator import CMD_GENERATE
181207 self .coordinator .broadcast_command (CMD_GENERATE , len (tokens ))
182208 self .coordinator .broadcast_tokens (tokens )
183209
@@ -193,15 +219,35 @@ async def Predict(self, request, context):
193219
194220 sampler = make_sampler (** sampler_params )
195221
222+ # Use prompt cache in single-node mode
223+ gen_kwargs = {}
224+ if self .lru_cache is not None :
225+ from mlx_lm .models .cache import make_prompt_cache
226+ cache_key = list (tokens )
227+ prompt_cache , remaining_tokens = self .lru_cache .fetch_nearest_cache (
228+ self .model_key , cache_key
229+ )
230+ if prompt_cache is None :
231+ prompt_cache = make_prompt_cache (self .model , self .max_kv_size )
232+ remaining_tokens = cache_key
233+ gen_kwargs ['prompt_cache' ] = prompt_cache
234+ tokens = remaining_tokens if remaining_tokens else cache_key
235+
196236 generated = []
197237 for response in stream_generate (
198238 self .model ,
199239 self .tokenizer ,
200240 prompt = tokens ,
201241 max_tokens = max_tokens ,
202242 sampler = sampler ,
243+ ** gen_kwargs ,
203244 ):
204245 generated .append (response .text )
246+ if cache_key is not None :
247+ cache_key .append (response .token )
248+
249+ if self .lru_cache is not None and cache_key is not None :
250+ self .lru_cache .insert_cache (self .model_key , cache_key , prompt_cache )
205251
206252 return backend_pb2 .Reply (message = bytes ('' .join (generated ), encoding = 'utf-8' ))
207253
@@ -212,18 +258,19 @@ async def Predict(self, request, context):
212258 return backend_pb2 .Reply (message = bytes ("" , encoding = 'utf-8' ))
213259
214260 async def PredictStream (self , request , context ):
261+ prompt_cache = None
262+ cache_key = None
263+
215264 try :
216265 import mlx .core as mx
217266 from mlx_lm import stream_generate
218267 from mlx_lm .sample_utils import make_sampler
219- from coordinator import CMD_GENERATE
220268
221269 prompt_text = self ._prepare_prompt (request )
222- tokens = self .tokenizer .encode (prompt_text )
223- if hasattr (tokens , 'tolist' ):
224- tokens = tokens .tolist ()
270+ tokens = self ._get_tokens_from_prompt (prompt_text )
225271
226272 if self .coordinator :
273+ from coordinator import CMD_GENERATE
227274 self .coordinator .broadcast_command (CMD_GENERATE , len (tokens ))
228275 self .coordinator .broadcast_tokens (tokens )
229276
@@ -239,13 +286,30 @@ async def PredictStream(self, request, context):
239286
240287 sampler = make_sampler (** sampler_params )
241288
289+ # Use prompt cache in single-node mode
290+ gen_kwargs = {}
291+ if self .lru_cache is not None :
292+ from mlx_lm .models .cache import make_prompt_cache
293+ cache_key = list (tokens )
294+ prompt_cache , remaining_tokens = self .lru_cache .fetch_nearest_cache (
295+ self .model_key , cache_key
296+ )
297+ if prompt_cache is None :
298+ prompt_cache = make_prompt_cache (self .model , self .max_kv_size )
299+ remaining_tokens = cache_key
300+ gen_kwargs ['prompt_cache' ] = prompt_cache
301+ tokens = remaining_tokens if remaining_tokens else cache_key
302+
242303 for response in stream_generate (
243304 self .model ,
244305 self .tokenizer ,
245306 prompt = tokens ,
246307 max_tokens = max_tokens ,
247308 sampler = sampler ,
309+ ** gen_kwargs ,
248310 ):
311+ if cache_key is not None :
312+ cache_key .append (response .token )
249313 yield backend_pb2 .Reply (message = bytes (response .text , encoding = 'utf-8' ))
250314
251315 except Exception as e :
@@ -254,6 +318,19 @@ async def PredictStream(self, request, context):
254318 context .set_details (f"Streaming failed: { str (e )} " )
255319 yield backend_pb2 .Reply (message = bytes ("" , encoding = 'utf-8' ))
256320
321+ finally :
322+ if self .lru_cache is not None and prompt_cache is not None and cache_key is not None :
323+ try :
324+ self .lru_cache .insert_cache (self .model_key , cache_key , prompt_cache )
325+ except Exception as e :
326+ print (f"Error inserting cache: { e } " , file = sys .stderr )
327+
328+ def Embedding (self , request , context ):
329+ print ("Embeddings not supported in MLX distributed backend" , file = sys .stderr )
330+ context .set_code (grpc .StatusCode .UNIMPLEMENTED )
331+ context .set_details ("Embeddings are not supported in the MLX distributed backend." )
332+ return backend_pb2 .EmbeddingResult ()
333+
257334 def _prepare_prompt (self , request ):
258335 if not request .Prompt and request .UseTokenizerTemplate and request .Messages :
259336 messages = [{"role" : msg .role , "content" : msg .content } for msg in request .Messages ]
@@ -262,7 +339,15 @@ def _prepare_prompt(self, request):
262339 )
263340 return request .Prompt
264341
342+ def _get_tokens_from_prompt (self , prompt_text : str ) -> List [int ]:
343+ tokens = self .tokenizer .encode (prompt_text )
344+ if hasattr (tokens , 'tolist' ):
345+ return tokens .tolist ()
346+ return list (tokens )
347+
265348 def _build_generation_params (self , request , default_max_tokens = 200 ):
349+ import mlx .core as mx
350+
266351 max_tokens = getattr (request , 'Tokens' , default_max_tokens )
267352 if max_tokens == 0 :
268353 max_tokens = default_max_tokens
@@ -286,23 +371,37 @@ def _build_generation_params(self, request, default_max_tokens=200):
286371
287372 seed = getattr (request , 'Seed' , 0 )
288373 if seed != 0 :
289- import mlx .core as mx
290374 mx .random .seed (seed )
291375
292376 if hasattr (self , 'options' ):
293377 if 'max_tokens' in self .options :
294378 max_tokens = self .options ['max_tokens' ]
295379 option_mapping = {
296- 'temp' : 'temp' , 'temperature' : 'temp' ,
297- 'top_p' : 'top_p' , 'min_p' : 'min_p' , 'top_k' : 'top_k' ,
380+ 'temp' : 'temp' ,
381+ 'temperature' : 'temp' ,
382+ 'top_p' : 'top_p' ,
383+ 'min_p' : 'min_p' ,
384+ 'top_k' : 'top_k' ,
385+ 'xtc_threshold' : 'xtc_threshold' ,
386+ 'xtc_probability' : 'xtc_probability' ,
298387 }
299388 for opt_key , param_key in option_mapping .items ():
300389 if opt_key in self .options :
301390 sampler_params [param_key ] = self .options [opt_key ]
391+ if 'seed' in self .options :
392+ mx .random .seed (self .options ['seed' ])
302393
394+ # XTC special tokens
303395 xtc_special_tokens = []
304- if hasattr (self .tokenizer , 'eos_token_id' ) and self .tokenizer .eos_token_id is not None :
396+ if hasattr (self .tokenizer , 'eos_token_ids' ) and self .tokenizer .eos_token_ids :
397+ xtc_special_tokens = list (self .tokenizer .eos_token_ids )
398+ elif hasattr (self .tokenizer , 'eos_token_id' ) and self .tokenizer .eos_token_id is not None :
305399 xtc_special_tokens = [self .tokenizer .eos_token_id ]
400+ try :
401+ newline_tokens = self .tokenizer .encode ("\n " )
402+ xtc_special_tokens .extend (newline_tokens )
403+ except :
404+ pass
306405 sampler_params ['xtc_special_tokens' ] = xtc_special_tokens
307406
308407 return max_tokens , sampler_params
0 commit comments