Skip to content

Commit 3455ed1

Browse files
committed
not for land: vllm bench for mxfp8
Summary: ``` with-proxy python benchmarks/mx_formats/vllm/create_quantized_hf_model.py ~/local/tmp/20251203_test_model_mxfp8 with-proxy vllm bench throughput --model ~/local/tmp/20251203_test_model_mxfp8/ --dataset-name sonnet --dataset-path ~/local/vllm/benchmarks/sonnet.txt --num-prompts 1024 --tensor-parallel-size 1 --max-model-len 2048 --gpu-memory-utilization 0.8 ``` currently fails with compile error (PyTorch 2.9) Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 620d504 ghstack-comment-id: 3608599132 Pull-Request: #3426
1 parent 16aad7c commit 3455ed1

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Create a quantized `meta-llama/Meta-Llama-3.1-8B-Instruct` model and save
9+
it to disk for local benchmarking with `vllm`.
10+
"""
11+
12+
import argparse
13+
import random
14+
from pathlib import Path
15+
16+
import numpy as np
17+
import torch
18+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
19+
20+
from torchao.prototype.mx_formats.inference_workflow import (
21+
MXDynamicActivationMXWeightConfig,
22+
)
23+
24+
25+
# Set seeds for reproducibility
26+
def set_seed(seed):
27+
random.seed(seed)
28+
np.random.seed(seed)
29+
torch.manual_seed(seed)
30+
torch.cuda.manual_seed_all(seed)
31+
32+
33+
def parse_args():
34+
parser = argparse.ArgumentParser(description="Quantize a model with TorchAO")
35+
parser.add_argument(
36+
"output_dir",
37+
type=str,
38+
help="Directory to save the quantized model",
39+
)
40+
return parser.parse_args()
41+
42+
43+
def main(args):
44+
"""
45+
Args:
46+
args: Parsed command line arguments containing:
47+
output_dir: Directory to save the quantized model
48+
max_new_tokens: Max tokens to generate for testing
49+
convert_llama_4_expert_weights_to_mnk: if True, converts LLaMa 4 Scout expert weights from MKN to MNK memory layout
50+
no_save_model_to_disk: if True, skips saving quantized model to local disk
51+
no_load_model_from_disk: if True, skips reloading model from disk to test it again
52+
"""
53+
54+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
55+
device_map = "auto"
56+
max_new_tokens = 20
57+
58+
# Test prompts
59+
prompts = [
60+
"Why is Pytorch 2.0 the best machine learning compiler?",
61+
]
62+
63+
# Set seed before creating the model
64+
set_seed(42)
65+
66+
# Create output directory
67+
output_dir = Path(args.output_dir)
68+
output_dir.mkdir(parents=True, exist_ok=True)
69+
70+
# Get quantization config
71+
# quantization_config = TorchAoConfig(Float8DynamicActivationFloat8WeightConfig())
72+
quantization_config = TorchAoConfig(
73+
MXDynamicActivationMXWeightConfig(
74+
activation_dtype=torch.float8_e4m3fn,
75+
weight_dtype=torch.float8_e4m3fn,
76+
)
77+
)
78+
79+
# Load tokenizer
80+
tokenizer = AutoTokenizer.from_pretrained(model_name)
81+
82+
# Load and quantize model
83+
print("Loading and quantizing model...")
84+
quantized_model = AutoModelForCausalLM.from_pretrained(
85+
model_name,
86+
torch_dtype="bfloat16",
87+
device_map=device_map,
88+
quantization_config=quantization_config,
89+
)
90+
print(quantized_model)
91+
92+
if False:
93+
# Test generation
94+
print("\nTesting quantized model generation...")
95+
input_ids = tokenizer(prompts, return_tensors="pt", padding=False).to(
96+
quantized_model.device
97+
)
98+
outputs = quantized_model.generate(**input_ids, max_new_tokens=max_new_tokens)
99+
100+
for i, (prompt, output) in enumerate(zip(prompts, outputs, strict=False)):
101+
generated_text = tokenizer.decode(output, skip_special_tokens=True)
102+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
103+
104+
# Save quantized model
105+
print(f"\nSaving quantized model to: {output_dir}")
106+
quantized_model.save_pretrained(
107+
output_dir,
108+
safe_serialization=False,
109+
)
110+
tokenizer.save_pretrained(output_dir)
111+
112+
if False:
113+
# Load saved model to verify
114+
# TODO: do we really need `weights_only=False` here?
115+
loaded_model = AutoModelForCausalLM.from_pretrained(
116+
output_dir,
117+
device_map=device_map,
118+
torch_dtype="auto",
119+
weights_only=False,
120+
)
121+
122+
# Test loaded model with first prompt
123+
test_prompt = prompts[0]
124+
input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device)
125+
output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
126+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
127+
print(
128+
f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}"
129+
)
130+
131+
print("\nQuantization process completed successfully.")
132+
133+
134+
if __name__ == "__main__":
135+
args = parse_args()
136+
main(args)

0 commit comments

Comments
 (0)