diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9f001b4e5ae..2eb6da5f614 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -167,21 +167,24 @@ def _generate_cuda_graph_batch_sizes(max_batch_size: int, List of batch sizes to create CUDA graphs for """ if enable_padding: + # Start with [1, 2, 4, 8, 16, 24, ..., 128] (multiples of 8) batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)] + # Sliding 64: extend by increments of 64 up to max_batch_size + while batch_sizes[-1] + 64 <= max_batch_size: + batch_sizes.append(batch_sizes[-1] + 64) else: batch_sizes = list(range(1, 32)) + [32, 64, 128] + # Add powers of 2 up to max_batch_size + batch_sizes += [ + 2**i for i in range(8, math.ceil(math.log(max_batch_size, 2))) + ] - # Add powers of 2 up to max_batch_size - batch_sizes += [ - 2**i for i in range(8, math.ceil(math.log(max_batch_size, 2))) - ] - - # Filter and sort batch sizes + # Filter and sort batch sizes for both branches batch_sizes = sorted( [size for size in batch_sizes if size <= max_batch_size]) # Add max_batch_size if not already included - if max_batch_size != batch_sizes[-1]: + if not batch_sizes or max_batch_size != batch_sizes[-1]: batch_sizes.append(max_batch_size) return batch_sizes diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 999edbcdcde..94d7b1cddaf 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -591,6 +591,16 @@ def test_cuda_graph_batch_sizes_case_2(self): 128, True) assert args.cuda_graph_config.max_batch_size == 128 + @pytest.mark.parametrize("max_batch_size", [64, 129, 320]) + def test_generate_cuda_graph_batch_sizes_padding_edge_cases( + self, max_batch_size): + # All sizes must be <= max_batch_size, sorted, and include max_batch_size + batch_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes( + max_batch_size, enable_padding=True) + assert all(s <= max_batch_size for s in batch_sizes) + assert batch_sizes == sorted(batch_sizes) + assert max_batch_size in batch_sizes + class TestTrtLlmArgs: