Skip to content

Commit 9919cfb

Browse files
authored
Add default 'auto' MODEL_IMPL_TYPE that resolves based on architecture (#1255)
1 parent d52ec89 commit 9919cfb

File tree

5 files changed

+80
-23
lines changed

5 files changed

+80
-23
lines changed

tests/models/common/test_model_loader.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,43 @@ def test_get_model_not_implemented(self, mock_get_flax, mock_get_vllm,
381381

382382
mock_get_flax.assert_not_called()
383383
mock_get_vllm.assert_not_called()
384+
385+
@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
386+
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
387+
@patch("tpu_inference.models.common.model_loader.get_flax_model")
388+
def test_get_model_auto_resolves_to_flax_nnx(self, mock_get_flax,
389+
mock_get_vllm, vllm_config,
390+
rng, mesh):
391+
"""
392+
Tests that 'auto' resolves to 'flax_nnx' for standard architectures
393+
(not in _VLLM_REQUIRED_ARCHITECTURES).
394+
"""
395+
# vllm_config uses Qwen3 which is NOT in _VLLM_REQUIRED_ARCHITECTURES
396+
mock_get_flax.return_value = "flax_model_sentinel"
397+
398+
result = model_loader.get_model(vllm_config, rng, mesh)
399+
400+
mock_get_flax.assert_called_once_with(vllm_config, rng, mesh, False)
401+
mock_get_vllm.assert_not_called()
402+
assert result == "flax_model_sentinel"
403+
404+
@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
405+
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
406+
@patch("tpu_inference.models.common.model_loader.get_flax_model")
407+
def test_get_model_auto_resolves_to_vllm_for_gpt_oss(
408+
self, mock_get_flax, mock_get_vllm, vllm_config, rng, mesh):
409+
"""
410+
Tests that 'auto' resolves to 'vllm' for architectures in
411+
_VLLM_REQUIRED_ARCHITECTURES (e.g., GptOssForCausalLM).
412+
"""
413+
# Mock the architecture to be GptOssForCausalLM
414+
vllm_config.model_config.hf_config.architectures = [
415+
"GptOssForCausalLM"
416+
]
417+
mock_get_vllm.return_value = "vllm_model_sentinel"
418+
419+
result = model_loader.get_model(vllm_config, rng, mesh)
420+
421+
mock_get_flax.assert_not_called()
422+
mock_get_vllm.assert_called_once_with(vllm_config, rng, mesh)
423+
assert result == "vllm_model_sentinel"

tests/test_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
256256

257257
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
258258
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
259-
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
259+
assert envs.MODEL_IMPL_TYPE == "auto"
260260

261261

262262
def test_cache_preserves_values_across_env_changes(

tpu_inference/envs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DECODE_SLICES: str = ""
1717
SKIP_JAX_PRECOMPILE: bool = False
1818
VLLM_XLA_CHECK_RECOMPILATION: bool = False
19-
MODEL_IMPL_TYPE: str = "flax_nnx"
19+
MODEL_IMPL_TYPE: str = "auto"
2020
NEW_MODEL_DESIGN: bool = False
2121
PHASED_PROFILING_DIR: str = ""
2222
PYTHON_TRACER_LEVEL: int = 1
@@ -128,8 +128,8 @@ def _get_bool_env() -> bool:
128128
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
129129
# Model implementation type (e.g., "flax_nnx")
130130
"MODEL_IMPL_TYPE":
131-
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
132-
["vllm", "flax_nnx", "jetpack"]),
131+
env_with_choices("MODEL_IMPL_TYPE", "auto",
132+
["auto", "vllm", "flax_nnx", "jetpack"]),
133133
# Enable new experimental model design
134134
"NEW_MODEL_DESIGN":
135135
env_bool("NEW_MODEL_DESIGN", default=False),

tpu_inference/models/common/model_loader.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424

2525
_MODEL_REGISTRY = {}
2626

27+
# Architectures that prefer "vllm" implementation type when MODEL_IMPL_TYPE is "auto".
28+
# These architectures are listed here because they have better performance with the
29+
# vLLM PyTorch backend compared to the flax_nnx JAX backend for now.
30+
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
31+
{"GptOssForCausalLM"})
32+
2733

2834
class UnsupportedArchitectureError(ValueError):
2935
"""Raised when a model architecture is not supported in the registry."""
@@ -342,24 +348,36 @@ def get_model(
342348
impl = envs.MODEL_IMPL_TYPE
343349
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
344350

345-
if impl == "flax_nnx":
346-
try:
347-
# Try to load the flax model first
348-
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
349-
except UnsupportedArchitectureError as e:
350-
# Convert the error message to a string to check its contents
351-
error_msg = str(e)
352-
353-
logger.warning(error_msg)
354-
355-
# Fall back to the vLLM model and updating the dtype accordingly
356-
vllm_config.model_config.dtype = j2t_dtype(
357-
vllm_config.model_config.dtype.dtype)
351+
if impl == "auto":
352+
# Resolve "auto" based on architecture
353+
architectures = getattr(vllm_config.model_config.hf_config,
354+
"architectures", [])
355+
assert len(architectures) == 1, (
356+
f"Expected exactly one architecture, got {len(architectures)}: "
357+
f"{architectures}")
358+
arch = architectures[0]
359+
impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
360+
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
361+
362+
match impl:
363+
case "flax_nnx":
364+
try:
365+
# Try to load the flax model first
366+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
367+
except UnsupportedArchitectureError as e:
368+
# Convert the error message to a string to check its contents
369+
error_msg = str(e)
370+
371+
logger.warning(error_msg)
372+
373+
# Fall back to the vLLM model and updating the dtype accordingly
374+
vllm_config.model_config.dtype = j2t_dtype(
375+
vllm_config.model_config.dtype.dtype)
376+
return get_vllm_model(vllm_config, rng, mesh)
377+
case "vllm":
358378
return get_vllm_model(vllm_config, rng, mesh)
359-
elif impl == "vllm":
360-
return get_vllm_model(vllm_config, rng, mesh)
361-
else:
362-
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
379+
case _:
380+
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
363381

364382

365383
def _validate_model_interface(model: Any) -> None:

tpu_inference/runner/tpu_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import functools
3-
import os
43
import random
54
from contextlib import nullcontext
65
from dataclasses import dataclass
@@ -1719,7 +1718,7 @@ def _sync_weights(
17191718
shard=shard)
17201719

17211720
def get_intermediate_tensor_spec(self, num_tokens: int):
1722-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1721+
impl = envs.MODEL_IMPL_TYPE
17231722
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
17241723
num_padded_tokens = runner_utils.get_padded_token_len(
17251724
self.num_tokens_paddings, num_tokens)

0 commit comments

Comments
 (0)