1515This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.
1616"""
1717
18- from typing import Callable , Optional , Sequence , Union
18+ from typing import Any , Callable , Optional , Sequence , Union
1919
2020import flashinfer
2121import torch
2222import torch_tensorrt
23+ from torch ._subclasses import FakeTensor
2324from torch .fx .passes .shape_prop import TensorMetadata
2425from torch_tensorrt .dynamo .lowering .passes ._aten_lowering_pass import (
2526 _aten_lowering_pass ,
@@ -51,6 +52,7 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso
5152def replace_rmsnorm (
5253 gm : torch .fx .GraphModule , sample_inputs : Sequence [torch .Tensor ]
5354) -> torch .fx .GraphModule :
55+
5456 for node in gm .graph .nodes :
5557 if (
5658 node .target == torch .ops .aten ._to_copy .default
@@ -90,13 +92,60 @@ def replace_rmsnorm(
9092 weight_mul_node = list (copy_node .users )[0 ]
9193
9294 weight = weight_mul_node .args [0 ]
95+ hidden_states_node = node .args [0 ]
9396
94- original_meta = weight_mul_node .meta .get (
97+ original_meta = hidden_states_node .meta .get (
9598 "tensor_meta" , {}
9699 )
97100 memory_format = original_meta .memory_format
101+ from torch .fx .experimental .symbolic_shapes import (
102+ ShapeEnv ,
103+ )
104+
105+ shape_env = ShapeEnv ()
98106
99107 with gm .graph .inserting_after (weight_mul_node ):
108+ input_meta = node .args [0 ].meta ["val" ]
109+ batch_size = input_meta .shape [0 ]
110+ seq_len = input_meta .shape [1 ]
111+ head_dim = input_meta .shape [2 ]
112+
113+ # Create symbolic ints for batch_size
114+ if isinstance (batch_size , int ):
115+ batch_size_unbacked_symint = (
116+ shape_env .create_unbacked_symint ()
117+ )
118+ torch ._check (
119+ batch_size_unbacked_symint >= batch_size
120+ )
121+ torch ._check (
122+ batch_size_unbacked_symint <= batch_size
123+ )
124+ elif isinstance (batch_size , torch .SymInt ):
125+ pass
126+ else :
127+ raise ValueError (
128+ "Batch size must be a sym int"
129+ )
130+
131+ # Create symbolic ints for head_dim
132+ if isinstance (head_dim , int ):
133+ head_dim_unbacked_symint = (
134+ shape_env .create_unbacked_symint ()
135+ )
136+ torch ._check (
137+ head_dim_unbacked_symint >= head_dim
138+ )
139+ torch ._check (
140+ head_dim_unbacked_symint <= head_dim
141+ )
142+ elif isinstance (head_dim , torch .SymInt ):
143+ pass
144+ else :
145+ raise ValueError (
146+ "head_dim must be a sym int"
147+ )
148+
100149 b = gm .graph .create_node (
101150 op = "call_function" ,
102151 target = torch .ops .aten .sym_size .int ,
@@ -111,19 +160,24 @@ def replace_rmsnorm(
111160 is_quantized = False ,
112161 qparams = {},
113162 )
163+
164+ batch_size = node .args [0 ].meta ["val" ].shape [0 ]
165+ b .meta ["val" ] = batch_size_unbacked_symint
166+
114167 s = gm .graph .create_node (
115168 op = "call_function" ,
116169 target = torch .ops .aten .sym_size .int ,
117170 args = (node .args [0 ], 1 ),
118171 )
119172 s .meta .update (b .meta )
120-
173+ s . meta [ "val" ] = seq_len
121174 d = gm .graph .create_node (
122175 op = "call_function" ,
123176 target = torch .ops .aten .sym_size .int ,
124177 args = (node .args [0 ], 2 ),
125178 )
126179 d .meta .update (b .meta )
180+ d .meta ["val" ] = head_dim_unbacked_symint
127181
128182 with gm .graph .inserting_after (b ):
129183 new_first_dim = gm .graph .create_node (
@@ -150,11 +204,11 @@ def replace_rmsnorm(
150204 [b_val * s_val , d_val ]
151205 ),
152206 dtype = original_meta .dtype ,
153- requires_grad = True ,
154207 stride = None ,
155208 memory_format = memory_format ,
156209 is_quantized = False ,
157210 qparams = {},
211+ requires_grad = False ,
158212 )
159213 )
160214
@@ -183,11 +237,22 @@ def replace_rmsnorm(
183237 [b , s , d ],
184238 ),
185239 )
240+ reshapback_node .meta ["tensor_meta" ] = (
241+ TensorMetadata (
242+ shape = torch .Size ([b_val , s_val , d_val ]),
243+ dtype = original_meta .dtype ,
244+ stride = None ,
245+ memory_format = memory_format ,
246+ is_quantized = False ,
247+ qparams = {},
248+ requires_grad = False ,
249+ )
250+ )
186251
252+ # reshapback_node.meta.update(weight_mul_node.meta)
187253 weight_mul_node .replace_all_uses_with (
188254 reshapback_node
189255 )
190- reshapback_node .meta .update (weight_mul_node .meta )
191256
192257 modified_graph = True
193258
@@ -207,6 +272,43 @@ def replace_rmsnorm(
207272 return gm
208273
209274
275+ @_aten_lowering_pass
276+ def set_copy_node_meta_data (
277+ gm : torch .fx .GraphModule , sample_inputs : Sequence [torch .Tensor ]
278+ ) -> torch .fx .GraphModule :
279+ for node in gm .graph .nodes :
280+ if node .target == torch .ops .aten ._to_copy .default and (
281+ "tensor_meta" not in node .meta
282+ ):
283+ input_node = node .args [0 ]
284+
285+ # Check if input has metadata
286+ if "tensor_meta" in input_node .meta :
287+ # Copy input metadata and update dtype to float32
288+ output_meta = input_node .meta ["tensor_meta" ]
289+ # output_meta.dtype = node.kwargs.get("dtype")
290+
291+ # # Assign to the _to_copy node
292+ # node.meta["tensor_meta"] = output_meta
293+ node .meta ["tensor_meta" ] = TensorMetadata (
294+ shape = output_meta .shape ,
295+ dtype = node .kwargs .get ("dtype" ),
296+ requires_grad = True ,
297+ stride = None ,
298+ memory_format = input_node .meta ["tensor_meta" ].memory_format ,
299+ is_quantized = False ,
300+ qparams = {},
301+ )
302+
303+ else :
304+ # Handle missing metadata (optional warning/logging)
305+ print (f"Warning: Input node { input_node } has no tensor_meta" )
306+
307+ gm = clean_up_graph_after_modifications (gm )
308+
309+ return gm
310+
311+
210312# 1. Create a custom config with 1 layer
211313config = LlamaConfig (
212314 vocab_size = 32000 ,
@@ -222,12 +324,14 @@ def replace_rmsnorm(
222324with torch .no_grad ():
223325 model = LlamaForCausalLM (config ).cuda ().half ().eval ()
224326
327+ MAX_TOKENS = 64
328+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = MAX_TOKENS )
225329# 3. Export with static shapes
226330input_ids = torch .randint (0 , 32000 , (1 , 64 )) # Static [batch=1, seq=64]
227331exported = torch .export .export (
228332 model ,
229333 (input_ids ,),
230- dynamic_shapes = None , # Fully static
334+ dynamic_shapes = ({ 1 : seq_len },),
231335)
232336
233337# Test forward pass
@@ -238,21 +342,60 @@ def replace_rmsnorm(
238342# Export validation
239343
240344DEVICE = torch .device ("cuda:0" )
345+ stream = torch .cuda .Stream ()
346+ with torch .cuda .stream (stream ):
347+ with torch_tensorrt .dynamo .Debugger (
348+ log_level = "info" ,
349+ # profile_format="trex",
350+ # save_engine_profile=True,
351+ capture_fx_graph_before = ["remove_detach" ],
352+ capture_fx_graph_after = ["replace_rmsnorm" ],
353+ logging_dir = "/home/profile/logging/torchtrt" ,
354+ engine_builder_monitor = False ,
355+ ):
356+ trt_model = torch_tensorrt .dynamo .compile (
357+ exported ,
358+ inputs = [input_ids ],
359+ enabled_precisions = {torch .float32 , torch .float16 },
360+ truncate_double = True ,
361+ device = DEVICE ,
362+ disable_tf32 = True ,
363+ use_explicit_typing = False ,
364+ use_fp32_acc = True ,
365+ use_python_runtime = True ,
366+ )
367+
368+ input_ids = input_ids .to (DEVICE )
241369
242- with torch_tensorrt .logging .errors ():
243- trt_model = torch_tensorrt .dynamo .compile (
244- exported ,
245- inputs = [input_ids ],
246- enabled_precisions = {torch .float32 , torch .float16 },
247- truncate_double = True ,
248- device = DEVICE ,
249- disable_tf32 = True ,
250- use_explicit_typing = False ,
251- use_fp32_acc = True ,
252- )
370+ res = trt_model .forward (input_ids )
253371
254- input_ids = input_ids . to ( DEVICE )
372+ # Benchmark TensorRT models
255373
256- with torch .no_grad ():
257- res = trt_model .forward (input_ids )
258- print (res )
374+ import time
375+
376+ def benchmark_model (model , input_ids , label , n_runs = 100 ):
377+ torch .cuda .synchronize ()
378+ start = time .time ()
379+ for _ in range (n_runs ):
380+ with torch .no_grad ():
381+ out = model (input_ids )
382+ torch .cuda .synchronize ()
383+ end = time .time ()
384+ print (f"{ label } : { n_runs } runs, total { (end - start ):.4f} s" )
385+ return out
386+
387+ # Warmup
388+ with torch .no_grad ():
389+ _ = trt_model (input_ids )
390+ # Benchmark
391+ trt_out = benchmark_model (trt_model , input_ids , "TensorRT model" )
392+
393+ # Compare outputs
394+
395+ pytorch_logits = output .logits
396+ trt_logits = trt_out .logits
397+
398+ pytorch_logits = pytorch_logits .to (DEVICE )
399+ trt_logits = trt_logits .to (DEVICE )
400+ print ("Max abs diff:" , (pytorch_logits - trt_logits ).abs ().max ().item ())
401+ print ("Mean abs diff:" , (pytorch_logits - trt_logits ).abs ().mean ().item ())
0 commit comments