Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,24 @@ def _generate_cuda_graph_batch_sizes(max_batch_size: int,
List of batch sizes to create CUDA graphs for
"""
if enable_padding:
# Start with [1, 2, 4, 8, 16, 24, ..., 128] (multiples of 8)
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
# Sliding 64: extend by increments of 64 up to max_batch_size
while batch_sizes[-1] + 64 <= max_batch_size:
batch_sizes.append(batch_sizes[-1] + 64)
else:
batch_sizes = list(range(1, 32)) + [32, 64, 128]
# Add powers of 2 up to max_batch_size
batch_sizes += [
2**i for i in range(8, math.ceil(math.log(max_batch_size, 2)))
]

# Add powers of 2 up to max_batch_size
batch_sizes += [
2**i for i in range(8, math.ceil(math.log(max_batch_size, 2)))
]

# Filter and sort batch sizes
# Filter and sort batch sizes for both branches
batch_sizes = sorted(
[size for size in batch_sizes if size <= max_batch_size])

# Add max_batch_size if not already included
if max_batch_size != batch_sizes[-1]:
if not batch_sizes or max_batch_size != batch_sizes[-1]:
batch_sizes.append(max_batch_size)

return batch_sizes
Expand Down
10 changes: 10 additions & 0 deletions tests/unittest/llmapi/test_llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,16 @@ def test_cuda_graph_batch_sizes_case_2(self):
128, True)
assert args.cuda_graph_config.max_batch_size == 128

@pytest.mark.parametrize("max_batch_size", [64, 129, 320])
def test_generate_cuda_graph_batch_sizes_padding_edge_cases(
self, max_batch_size):
# All sizes must be <= max_batch_size, sorted, and include max_batch_size
batch_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes(
max_batch_size, enable_padding=True)
assert all(s <= max_batch_size for s in batch_sizes)
assert batch_sizes == sorted(batch_sizes)
assert max_batch_size in batch_sizes


class TestTrtLlmArgs:

Expand Down
Loading