Skip to content

Commit 9248896

Browse files
committed
Added perf script
1 parent 46829e6 commit 9248896

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import time
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import torch_tensorrt as torchtrt
7+
import torchvision
8+
from pyinstrument import Profiler
9+
from torch_tensorrt.dynamo.utils import get_model_device
10+
11+
torch.manual_seed(0)
12+
torch.cuda.manual_seed_all(0)
13+
import argparse
14+
15+
16+
def benchmark_model(model, input, label, profile=False):
17+
if profile:
18+
profiler = Profiler(interval=0.01)
19+
profiler.start()
20+
start_time = time.time()
21+
for _ in range(1000):
22+
model_outputs = model(*input)
23+
end_time = time.time()
24+
print(f"{label} 1000 runs: {end_time - start_time:.4f} seconds")
25+
if profile:
26+
profiler.stop()
27+
profiler.write_html(
28+
f"/home/other/{label.replace(' ', '_')}.html", timeline=False, show_all=True
29+
)
30+
31+
32+
def main(args):
33+
profile = args.profile
34+
use_python_runtime = args.use_python_runtime
35+
model_name = args.model
36+
37+
with torchtrt.dynamo.Debugger(log_level="debug", engine_builder_monitor=False):
38+
39+
model = (
40+
torchvision.models.detection.ssd300_vgg16(pretrained=True).eval().to("cuda")
41+
)
42+
input = [torch.randn((1, 3, 224, 224)).to("cuda")]
43+
44+
BATCH = torch.export.Dim("BATCH", min=1, max=16)
45+
exp_program = torch.export.export(model, tuple(input), strict=True)
46+
trt_mod2 = trt_gm = torchtrt.dynamo.compile(
47+
exp_program,
48+
tuple(input),
49+
use_python_runtime=use_python_runtime,
50+
enabled_precisions={torch.float},
51+
min_block_size=1,
52+
immutable_weights=False,
53+
reuse_cached_engines=False,
54+
)
55+
56+
trt_mod1 = trt_gm = torchtrt.dynamo.compile(
57+
exp_program,
58+
tuple(input),
59+
use_python_runtime=use_python_runtime,
60+
enabled_precisions={torch.float},
61+
min_block_size=1,
62+
immutable_weights=False,
63+
torch_executed_ops={torch.ops.aten.relu.default},
64+
reuse_cached_engines=False,
65+
)
66+
67+
# AOTI
68+
if not use_python_runtime:
69+
torchtrt.save(
70+
trt_mod1,
71+
"/home/other/aoti.pt2",
72+
output_format="aot_inductor",
73+
inputs=input,
74+
retrace=True,
75+
)
76+
aoti_model_gb = torch._inductor.aoti_load_package("/home/other/aoti.pt2")
77+
torchtrt.save(
78+
trt_mod2,
79+
"/home/other/aoti_no_gb.pt2",
80+
output_format="aot_inductor",
81+
inputs=input,
82+
retrace=True,
83+
)
84+
aoti_model_no_gb = torch._inductor.aoti_load_package(
85+
"/home/other/aoti_no_gb.pt2"
86+
)
87+
88+
# Warmup runs to avoid measuring first-run overheads
89+
for _ in range(100):
90+
trt_mod2(*input)
91+
model(*input)
92+
if not use_python_runtime:
93+
aoti_model_gb(*input)
94+
aoti_model_no_gb(*input)
95+
96+
time.sleep(1)
97+
benchmark_model(trt_mod1, input, "trt_mod1 (with graph break)", profile=profile)
98+
benchmark_model(trt_mod2, input, "trt_mod2 (without graph break)", profile=profile)
99+
if not use_python_runtime:
100+
benchmark_model(aoti_model_gb, input, "aoti_model_gb", profile=profile)
101+
benchmark_model(aoti_model_no_gb, input, "aoti_model_no_gb", profile=profile)
102+
103+
out1 = trt_mod1(*input)
104+
out2 = trt_mod2(*input)
105+
if not use_python_runtime:
106+
out3 = aoti_model_gb(*input)
107+
out4 = aoti_model_no_gb(*input)
108+
109+
def _to_tuple(x):
110+
if isinstance(x, (tuple, list)):
111+
return tuple(x)
112+
return (x,)
113+
114+
outs1 = _to_tuple(out1)
115+
outs2 = _to_tuple(out2)
116+
if not use_python_runtime:
117+
outs3 = _to_tuple(out3)
118+
outs4 = _to_tuple(out4)
119+
120+
def compare_outputs(a, b, name1="A", name2="B"):
121+
if len(a) != len(b):
122+
print(f"Number of outputs differ: {len(a)} vs {len(b)}")
123+
return False
124+
all_equal = True
125+
for i, (x, y) in enumerate(zip(a, b)):
126+
if not torch.allclose(x, y, atol=1e-3, rtol=1e-3):
127+
print(f"Output {i} differs between {name1} and {name2}")
128+
print(f"max diff: {torch.max(torch.abs(x - y))}")
129+
print(f"Mean diff: {torch.mean(torch.abs(x - y))}")
130+
all_equal = False
131+
if all_equal:
132+
print(f"All outputs match between {name1} and {name2}")
133+
return all_equal
134+
135+
compare_outputs(outs1, outs2, "trt_mod1", "trt_mod2")
136+
if not use_python_runtime:
137+
compare_outputs(outs1, outs3, "trt_mod1", "aoti_model_gb")
138+
compare_outputs(outs1, outs4, "trt_mod1", "aoti_model_no_gb")
139+
compare_outputs(outs2, outs3, "trt_mod2", "aoti_model")
140+
141+
142+
if __name__ == "__main__":
143+
arg_parser = argparse.ArgumentParser()
144+
arg_parser.add_argument("--profile", action="store_true")
145+
arg_parser.add_argument("--use_python_runtime", action="store_true")
146+
arg_parser.add_argument(
147+
"--model", type=str, default="resnet18", choices=["resnet18", "resnet152"]
148+
)
149+
args = arg_parser.parse_args()
150+
main(args)

0 commit comments

Comments
 (0)