Skip to content

Commit 06bc56c

Browse files
committed
Revised the lowering pass according to Bo's suggestion
1 parent 95aad4b commit 06bc56c

File tree

1 file changed

+164
-21
lines changed

1 file changed

+164
-21
lines changed

examples/dynamo/llama2_flashinfer_rmsnorm.py

Lines changed: 164 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.
1616
"""
1717

18-
from typing import Callable, Optional, Sequence, Union
18+
from typing import Any, Callable, Optional, Sequence, Union
1919

2020
import flashinfer
2121
import torch
2222
import torch_tensorrt
23+
from torch._subclasses import FakeTensor
2324
from torch.fx.passes.shape_prop import TensorMetadata
2425
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
2526
_aten_lowering_pass,
@@ -51,6 +52,7 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso
5152
def replace_rmsnorm(
5253
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
5354
) -> torch.fx.GraphModule:
55+
5456
for node in gm.graph.nodes:
5557
if (
5658
node.target == torch.ops.aten._to_copy.default
@@ -90,13 +92,60 @@ def replace_rmsnorm(
9092
weight_mul_node = list(copy_node.users)[0]
9193

9294
weight = weight_mul_node.args[0]
95+
hidden_states_node = node.args[0]
9396

94-
original_meta = weight_mul_node.meta.get(
97+
original_meta = hidden_states_node.meta.get(
9598
"tensor_meta", {}
9699
)
97100
memory_format = original_meta.memory_format
101+
from torch.fx.experimental.symbolic_shapes import (
102+
ShapeEnv,
103+
)
104+
105+
shape_env = ShapeEnv()
98106

99107
with gm.graph.inserting_after(weight_mul_node):
108+
input_meta = node.args[0].meta["val"]
109+
batch_size = input_meta.shape[0]
110+
seq_len = input_meta.shape[1]
111+
head_dim = input_meta.shape[2]
112+
113+
# Create symbolic ints for batch_size
114+
if isinstance(batch_size, int):
115+
batch_size_unbacked_symint = (
116+
shape_env.create_unbacked_symint()
117+
)
118+
torch._check(
119+
batch_size_unbacked_symint >= batch_size
120+
)
121+
torch._check(
122+
batch_size_unbacked_symint <= batch_size
123+
)
124+
elif isinstance(batch_size, torch.SymInt):
125+
pass
126+
else:
127+
raise ValueError(
128+
"Batch size must be a sym int"
129+
)
130+
131+
# Create symbolic ints for head_dim
132+
if isinstance(head_dim, int):
133+
head_dim_unbacked_symint = (
134+
shape_env.create_unbacked_symint()
135+
)
136+
torch._check(
137+
head_dim_unbacked_symint >= head_dim
138+
)
139+
torch._check(
140+
head_dim_unbacked_symint <= head_dim
141+
)
142+
elif isinstance(head_dim, torch.SymInt):
143+
pass
144+
else:
145+
raise ValueError(
146+
"head_dim must be a sym int"
147+
)
148+
100149
b = gm.graph.create_node(
101150
op="call_function",
102151
target=torch.ops.aten.sym_size.int,
@@ -111,19 +160,24 @@ def replace_rmsnorm(
111160
is_quantized=False,
112161
qparams={},
113162
)
163+
164+
batch_size = node.args[0].meta["val"].shape[0]
165+
b.meta["val"] = batch_size_unbacked_symint
166+
114167
s = gm.graph.create_node(
115168
op="call_function",
116169
target=torch.ops.aten.sym_size.int,
117170
args=(node.args[0], 1),
118171
)
119172
s.meta.update(b.meta)
120-
173+
s.meta["val"] = seq_len
121174
d = gm.graph.create_node(
122175
op="call_function",
123176
target=torch.ops.aten.sym_size.int,
124177
args=(node.args[0], 2),
125178
)
126179
d.meta.update(b.meta)
180+
d.meta["val"] = head_dim_unbacked_symint
127181

128182
with gm.graph.inserting_after(b):
129183
new_first_dim = gm.graph.create_node(
@@ -150,11 +204,11 @@ def replace_rmsnorm(
150204
[b_val * s_val, d_val]
151205
),
152206
dtype=original_meta.dtype,
153-
requires_grad=True,
154207
stride=None,
155208
memory_format=memory_format,
156209
is_quantized=False,
157210
qparams={},
211+
requires_grad=False,
158212
)
159213
)
160214

@@ -183,11 +237,22 @@ def replace_rmsnorm(
183237
[b, s, d],
184238
),
185239
)
240+
reshapback_node.meta["tensor_meta"] = (
241+
TensorMetadata(
242+
shape=torch.Size([b_val, s_val, d_val]),
243+
dtype=original_meta.dtype,
244+
stride=None,
245+
memory_format=memory_format,
246+
is_quantized=False,
247+
qparams={},
248+
requires_grad=False,
249+
)
250+
)
186251

252+
# reshapback_node.meta.update(weight_mul_node.meta)
187253
weight_mul_node.replace_all_uses_with(
188254
reshapback_node
189255
)
190-
reshapback_node.meta.update(weight_mul_node.meta)
191256

192257
modified_graph = True
193258

@@ -207,6 +272,43 @@ def replace_rmsnorm(
207272
return gm
208273

209274

275+
@_aten_lowering_pass
276+
def set_copy_node_meta_data(
277+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
278+
) -> torch.fx.GraphModule:
279+
for node in gm.graph.nodes:
280+
if node.target == torch.ops.aten._to_copy.default and (
281+
"tensor_meta" not in node.meta
282+
):
283+
input_node = node.args[0]
284+
285+
# Check if input has metadata
286+
if "tensor_meta" in input_node.meta:
287+
# Copy input metadata and update dtype to float32
288+
output_meta = input_node.meta["tensor_meta"]
289+
# output_meta.dtype = node.kwargs.get("dtype")
290+
291+
# # Assign to the _to_copy node
292+
# node.meta["tensor_meta"] = output_meta
293+
node.meta["tensor_meta"] = TensorMetadata(
294+
shape=output_meta.shape,
295+
dtype=node.kwargs.get("dtype"),
296+
requires_grad=True,
297+
stride=None,
298+
memory_format=input_node.meta["tensor_meta"].memory_format,
299+
is_quantized=False,
300+
qparams={},
301+
)
302+
303+
else:
304+
# Handle missing metadata (optional warning/logging)
305+
print(f"Warning: Input node {input_node} has no tensor_meta")
306+
307+
gm = clean_up_graph_after_modifications(gm)
308+
309+
return gm
310+
311+
210312
# 1. Create a custom config with 1 layer
211313
config = LlamaConfig(
212314
vocab_size=32000,
@@ -222,12 +324,14 @@ def replace_rmsnorm(
222324
with torch.no_grad():
223325
model = LlamaForCausalLM(config).cuda().half().eval()
224326

327+
MAX_TOKENS = 64
328+
seq_len = torch.export.Dim("seq_len", min=2, max=MAX_TOKENS)
225329
# 3. Export with static shapes
226330
input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64]
227331
exported = torch.export.export(
228332
model,
229333
(input_ids,),
230-
dynamic_shapes=None, # Fully static
334+
dynamic_shapes=({1: seq_len},),
231335
)
232336

233337
# Test forward pass
@@ -238,21 +342,60 @@ def replace_rmsnorm(
238342
# Export validation
239343

240344
DEVICE = torch.device("cuda:0")
345+
stream = torch.cuda.Stream()
346+
with torch.cuda.stream(stream):
347+
with torch_tensorrt.dynamo.Debugger(
348+
log_level="info",
349+
# profile_format="trex",
350+
# save_engine_profile=True,
351+
capture_fx_graph_before=["remove_detach"],
352+
capture_fx_graph_after=["replace_rmsnorm"],
353+
logging_dir="/home/profile/logging/torchtrt",
354+
engine_builder_monitor=False,
355+
):
356+
trt_model = torch_tensorrt.dynamo.compile(
357+
exported,
358+
inputs=[input_ids],
359+
enabled_precisions={torch.float32, torch.float16},
360+
truncate_double=True,
361+
device=DEVICE,
362+
disable_tf32=True,
363+
use_explicit_typing=False,
364+
use_fp32_acc=True,
365+
use_python_runtime=True,
366+
)
367+
368+
input_ids = input_ids.to(DEVICE)
241369

242-
with torch_tensorrt.logging.errors():
243-
trt_model = torch_tensorrt.dynamo.compile(
244-
exported,
245-
inputs=[input_ids],
246-
enabled_precisions={torch.float32, torch.float16},
247-
truncate_double=True,
248-
device=DEVICE,
249-
disable_tf32=True,
250-
use_explicit_typing=False,
251-
use_fp32_acc=True,
252-
)
370+
res = trt_model.forward(input_ids)
253371

254-
input_ids = input_ids.to(DEVICE)
372+
# Benchmark TensorRT models
255373

256-
with torch.no_grad():
257-
res = trt_model.forward(input_ids)
258-
print(res)
374+
import time
375+
376+
def benchmark_model(model, input_ids, label, n_runs=100):
377+
torch.cuda.synchronize()
378+
start = time.time()
379+
for _ in range(n_runs):
380+
with torch.no_grad():
381+
out = model(input_ids)
382+
torch.cuda.synchronize()
383+
end = time.time()
384+
print(f"{label}: {n_runs} runs, total {(end - start):.4f} s")
385+
return out
386+
387+
# Warmup
388+
with torch.no_grad():
389+
_ = trt_model(input_ids)
390+
# Benchmark
391+
trt_out = benchmark_model(trt_model, input_ids, "TensorRT model")
392+
393+
# Compare outputs
394+
395+
pytorch_logits = output.logits
396+
trt_logits = trt_out.logits
397+
398+
pytorch_logits = pytorch_logits.to(DEVICE)
399+
trt_logits = trt_logits.to(DEVICE)
400+
print("Max abs diff:", (pytorch_logits - trt_logits).abs().max().item())
401+
print("Mean abs diff:", (pytorch_logits - trt_logits).abs().mean().item())

0 commit comments

Comments
 (0)