Skip to content

Commit 9724d4a

Browse files
committed
fix
Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent eb6c2d1 commit 9724d4a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
# Architectures that prefer "vllm" implementation type when MODEL_IMPL_TYPE is "auto".
2828
# These architectures are listed here because they have better performance with the
2929
# vLLM PyTorch backend compared to the flax_nnx JAX backend for now.
30-
_VLLM_REQUIRED_ARCHITECTURES: frozenset[str] = frozenset({"GptOssForCausalLM"})
30+
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
31+
{"GptOssForCausalLM"})
3132

3233

3334
class UnsupportedArchitectureError(ValueError):
@@ -355,7 +356,7 @@ def get_model(
355356
f"Expected exactly one architecture, got {len(architectures)}: "
356357
f"{architectures}")
357358
arch = architectures[0]
358-
impl = "vllm" if arch in _VLLM_REQUIRED_ARCHITECTURES else "flax_nnx"
359+
impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
359360
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
360361

361362
match impl:

0 commit comments

Comments
 (0)