|
24 | 24 |
|
25 | 25 | _MODEL_REGISTRY = {} |
26 | 26 |
|
| 27 | +# Architectures that prefer "vllm" implementation type when MODEL_IMPL_TYPE is "auto". |
| 28 | +# These architectures are listed here because they have better performance with the |
| 29 | +# vLLM PyTorch backend compared to the flax_nnx JAX backend for now. |
| 30 | +_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset( |
| 31 | + {"GptOssForCausalLM"}) |
| 32 | + |
27 | 33 |
|
28 | 34 | class UnsupportedArchitectureError(ValueError): |
29 | 35 | """Raised when a model architecture is not supported in the registry.""" |
@@ -342,24 +348,36 @@ def get_model( |
342 | 348 | impl = envs.MODEL_IMPL_TYPE |
343 | 349 | logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}") |
344 | 350 |
|
345 | | - if impl == "flax_nnx": |
346 | | - try: |
347 | | - # Try to load the flax model first |
348 | | - return get_flax_model(vllm_config, rng, mesh, is_draft_model) |
349 | | - except UnsupportedArchitectureError as e: |
350 | | - # Convert the error message to a string to check its contents |
351 | | - error_msg = str(e) |
352 | | - |
353 | | - logger.warning(error_msg) |
354 | | - |
355 | | - # Fall back to the vLLM model and updating the dtype accordingly |
356 | | - vllm_config.model_config.dtype = j2t_dtype( |
357 | | - vllm_config.model_config.dtype.dtype) |
| 351 | + if impl == "auto": |
| 352 | + # Resolve "auto" based on architecture |
| 353 | + architectures = getattr(vllm_config.model_config.hf_config, |
| 354 | + "architectures", []) |
| 355 | + assert len(architectures) == 1, ( |
| 356 | + f"Expected exactly one architecture, got {len(architectures)}: " |
| 357 | + f"{architectures}") |
| 358 | + arch = architectures[0] |
| 359 | + impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx" |
| 360 | + logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'") |
| 361 | + |
| 362 | + match impl: |
| 363 | + case "flax_nnx": |
| 364 | + try: |
| 365 | + # Try to load the flax model first |
| 366 | + return get_flax_model(vllm_config, rng, mesh, is_draft_model) |
| 367 | + except UnsupportedArchitectureError as e: |
| 368 | + # Convert the error message to a string to check its contents |
| 369 | + error_msg = str(e) |
| 370 | + |
| 371 | + logger.warning(error_msg) |
| 372 | + |
| 373 | + # Fall back to the vLLM model and updating the dtype accordingly |
| 374 | + vllm_config.model_config.dtype = j2t_dtype( |
| 375 | + vllm_config.model_config.dtype.dtype) |
| 376 | + return get_vllm_model(vllm_config, rng, mesh) |
| 377 | + case "vllm": |
358 | 378 | return get_vllm_model(vllm_config, rng, mesh) |
359 | | - elif impl == "vllm": |
360 | | - return get_vllm_model(vllm_config, rng, mesh) |
361 | | - else: |
362 | | - raise NotImplementedError("Unsupported MODEL_IMPL_TYPE") |
| 379 | + case _: |
| 380 | + raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}") |
363 | 381 |
|
364 | 382 |
|
365 | 383 | def _validate_model_interface(model: Any) -> None: |
|
0 commit comments