@@ -287,13 +287,14 @@ def test_advance_branched_quantize(self) -> None:
287287 @torch .no_grad ()
288288 def test_advance_quantize (self ) -> None :
289289 builder = GraphBuilder ()
290- x = builder . placeholder ( "x" , torch .randn (16 , 1 , 6 , 32 , dtype = torch .float32 ) )
291- weights = builder . placeholder (
292- "weights " , torch . randint ( - 128 , 127 , ( 32 , 32 ), dtype = torch . int8 )
293- )
290+ x_data = torch .randn (16 , 1 , 32 , 6 , dtype = torch .float32 )
291+ weight_data = torch . randint ( - 128 , 127 , ( 32 , 32 ), dtype = torch . int8 )
292+ x = builder . placeholder ( "x " , x_data )
293+ weights = builder . placeholder ( "weights" , weight_data )
294294 full = builder .call_operator (
295295 op = exir_ops .edge .aten .full .default ,
296296 args = ([1 ], - 7 ),
297+ kwargs = {"dtype" : torch .int32 },
297298 )
298299 full_1 = builder .call_operator (
299300 op = exir_ops .edge .aten .full .default ,
@@ -305,7 +306,8 @@ def test_advance_quantize(self) -> None:
305306 )
306307 full_3 = builder .call_operator (
307308 op = exir_ops .edge .aten .full .default ,
308- args = ([12 ], 0.0 ),
309+ args = ([1 ], 0 ),
310+ kwargs = {"dtype" : torch .int32 },
309311 )
310312 permute = builder .call_operator (
311313 op = exir_ops .edge .aten .permute_copy .default ,
@@ -338,8 +340,14 @@ def test_advance_quantize(self) -> None:
338340
339341 p1 = AdvanceQuantizeOpAboveDefInBranchPass ()
340342 tmp_graph = cast (PassResult , p1 (original_graph )).graph_module
341- p2 = AdvanceQuantizeOpAboveDefChainPass ()
342- converted_graph = cast (PassResult , p2 (tmp_graph )).graph_module
343+ result = transform_and_check_numerics (
344+ tmp_graph ,
345+ (x_data , weight_data ),
346+ AdvanceQuantizeOpAboveDefChainPass (),
347+ "AdvanceQuantizeOpAboveDefChainPass" ,
348+ )
349+ self .assertFalse (result .modified )
350+ converted_graph = result .graph_module
343351 # Assert that permute node is now the successor of the quant node.
344352 self .assertTrue (
345353 get_node_pos (
@@ -350,13 +358,14 @@ def test_advance_quantize(self) -> None:
350358
351359 def test_postpone_dequantize1 (self ) -> None :
352360 builder = GraphBuilder ()
353- x = builder . placeholder ( "x" , torch .randn (1 , 16 , 32 , 6 , dtype = torch .float32 ) )
354- weights = builder . placeholder (
355- "weights " , torch . randint ( - 128 , 127 , ( 6 , 6 ), dtype = torch . int8 )
356- )
361+ x_data = torch .randn (1 , 16 , 32 , 6 , dtype = torch .float32 )
362+ weight_data = torch . randint ( - 128 , 127 , ( 6 , 6 ), dtype = torch . int8 )
363+ x = builder . placeholder ( "x " , x_data )
364+ weights = builder . placeholder ( "weights" , weight_data )
357365 full = builder .call_operator (
358366 op = exir_ops .edge .aten .full .default ,
359367 args = ([1 ], - 7 ),
368+ kwargs = {"dtype" : torch .int32 },
360369 )
361370 full_1 = builder .call_operator (
362371 op = exir_ops .edge .aten .full .default ,
@@ -368,7 +377,8 @@ def test_postpone_dequantize1(self) -> None:
368377 )
369378 full_3 = builder .call_operator (
370379 op = exir_ops .edge .aten .full .default ,
371- args = ([12 ], 0.0 ),
380+ args = ([1 ], 0 ),
381+ kwargs = {"dtype" : torch .int32 },
372382 )
373383 quantize_per_tensor = builder .call_operator (
374384 op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
@@ -398,8 +408,14 @@ def test_postpone_dequantize1(self) -> None:
398408 )
399409 builder .output ([permute ])
400410 original_graph = builder .get_graph_module ()
401- p = PostponeDequantizeOpBelowUseChainPass ()
402- converted_graph = cast (PassResult , p (original_graph )).graph_module
411+ result = transform_and_check_numerics (
412+ original_graph ,
413+ (x_data , weight_data ),
414+ PostponeDequantizeOpBelowUseChainPass (),
415+ "PostponeDequantizeOpBelowUseChainPass" ,
416+ )
417+ self .assertTrue (result .modified )
418+ converted_graph = result .graph_module
403419 # Assert that dequant node is now the successor of the permute node.
404420 self .assertTrue (
405421 get_node_pos (converted_graph , exir_ops .edge .aten .permute_copy .default )
0 commit comments