Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions language/llama3.1-8b/SUT_VLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@ def __init__(
# Set this to True *only for test accuracy runs* in case your prior
# session was killed partway through
workers=1,
tensor_parallel_size=8
tensor_parallel_size=8,
gpu_memory_utilization=0.90,
max_num_batched_tokens=None,
max_num_seqs=256,
block_size=16,
enforce_eager=False,
enable_chunked_prefill=None,
max_model_len=None,
):

self.model_path = model_path or f"meta-llama/Meta-Llama-3.1-8B-Instruct"
self.model_path = model_path or "meta-llama/Meta-Llama-3.1-8B-Instruct"

if not batch_size:
batch_size = 1
Expand All @@ -49,6 +56,15 @@ def __init__(
self.dtype = dtype
self.tensor_parallel_size = tensor_parallel_size

# Store vLLM engine config
self.gpu_memory_utilization = gpu_memory_utilization
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs
self.block_size = block_size
self.enforce_eager = enforce_eager
self.enable_chunked_prefill = enable_chunked_prefill
self.max_model_len = max_model_len

if not torch.cuda.is_available():
assert False, "torch gpu is not available, exiting..."

Expand All @@ -73,7 +89,7 @@ def __init__(
"top_k": 1,
"seed": 42,
"max_tokens": 128,
"min_tokens": 1
"min_tokens": 1,
}
self.sampling_params = SamplingParams(**gen_kwargs)
# self.sampling_params.all_stop_token_ids.add(self.model.get_tokenizer().eos_token_id)
Expand Down Expand Up @@ -162,6 +178,12 @@ def load_model(self):
self.model_path,
dtype=self.dtype,
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
block_size=self.block_size,
enforce_eager=self.enforce_eager,
enable_chunked_prefill=self.enable_chunked_prefill,
)
log.info("Loaded model")

Expand Down Expand Up @@ -203,7 +225,14 @@ def __init__(
dataset_path=None,
batch_size=None,
workers=1,
tensor_parallel_size=8
tensor_parallel_size=8,
gpu_memory_utilization=0.90,
max_num_batched_tokens=None,
max_num_seqs=256,
block_size=16,
enforce_eager=False,
enable_chunked_prefill=None,
max_model_len=None,
):

super().__init__(
Expand All @@ -213,6 +242,13 @@ def __init__(
dataset_path=dataset_path,
workers=workers,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
block_size=block_size,
enforce_eager=enforce_eager,
enable_chunked_prefill=enable_chunked_prefill,
max_model_len=max_model_len,
)
self.request_id = 0

Expand Down Expand Up @@ -287,10 +323,18 @@ def stop(self):
self.ft_response_thread.join()

def load_model(self):
log.info("Loading model")
log.info("Loading model...")
self.engine_args = AsyncEngineArgs(
self.model_path,
dtype=self.dtype,
tensor_parallel_size=self.tensor_parallel_size)
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=self.max_model_len,
block_size=self.block_size,
enforce_eager=self.enforce_eager,
enable_chunked_prefill=self.enable_chunked_prefill,
)
self.model = AsyncLLMEngine.from_engine_args(self.engine_args)
log.info("Loaded model")
54 changes: 52 additions & 2 deletions language/llama3.1-8b/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_args():
"--tensor-parallel-size",
type=int,
default=8,
help="Number of workers to process queries",
help="Number of tensor parallel GPUs",
)
parser.add_argument("--vllm", action="store_true", help="vllm mode")
parser.add_argument(
Expand All @@ -127,6 +127,49 @@ def get_args():
help="Model name(specified in llm server)",
)

parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.90,
help="Fraction of GPU memory for vLLM to use (default: 0.90)",
)
parser.add_argument(
"--max-num-batched-tokens",
type=int,
default=None,
help="Max tokens in a single batch (default: vLLM engine default)",
)
parser.add_argument(
"--max-num-seqs",
type=int,
default=256,
help="Max concurrent sequences (default: 256)",
)
parser.add_argument(
"--block-size",
type=int,
default=16,
help="KV cache block size (default: 16)",
)
parser.add_argument(
"--enforce-eager",
action=argparse.BooleanOptionalAction,
default=False,
help="Use eager mode instead of CUDA graphs (default: disabled)",
)
parser.add_argument(
"--enable-chunked-prefill",
action=argparse.BooleanOptionalAction,
default=None,
help="Enable chunked prefill (default: vLLM engine default)",
)
parser.add_argument(
"--max-model-len",
type=int,
default=None,
help="Max model context length (default: vLLM engine default)",
)

args = parser.parse_args()
return args

Expand Down Expand Up @@ -177,7 +220,14 @@ def main():
dataset_path=args.dataset_path,
total_sample_count=args.total_sample_count,
workers=args.num_workers,
tensor_parallel_size=args.tensor_parallel_size
tensor_parallel_size=args.tensor_parallel_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_seqs=args.max_num_seqs,
block_size=args.block_size,
enforce_eager=args.enforce_eager,
enable_chunked_prefill=args.enable_chunked_prefill,
max_model_len=args.max_model_len
)
else:
sut = sut_cls(
Expand Down
Loading