Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/models/common/test_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,43 @@ def test_get_model_not_implemented(self, mock_get_flax, mock_get_vllm,

mock_get_flax.assert_not_called()
mock_get_vllm.assert_not_called()

@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
@patch("tpu_inference.models.common.model_loader.get_flax_model")
def test_get_model_auto_resolves_to_flax_nnx(self, mock_get_flax,
mock_get_vllm, vllm_config,
rng, mesh):
"""
Tests that 'auto' resolves to 'flax_nnx' for standard architectures
(not in _VLLM_REQUIRED_ARCHITECTURES).
"""
# vllm_config uses Qwen3 which is NOT in _VLLM_REQUIRED_ARCHITECTURES
mock_get_flax.return_value = "flax_model_sentinel"

result = model_loader.get_model(vllm_config, rng, mesh)

mock_get_flax.assert_called_once_with(vllm_config, rng, mesh, False)
mock_get_vllm.assert_not_called()
assert result == "flax_model_sentinel"

@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
@patch("tpu_inference.models.common.model_loader.get_flax_model")
def test_get_model_auto_resolves_to_vllm_for_gpt_oss(
self, mock_get_flax, mock_get_vllm, vllm_config, rng, mesh):
"""
Tests that 'auto' resolves to 'vllm' for architectures in
_VLLM_REQUIRED_ARCHITECTURES (e.g., GptOssForCausalLM).
"""
# Mock the architecture to be GptOssForCausalLM
vllm_config.model_config.hf_config.architectures = [
"GptOssForCausalLM"
]
mock_get_vllm.return_value = "vllm_model_sentinel"

result = model_loader.get_model(vllm_config, rng, mesh)

mock_get_flax.assert_not_called()
mock_get_vllm.assert_called_once_with(vllm_config, rng, mesh)
assert result == "vllm_model_sentinel"
2 changes: 1 addition & 1 deletion tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):

def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
assert envs.MODEL_IMPL_TYPE == "auto"


def test_cache_preserves_values_across_env_changes(
Expand Down
6 changes: 3 additions & 3 deletions tpu_inference/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DECODE_SLICES: str = ""
SKIP_JAX_PRECOMPILE: bool = False
VLLM_XLA_CHECK_RECOMPILATION: bool = False
MODEL_IMPL_TYPE: str = "flax_nnx"
MODEL_IMPL_TYPE: str = "auto"
NEW_MODEL_DESIGN: bool = False
PHASED_PROFILING_DIR: str = ""
PYTHON_TRACER_LEVEL: int = 1
Expand Down Expand Up @@ -128,8 +128,8 @@ def _get_bool_env() -> bool:
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
# Model implementation type (e.g., "flax_nnx")
"MODEL_IMPL_TYPE":
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
["vllm", "flax_nnx", "jetpack"]),
env_with_choices("MODEL_IMPL_TYPE", "auto",
["auto", "vllm", "flax_nnx", "jetpack"]),
# Enable new experimental model design
"NEW_MODEL_DESIGN":
env_bool("NEW_MODEL_DESIGN", default=False),
Expand Down
52 changes: 35 additions & 17 deletions tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@

_MODEL_REGISTRY = {}

# Architectures that prefer "vllm" implementation type when MODEL_IMPL_TYPE is "auto".
# These architectures are listed here because they have better performance with the
# vLLM PyTorch backend compared to the flax_nnx JAX backend for now.
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
{"GptOssForCausalLM"})


class UnsupportedArchitectureError(ValueError):
"""Raised when a model architecture is not supported in the registry."""
Expand Down Expand Up @@ -342,24 +348,36 @@ def get_model(
impl = envs.MODEL_IMPL_TYPE
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")

if impl == "flax_nnx":
try:
# Try to load the flax model first
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
except UnsupportedArchitectureError as e:
# Convert the error message to a string to check its contents
error_msg = str(e)

logger.warning(error_msg)

# Fall back to the vLLM model and updating the dtype accordingly
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)
if impl == "auto":
# Resolve "auto" based on architecture
architectures = getattr(vllm_config.model_config.hf_config,
"architectures", [])
assert len(architectures) == 1, (
f"Expected exactly one architecture, got {len(architectures)}: "
f"{architectures}")
arch = architectures[0]
impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")

match impl:
case "flax_nnx":
try:
# Try to load the flax model first
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
except UnsupportedArchitectureError as e:
# Convert the error message to a string to check its contents
error_msg = str(e)

logger.warning(error_msg)

# Fall back to the vLLM model and updating the dtype accordingly
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)
return get_vllm_model(vllm_config, rng, mesh)
case "vllm":
return get_vllm_model(vllm_config, rng, mesh)
elif impl == "vllm":
return get_vllm_model(vllm_config, rng, mesh)
else:
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
case _:
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")


def _validate_model_interface(model: Any) -> None:
Expand Down
3 changes: 1 addition & 2 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import functools
import os
import random
from contextlib import nullcontext
from dataclasses import dataclass
Expand Down Expand Up @@ -1719,7 +1718,7 @@ def _sync_weights(
shard=shard)

def get_intermediate_tensor_spec(self, num_tokens: int):
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
impl = envs.MODEL_IMPL_TYPE
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
num_padded_tokens = runner_utils.get_padded_token_len(
self.num_tokens_paddings, num_tokens)
Expand Down