diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 2cbf5635a07..120c42dbd2c 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -694,6 +694,7 @@ def _merge_helix_requests(self, new_requests: list[RequestQueueItem], position_ids=position_ids_this_rank, ) req.total_input_len_cp = input_len + req.seqlen_this_rank_cp = len(input_ids_this_rank) req_with_children.append(req) if req.child_requests: req_with_children.extend(req.child_requests) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 28314382564..5f81b94a013 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -489,6 +489,8 @@ def __init__( self.py_max_new_tokens = self.max_new_tokens self.py_min_length = self.sampling_config.min_length self.py_helix_is_inactive_rank = False + self.seqlen_this_rank_cp = 0 + self.total_input_len_cp = 0 self.py_batch_idx = None self.py_draft_pages_allocated = 0 self.py_rewind_len = 0 diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index aaac2256c90..6c45c00361b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -568,13 +568,12 @@ def warmup(self, resource_manager: ResourceManager) -> None: # Reset the global cuda graph dummy request to None in warmup. self.cuda_graph_runner.padding_dummy_request = None - cp_type = self.mapping.cp_config.get('cp_type', None) - if cp_type is not None: - if cp_type in [CpType.ULYSSES, CpType.STAR]: - logger.info( - "[ModelEngine::warmup] Skipping warmup for cp_type: ", - cp_type.name) - return + if self.mapping.cp_size > 1: + cp_type = self.mapping.cp_config.get("cp_type", None) + logger.info( + f"[ModelEngine::warmup] Skipping warmup for cp_type: {None if cp_type is None else cp_type.name}." + ) + return self._run_torch_compile_warmup(resource_manager) self._run_autotuner_warmup(resource_manager) @@ -1671,12 +1670,12 @@ def _prepare_tp_inputs( # Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called. if not self.is_warmup and not request.is_cuda_graph_dummy: position_id = request.total_input_len_cp + request.py_decoding_iter - 1 - # TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix. - if self.mapping.cp_rank == self.mapping.cp_size - 1: - past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1 + if request.py_helix_is_inactive_rank: + past_seen_token_num = request.seqlen_this_rank_cp else: - # past_seen_token_num doesn't grow on inactive ranks. - past_seen_token_num = request.orig_prompt_len + # Discount the token added to active rank in resource manager as it hasn't + # been previously seen. + past_seen_token_num = request.seqlen_this_rank_cp - 1 position_ids.append(position_id) num_cached_tokens_per_seq.append(past_seen_token_num) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index a70b35dfcfb..bd1d197786a 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -468,13 +468,17 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): req, block_ids) for req in generation_batch: - # TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix. if self.mapping.has_cp_helix(): - if self.mapping.cp_rank != self.mapping.cp_size - 1: + # Distribute the decode blocks across CP ranks in a round-robin manner. + decode_block_id = (req.py_decoding_iter - + 1) // self.tokens_per_block + if decode_block_id % self.mapping.cp_size == self.mapping.cp_rank: + req.py_helix_is_inactive_rank = False + req.seqlen_this_rank_cp += 1 + else: req.py_helix_is_inactive_rank = True - # Skip allocating KV cache at decode for inactive helix ranks. - if req.py_helix_is_inactive_rank: - continue + # Skip allocating KV cache at decode for inactive helix ranks. + continue self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 7e58d0f5007..bd9dbaed3be 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -519,6 +519,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 04a4278ba6f..06be8963820 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -221,3 +221,4 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2 + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix diff --git a/tests/unittest/_torch/executor/test_pytorch_model_engine.py b/tests/unittest/_torch/executor/test_pytorch_model_engine.py index ca75cbb3517..76f72629301 100644 --- a/tests/unittest/_torch/executor/test_pytorch_model_engine.py +++ b/tests/unittest/_torch/executor/test_pytorch_model_engine.py @@ -407,6 +407,7 @@ def test_prepare_tp_inputs_with_helix_parallelism(self) -> None: req.sampling_config.beam_width = 1 req.py_multimodal_data = {} req.total_input_len_cp = prompt_lens[idx] * 2 + req.seqlen_this_rank_cp = prompt_lens[idx] req.py_decoding_iter = 1 gen_requests.append(req) scheduled_requests.generation_requests = gen_requests