Skip to content

Commit dc3706a

Browse files
committed
Add 'auto' MODEL_IMPL_TYPE that resolves based on architecture
- Add 'auto' as default value for MODEL_IMPL_TYPE env var - For GptOssForCausalLM, 'auto' resolves to 'vllm' for better performance - For all other architectures, 'auto' resolves to 'flax_nnx' - Add _VLLM_REQUIRED_ARCHITECTURES frozenset in model_loader.py - Use match/case pattern in get_model() for implementation selection - Add tests for 'auto' resolution behavior Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent 774cdbe commit dc3706a

File tree

5 files changed

+81
-23
lines changed

5 files changed

+81
-23
lines changed

tests/models/common/test_model_loader.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,43 @@ def test_get_model_not_implemented(self, mock_get_flax, mock_get_vllm,
381381

382382
mock_get_flax.assert_not_called()
383383
mock_get_vllm.assert_not_called()
384+
385+
@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
386+
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
387+
@patch("tpu_inference.models.common.model_loader.get_flax_model")
388+
def test_get_model_auto_resolves_to_flax_nnx(self, mock_get_flax,
389+
mock_get_vllm, vllm_config,
390+
rng, mesh):
391+
"""
392+
Tests that 'auto' resolves to 'flax_nnx' for standard architectures
393+
(not in _VLLM_REQUIRED_ARCHITECTURES).
394+
"""
395+
# vllm_config uses Qwen3 which is NOT in _VLLM_REQUIRED_ARCHITECTURES
396+
mock_get_flax.return_value = "flax_model_sentinel"
397+
398+
result = model_loader.get_model(vllm_config, rng, mesh)
399+
400+
mock_get_flax.assert_called_once_with(vllm_config, rng, mesh, False)
401+
mock_get_vllm.assert_not_called()
402+
assert result == "flax_model_sentinel"
403+
404+
@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
405+
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
406+
@patch("tpu_inference.models.common.model_loader.get_flax_model")
407+
def test_get_model_auto_resolves_to_vllm_for_gpt_oss(
408+
self, mock_get_flax, mock_get_vllm, vllm_config, rng, mesh):
409+
"""
410+
Tests that 'auto' resolves to 'vllm' for architectures in
411+
_VLLM_REQUIRED_ARCHITECTURES (e.g., GptOssForCausalLM).
412+
"""
413+
# Mock the architecture to be GptOssForCausalLM
414+
vllm_config.model_config.hf_config.architectures = [
415+
"GptOssForCausalLM"
416+
]
417+
mock_get_vllm.return_value = "vllm_model_sentinel"
418+
419+
result = model_loader.get_model(vllm_config, rng, mesh)
420+
421+
mock_get_flax.assert_not_called()
422+
mock_get_vllm.assert_called_once_with(vllm_config, rng, mesh)
423+
assert result == "vllm_model_sentinel"

tests/test_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
256256

257257
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
258258
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
259-
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
259+
assert envs.MODEL_IMPL_TYPE == "auto"
260260

261261

262262
def test_cache_preserves_values_across_env_changes(

tpu_inference/envs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DECODE_SLICES: str = ""
1717
SKIP_JAX_PRECOMPILE: bool = False
1818
VLLM_XLA_CHECK_RECOMPILATION: bool = False
19-
MODEL_IMPL_TYPE: str = "flax_nnx"
19+
MODEL_IMPL_TYPE: str = "auto"
2020
NEW_MODEL_DESIGN: bool = False
2121
PHASED_PROFILING_DIR: str = ""
2222
PYTHON_TRACER_LEVEL: int = 1
@@ -128,8 +128,8 @@ def _get_bool_env() -> bool:
128128
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
129129
# Model implementation type (e.g., "flax_nnx")
130130
"MODEL_IMPL_TYPE":
131-
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
132-
["vllm", "flax_nnx", "jetpack"]),
131+
env_with_choices("MODEL_IMPL_TYPE", "auto",
132+
["auto", "vllm", "flax_nnx", "jetpack"]),
133133
# Enable new experimental model design
134134
"NEW_MODEL_DESIGN":
135135
env_bool("NEW_MODEL_DESIGN", default=False),

tpu_inference/models/common/model_loader.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,12 @@ def get_vllm_model(
333333
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model
334334

335335

336+
# Architectures that require "vllm" implementation type when MODEL_IMPL_TYPE is "auto".
337+
# These architectures are listed here because they have better performance with the
338+
# vLLM PyTorch backend compared to the flax_nnx JAX backend for now.
339+
_VLLM_REQUIRED_ARCHITECTURES: frozenset[str] = frozenset({"GptOssForCausalLM"})
340+
341+
336342
def get_model(
337343
vllm_config: VllmConfig,
338344
rng: jax.Array,
@@ -342,24 +348,37 @@ def get_model(
342348
impl = envs.MODEL_IMPL_TYPE
343349
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
344350

345-
if impl == "flax_nnx":
346-
try:
347-
# Try to load the flax model first
348-
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
349-
except UnsupportedArchitectureError as e:
350-
# Convert the error message to a string to check its contents
351-
error_msg = str(e)
352-
353-
logger.warning(error_msg)
354-
355-
# Fall back to the vLLM model and updating the dtype accordingly
356-
vllm_config.model_config.dtype = j2t_dtype(
357-
vllm_config.model_config.dtype.dtype)
351+
if impl == "auto":
352+
# Resolve "auto" based on architecture
353+
architectures = getattr(vllm_config.model_config.hf_config,
354+
"architectures", [])
355+
for arch in architectures:
356+
if arch in _VLLM_REQUIRED_ARCHITECTURES:
357+
impl = "vllm"
358+
break
359+
else:
360+
impl = "flax_nnx"
361+
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
362+
363+
match impl:
364+
case "flax_nnx":
365+
try:
366+
# Try to load the flax model first
367+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
368+
except UnsupportedArchitectureError as e:
369+
# Convert the error message to a string to check its contents
370+
error_msg = str(e)
371+
372+
logger.warning(error_msg)
373+
374+
# Fall back to the vLLM model and updating the dtype accordingly
375+
vllm_config.model_config.dtype = j2t_dtype(
376+
vllm_config.model_config.dtype.dtype)
377+
return get_vllm_model(vllm_config, rng, mesh)
378+
case "vllm":
358379
return get_vllm_model(vllm_config, rng, mesh)
359-
elif impl == "vllm":
360-
return get_vllm_model(vllm_config, rng, mesh)
361-
else:
362-
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
380+
case _:
381+
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
363382

364383

365384
def _validate_model_interface(model: Any) -> None:

tpu_inference/runner/tpu_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import functools
3-
import os
43
import random
54
from contextlib import nullcontext
65
from dataclasses import dataclass
@@ -1719,7 +1718,7 @@ def _sync_weights(
17191718
shard=shard)
17201719

17211720
def get_intermediate_tensor_spec(self, num_tokens: int):
1722-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1721+
impl = envs.MODEL_IMPL_TYPE
17231722
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
17241723
num_padded_tokens = runner_utils.get_padded_token_len(
17251724
self.num_tokens_paddings, num_tokens)

0 commit comments

Comments
 (0)