Skip to content

Commit adc343e

Browse files
committed
Add 'auto' MODEL_IMPL_TYPE that resolves based on architecture
- Add 'auto' as default value for MODEL_IMPL_TYPE env var - For GptOssForCausalLM, 'auto' resolves to 'vllm' for better performance - For all other architectures, 'auto' resolves to 'flax_nnx' - Add _VLLM_REQUIRED_ARCHITECTURES frozenset in model_loader.py - Use match/case pattern in get_model() for implementation selection - Add tests for 'auto' resolution behavior Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent 5150ed2 commit adc343e

File tree

5 files changed

+81
-23
lines changed

5 files changed

+81
-23
lines changed

tests/models/common/test_model_loader.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,43 @@ def test_get_model_not_implemented(self, mock_get_flax, mock_get_vllm,
381381

382382
mock_get_flax.assert_not_called()
383383
mock_get_vllm.assert_not_called()
384+
385+
@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
386+
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
387+
@patch("tpu_inference.models.common.model_loader.get_flax_model")
388+
def test_get_model_auto_resolves_to_flax_nnx(self, mock_get_flax,
389+
mock_get_vllm, vllm_config,
390+
rng, mesh):
391+
"""
392+
Tests that 'auto' resolves to 'flax_nnx' for standard architectures
393+
(not in _VLLM_REQUIRED_ARCHITECTURES).
394+
"""
395+
# vllm_config uses Qwen3 which is NOT in _VLLM_REQUIRED_ARCHITECTURES
396+
mock_get_flax.return_value = "flax_model_sentinel"
397+
398+
result = model_loader.get_model(vllm_config, rng, mesh)
399+
400+
mock_get_flax.assert_called_once_with(vllm_config, rng, mesh, False)
401+
mock_get_vllm.assert_not_called()
402+
assert result == "flax_model_sentinel"
403+
404+
@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
405+
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
406+
@patch("tpu_inference.models.common.model_loader.get_flax_model")
407+
def test_get_model_auto_resolves_to_vllm_for_gpt_oss(
408+
self, mock_get_flax, mock_get_vllm, vllm_config, rng, mesh):
409+
"""
410+
Tests that 'auto' resolves to 'vllm' for architectures in
411+
_VLLM_REQUIRED_ARCHITECTURES (e.g., GptOssForCausalLM).
412+
"""
413+
# Mock the architecture to be GptOssForCausalLM
414+
vllm_config.model_config.hf_config.architectures = [
415+
"GptOssForCausalLM"
416+
]
417+
mock_get_vllm.return_value = "vllm_model_sentinel"
418+
419+
result = model_loader.get_model(vllm_config, rng, mesh)
420+
421+
mock_get_flax.assert_not_called()
422+
mock_get_vllm.assert_called_once_with(vllm_config, rng, mesh)
423+
assert result == "vllm_model_sentinel"

tests/test_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
179179

180180
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
181181
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
182-
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
182+
assert envs.MODEL_IMPL_TYPE == "auto"
183183

184184

185185
def test_cache_preserves_values_across_env_changes(

tpu_inference/envs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DECODE_SLICES: str = ""
1717
SKIP_JAX_PRECOMPILE: bool = False
1818
VLLM_XLA_CHECK_RECOMPILATION: bool = False
19-
MODEL_IMPL_TYPE: str = "flax_nnx"
19+
MODEL_IMPL_TYPE: str = "auto"
2020
NEW_MODEL_DESIGN: bool = False
2121
PHASED_PROFILING_DIR: str = ""
2222
PYTHON_TRACER_LEVEL: int = 1
@@ -99,8 +99,8 @@ def _get_validated_env() -> str | None:
9999
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
100100
# Model implementation type (e.g., "flax_nnx")
101101
"MODEL_IMPL_TYPE":
102-
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103-
["vllm", "flax_nnx", "jetpack"]),
102+
env_with_choices("MODEL_IMPL_TYPE", "auto",
103+
["auto", "vllm", "flax_nnx", "jetpack"]),
104104
# Enable new experimental model design
105105
"NEW_MODEL_DESIGN":
106106
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),

tpu_inference/models/common/model_loader.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ def get_vllm_model(
330330
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model
331331

332332

333+
# Architectures that require "vllm" implementation type when MODEL_IMPL_TYPE is "auto".
334+
# These architectures are listed here because they have better performance with the
335+
# vLLM PyTorch backend compared to the flax_nnx JAX backend for now.
336+
_VLLM_REQUIRED_ARCHITECTURES: frozenset[str] = frozenset({"GptOssForCausalLM"})
337+
338+
333339
def get_model(
334340
vllm_config: VllmConfig,
335341
rng: jax.Array,
@@ -339,24 +345,37 @@ def get_model(
339345
impl = envs.MODEL_IMPL_TYPE
340346
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
341347

342-
if impl == "flax_nnx":
343-
try:
344-
# Try to load the flax model first
345-
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
346-
except UnsupportedArchitectureError as e:
347-
# Convert the error message to a string to check its contents
348-
error_msg = str(e)
349-
350-
logger.warning(error_msg)
351-
352-
# Fall back to the vLLM model and updating the dtype accordingly
353-
vllm_config.model_config.dtype = j2t_dtype(
354-
vllm_config.model_config.dtype.dtype)
348+
if impl == "auto":
349+
# Resolve "auto" based on architecture
350+
architectures = getattr(vllm_config.model_config.hf_config,
351+
"architectures", [])
352+
for arch in architectures:
353+
if arch in _VLLM_REQUIRED_ARCHITECTURES:
354+
impl = "vllm"
355+
break
356+
else:
357+
impl = "flax_nnx"
358+
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
359+
360+
match impl:
361+
case "flax_nnx":
362+
try:
363+
# Try to load the flax model first
364+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
365+
except UnsupportedArchitectureError as e:
366+
# Convert the error message to a string to check its contents
367+
error_msg = str(e)
368+
369+
logger.warning(error_msg)
370+
371+
# Fall back to the vLLM model and updating the dtype accordingly
372+
vllm_config.model_config.dtype = j2t_dtype(
373+
vllm_config.model_config.dtype.dtype)
374+
return get_vllm_model(vllm_config, rng, mesh)
375+
case "vllm":
355376
return get_vllm_model(vllm_config, rng, mesh)
356-
elif impl == "vllm":
357-
return get_vllm_model(vllm_config, rng, mesh)
358-
else:
359-
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
377+
case _:
378+
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
360379

361380

362381
def _validate_model_interface(model: Any) -> None:

tpu_inference/runner/tpu_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import functools
3-
import os
43
import random
54
from contextlib import nullcontext
65
from dataclasses import dataclass
@@ -1712,7 +1711,7 @@ def _sync_weights(
17121711
shard=shard)
17131712

17141713
def get_intermediate_tensor_spec(self, num_tokens: int):
1715-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1714+
impl = envs.MODEL_IMPL_TYPE
17161715
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
17171716
num_padded_tokens = runner_utils.get_padded_token_len(
17181717
self.num_tokens_paddings, num_tokens)

0 commit comments

Comments
 (0)