@@ -69,6 +69,34 @@ def _get_validated_env() -> str | None:
6969 return _get_validated_env
7070
7171
72+ def env_bool (env_name : str , default : bool = False ) -> Callable [[], bool ]:
73+ """
74+ Accepts both numeric strings ("0", "1") and boolean strings
75+ ("true", "false", "True", "False").
76+
77+ Args:
78+ env_name: Name of the environment variable
79+ default: Default boolean value if not set
80+ """
81+
82+ def _get_bool_env () -> bool :
83+ value = os .getenv (env_name )
84+ if value is None or value == "" :
85+ return default
86+
87+ value_lower = value .lower ()
88+ if value_lower in ("true" , "1" ):
89+ return True
90+ elif value_lower in ("false" , "0" ):
91+ return False
92+ else :
93+ raise ValueError (
94+ f"Invalid boolean value '{ value } ' for { env_name } . "
95+ f"Valid options: '0', '1', 'true', 'false', 'True', 'False'." )
96+
97+ return _get_bool_env
98+
99+
72100environment_variables : dict [str , Callable [[], Any ]] = {
73101 # JAX platform selection (e.g., "tpu", "cpu", "proxy")
74102 "JAX_PLATFORMS" :
@@ -93,17 +121,17 @@ def _get_validated_env() -> str | None:
93121 lambda : os .getenv ("DECODE_SLICES" , "" ),
94122 # Skip JAX precompilation step during initialization
95123 "SKIP_JAX_PRECOMPILE" :
96- lambda : bool ( int ( os . getenv ( "SKIP_JAX_PRECOMPILE" ) or "0" ) ),
124+ env_bool ( "SKIP_JAX_PRECOMPILE" , default = False ),
97125 # Check for XLA recompilation during execution
98126 "VLLM_XLA_CHECK_RECOMPILATION" :
99- lambda : bool ( int ( os . getenv ( "VLLM_XLA_CHECK_RECOMPILATION" ) or "0" ) ),
127+ env_bool ( "VLLM_XLA_CHECK_RECOMPILATION" , default = False ),
100128 # Model implementation type (e.g., "flax_nnx")
101129 "MODEL_IMPL_TYPE" :
102130 env_with_choices ("MODEL_IMPL_TYPE" , "flax_nnx" ,
103131 ["vllm" , "flax_nnx" , "jetpack" ]),
104132 # Enable new experimental model design
105133 "NEW_MODEL_DESIGN" :
106- lambda : bool ( int ( os . getenv ( "NEW_MODEL_DESIGN" ) or "0" ) ),
134+ env_bool ( "NEW_MODEL_DESIGN" , default = False ),
107135 # Directory to store phased profiling output
108136 "PHASED_PROFILING_DIR" :
109137 lambda : os .getenv ("PHASED_PROFILING_DIR" , "" ),
@@ -112,7 +140,7 @@ def _get_validated_env() -> str | None:
112140 lambda : int (os .getenv ("PYTHON_TRACER_LEVEL" ) or "1" ),
113141 # Use custom expert-parallel kernel for MoE (Mixture of Experts)
114142 "USE_MOE_EP_KERNEL" :
115- lambda : bool ( int ( os . getenv ( "USE_MOE_EP_KERNEL" ) or "0" ) ),
143+ env_bool ( "USE_MOE_EP_KERNEL" , default = False ),
116144 # Number of TPU slices for multi-slice mesh
117145 "NUM_SLICES" :
118146 lambda : int (os .getenv ("NUM_SLICES" ) or "1" ),
0 commit comments