Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test to cover this bug fix? If possible, add a test to verify the hang is fixed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some tests in post_merge CI, and those tests hang stably. Let me add one of them to L0.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import traceback
from contextlib import contextmanager
from enum import IntEnum
from queue import Queue
from queue import Empty, Queue
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -1572,6 +1572,16 @@ def handle_executed_batches(executed_batch_num: int):
self._handle_executed_batch(executed_batch)
self.unhandled_batch_counter -= 1

def _get_executed_batch(self):
while True:
try:
return self.executed_batch_queue.get(timeout=0.001)
except Empty:
# Calling MPI_Test on pending isend handles while idle to prevent potential hangs.
for handle in self.send_handles:
if handle is not None:
handle.test()

def _broadcast_sample_state_loop(self):
logger.debug(
f"Starting broadcast sample state loop for pp_rank {self.dist.pp_rank}"
Expand All @@ -1588,17 +1598,10 @@ def _broadcast_sample_state_loop(self):
new_mpi_comm = mpi_comm().Dup()
set_thread_local_mpi_comm(new_mpi_comm)
while True:
executed_batch = self.executed_batch_queue.get()
executed_batch = self._get_executed_batch()
if executed_batch is None:
break
self._ring_broadcast_sample_state(executed_batch)
# Flush the last isend before this thread goes idle on
# queue.get() — otherwise no MPI call will be made to drive
# progress and the non-blocking send data will never reach
# the receiver, causing a deadlock.
if self.executed_batch_queue.empty():
self.wait_on_pp_send_handles(self.send_handles,
executed_batch.microbatch_id)
set_thread_local_mpi_comm(None)
new_mpi_comm.Free()

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ l0_dgx_b200:
- unittest/_torch/visual_gen/test_wan_i2v.py::TestWanI2VCombinedOptimizations::test_all_optimizations_combined
- unittest/_torch/visual_gen/test_flux_pipeline.py::TestFluxParallelism::test_ulysses_2gpu_correctness
- unittest/_torch/visual_gen/test_flux_pipeline.py::TestFluxCombinedOptimizations::test_all_optimizations_combined
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- condition:
ranges:
system_gpu_count:
Expand Down Expand Up @@ -212,7 +213,6 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=2]
Expand Down
2 changes: 0 additions & 2 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,6 @@ examples/test_visual_gen.py::test_vbench_dimension_score_wan SKIP (https://nvbug
examples/test_visual_gen.py::test_vbench_dimension_score_wan22_a14b_fp8 SKIP (https://nvbugs/6050483)
visual_gen/test_visual_gen_benchmark.py::test_offline_benchmark SKIP (https://nvbugs/6050483)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/6050487)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489)
disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp1-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057459)
disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp4-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057460)
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5844149)
Expand Down
Loading