Skip to content

Commit 0f1a96c

Browse files
committed
the barrier for TRT-LLM installation
1 parent 7721953 commit 0f1a96c

File tree

1 file changed

+45
-31
lines changed

1 file changed

+45
-31
lines changed

py/torch_tensorrt/_utils.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,55 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
148148
)
149149

150150

151+
def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
152+
# this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM
153+
from torch.distributed import barrier, get_rank, is_initialized
154+
155+
if not is_initialized():
156+
# Single process case, just unzip
157+
is_master = True
158+
else:
159+
is_master = get_rank() == 0 # only rank 0 does the unzip
160+
161+
if is_master:
162+
try:
163+
import zipfile
164+
except ImportError as e:
165+
raise ImportError(
166+
"zipfile module is required but not found. Please install zipfile"
167+
)
168+
try:
169+
with zipfile.ZipFile(wheel_path) as zip_ref:
170+
zip_ref.extractall(extract_dir)
171+
logger.debug(f"Extracted wheel to {extract_dir}")
172+
173+
except FileNotFoundError as e:
174+
# This should capture the errors in the download failure above
175+
logger.error(f"Wheel file not found at {wheel_path}: {e}")
176+
raise RuntimeError(
177+
f"Failed to find downloaded wheel file at {wheel_path}"
178+
) from e
179+
except zipfile.BadZipFile as e:
180+
logger.error(f"Invalid or corrupted wheel file: {e}")
181+
raise RuntimeError(
182+
"Downloaded wheel file is corrupted or not a valid zip archive"
183+
) from e
184+
except Exception as e:
185+
logger.error(f"Unexpected error while extracting wheel: {e}")
186+
raise RuntimeError(
187+
"Unexpected error during extraction of TensorRT-LLM wheel"
188+
) from e
189+
190+
# Make sure others wait until unzip is done
191+
if is_initialized():
192+
barrier()
193+
194+
151195
def download_and_get_plugin_lib_path() -> Optional[str]:
152196
"""
153197
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
154-
155198
Args:
156199
platform (str): Platform identifier (e.g., 'linux_x86_64')
157-
158200
Returns:
159201
Optional[str]: Path to shared library or None if operation fails.
160202
"""
@@ -199,32 +241,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
199241
except OSError as e:
200242
logger.error(f"Local file write error: {e}")
201243

202-
try:
203-
import zipfile
204-
except ImportError as e:
205-
raise ImportError(
206-
"zipfile module is required but not found. Please install zipfile"
207-
)
208-
try:
209-
with zipfile.ZipFile(wheel_path) as zip_ref:
210-
zip_ref.extractall(extract_dir)
211-
logger.debug(f"Extracted wheel to {extract_dir}")
212-
except FileNotFoundError as e:
213-
# This should capture the errors in the download failure above
214-
logger.error(f"Wheel file not found at {wheel_path}: {e}")
215-
raise RuntimeError(
216-
f"Failed to find downloaded wheel file at {wheel_path}"
217-
) from e
218-
except zipfile.BadZipFile as e:
219-
logger.error(f"Invalid or corrupted wheel file: {e}")
220-
raise RuntimeError(
221-
"Downloaded wheel file is corrupted or not a valid zip archive"
222-
) from e
223-
except Exception as e:
224-
logger.error(f"Unexpected error while extracting wheel: {e}")
225-
raise RuntimeError(
226-
"Unexpected error during extraction of TensorRT-LLM wheel"
227-
) from e
244+
extract_wheel_file(wheel_path, extract_dir)
228245

229246
try:
230247
wheel_path.unlink(missing_ok=True)
@@ -243,10 +260,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
243260
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
244261
"""
245262
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
246-
247263
Args:
248264
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
249-
250265
Returns:
251266
bool: True if successful, False otherwise.
252267
"""
@@ -298,7 +313,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
298313
Attempts to load the TensorRT-LLM plugin and initialize it.
299314
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
300315
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
301-
302316
Returns:
303317
bool: True if the plugin was successfully loaded and initialized, False otherwise.
304318
"""

0 commit comments

Comments
 (0)