Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ def _merge_helix_requests(self, new_requests: list[RequestQueueItem],
position_ids=position_ids_this_rank,
)
req.total_input_len_cp = input_len
req.seqlen_this_rank_cp = len(input_ids_this_rank)
req_with_children.append(req)
if req.child_requests:
req_with_children.extend(req.child_requests)
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ def __init__(
self.py_max_new_tokens = self.max_new_tokens
self.py_min_length = self.sampling_config.min_length
self.py_helix_is_inactive_rank = False
self.seqlen_this_rank_cp = 0
self.total_input_len_cp = 0
self.py_batch_idx = None
self.py_draft_pages_allocated = 0
self.py_rewind_len = 0
Expand Down
23 changes: 11 additions & 12 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,13 +568,12 @@ def warmup(self, resource_manager: ResourceManager) -> None:
# Reset the global cuda graph dummy request to None in warmup.
self.cuda_graph_runner.padding_dummy_request = None

cp_type = self.mapping.cp_config.get('cp_type', None)
if cp_type is not None:
if cp_type in [CpType.ULYSSES, CpType.STAR]:
logger.info(
"[ModelEngine::warmup] Skipping warmup for cp_type: ",
cp_type.name)
return
if self.mapping.cp_size > 1:
cp_type = self.mapping.cp_config.get("cp_type", None)
logger.info(
f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}."
)
return

self._run_torch_compile_warmup(resource_manager)
self._run_autotuner_warmup(resource_manager)
Expand Down Expand Up @@ -1671,12 +1670,12 @@ def _prepare_tp_inputs(
# Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called.
if not self.is_warmup and not request.is_cuda_graph_dummy:
position_id = request.total_input_len_cp + request.py_decoding_iter - 1
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
if self.mapping.cp_rank == self.mapping.cp_size - 1:
past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1
if request.py_helix_is_inactive_rank:
past_seen_token_num = request.seqlen_this_rank_cp
else:
# past_seen_token_num doesn't grow on inactive ranks.
past_seen_token_num = request.orig_prompt_len
# Discount the token added to active rank in resource manager as it hasn't
# been previously seen.
past_seen_token_num = request.seqlen_this_rank_cp - 1

position_ids.append(position_id)
num_cached_tokens_per_seq.append(past_seen_token_num)
Expand Down
14 changes: 9 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,17 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
req, block_ids)

for req in generation_batch:
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
if self.mapping.has_cp_helix():
if self.mapping.cp_rank != self.mapping.cp_size - 1:
# Distribute the decode blocks across CP ranks in a round-robin manner.
decode_block_id = (req.py_decoding_iter -
1) // self.tokens_per_block
if decode_block_id % self.mapping.cp_size == self.mapping.cp_rank:
req.py_helix_is_inactive_rank = False
req.seqlen_this_rank_cp += 1
else:
req.py_helix_is_inactive_rank = True
# Skip allocating KV cache at decode for inactive helix ranks.
if req.py_helix_is_inactive_rank:
continue
# Skip allocating KV cache at decode for inactive helix ranks.
continue
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,4 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def test_prepare_tp_inputs_with_helix_parallelism(self) -> None:
req.sampling_config.beam_width = 1
req.py_multimodal_data = {}
req.total_input_len_cp = prompt_lens[idx] * 2
req.seqlen_this_rank_cp = prompt_lens[idx]
req.py_decoding_iter = 1
gen_requests.append(req)
scheduled_requests.generation_requests = gen_requests
Expand Down