Skip to content

Commit b5a57d1

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update AdvanceQuantizeOpAboveDefChainPass, PostponeDequantizeOpBelowUseChainPass, and ScalarToTensorPass to correctly set modified bit (#16230)
Summary: Update - AdvanceQuantizeOpAboveDefChainPass - PostponeDequantizeOpBelowUseChainPass to correctly track modified bit and updated tests to check for numerical correctness. Differential Revision: D87900822
1 parent eb55377 commit b5a57d1

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

backends/cadence/aot/reorder_ops.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ def advancing_feasible(self, quant_node: torch.fx.Node):
299299
# All the conditions satisfied, we advance.
300300
return True
301301

302-
def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
302+
def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
303303
graph = graph_module.graph
304+
modified = False
304305
for node in reversed(graph.nodes):
305306
if get_overload_packet(node.target) not in (
306307
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
@@ -339,15 +340,19 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
339340
# We can safely remove the quant node and trivially quantizable op
340341
graph.erase_node(node)
341342
graph.erase_node(trivially_quantizable_op)
343+
modified = True
342344

343-
graph_module.recompile()
344-
graph_module.graph.eliminate_dead_code()
345+
return modified
345346

346347
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
347348
self.graph_module = graph_module
348-
self.advance_quantize_op(graph_module)
349-
result = super().call(graph_module)
350-
return result
349+
modified = self.advance_quantize_op(graph_module)
350+
if modified:
351+
graph_module.recompile()
352+
graph_module.graph.eliminate_dead_code()
353+
return super().call(graph_module)
354+
355+
return PassResult(graph_module, False)
351356

352357

353358
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -474,14 +479,21 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
474479
# the graph (up to 3 times max, to avoid potential infinite loops)
475480
self.graph_module = graph_module
476481
iter_count = 0
477-
modified = True
482+
local_modified = False
483+
overall_modified = False
484+
485+
while local_modified or iter_count == 0:
486+
local_modified = self.postpone_dequantize_op(self.graph_module)
487+
overall_modified |= local_modified
488+
489+
if local_modified:
490+
self.graph_module = super().call(self.graph_module).graph_module
478491

479-
while modified and iter_count < 3:
480-
modified = self.postpone_dequantize_op(self.graph_module)
481-
self.graph_module = super().call(self.graph_module).graph_module
482492
iter_count += 1
493+
if iter_count == 3:
494+
break
483495

484-
return super().call(self.graph_module)
496+
return PassResult(graph_module, overall_modified)
485497

486498

487499
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,14 @@ def test_advance_branched_quantize(self) -> None:
287287
@torch.no_grad()
288288
def test_advance_quantize(self) -> None:
289289
builder = GraphBuilder()
290-
x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32))
291-
weights = builder.placeholder(
292-
"weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8)
293-
)
290+
x_data = torch.randn(16, 1, 32, 6, dtype=torch.float32)
291+
weight_data = torch.randint(-128, 127, (32, 32), dtype=torch.int8)
292+
x = builder.placeholder("x", x_data)
293+
weights = builder.placeholder("weights", weight_data)
294294
full = builder.call_operator(
295295
op=exir_ops.edge.aten.full.default,
296296
args=([1], -7),
297+
kwargs={"dtype": torch.int32},
297298
)
298299
full_1 = builder.call_operator(
299300
op=exir_ops.edge.aten.full.default,
@@ -305,7 +306,8 @@ def test_advance_quantize(self) -> None:
305306
)
306307
full_3 = builder.call_operator(
307308
op=exir_ops.edge.aten.full.default,
308-
args=([12], 0.0),
309+
args=([1], 0),
310+
kwargs={"dtype": torch.int32},
309311
)
310312
permute = builder.call_operator(
311313
op=exir_ops.edge.aten.permute_copy.default,
@@ -338,8 +340,14 @@ def test_advance_quantize(self) -> None:
338340

339341
p1 = AdvanceQuantizeOpAboveDefInBranchPass()
340342
tmp_graph = cast(PassResult, p1(original_graph)).graph_module
341-
p2 = AdvanceQuantizeOpAboveDefChainPass()
342-
converted_graph = cast(PassResult, p2(tmp_graph)).graph_module
343+
result = transform_and_check_numerics(
344+
tmp_graph,
345+
(x_data, weight_data),
346+
AdvanceQuantizeOpAboveDefChainPass(),
347+
"AdvanceQuantizeOpAboveDefChainPass",
348+
)
349+
self.assertFalse(result.modified)
350+
converted_graph = result.graph_module
343351
# Assert that permute node is now the successor of the quant node.
344352
self.assertTrue(
345353
get_node_pos(
@@ -350,13 +358,14 @@ def test_advance_quantize(self) -> None:
350358

351359
def test_postpone_dequantize1(self) -> None:
352360
builder = GraphBuilder()
353-
x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32))
354-
weights = builder.placeholder(
355-
"weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8)
356-
)
361+
x_data = torch.randn(1, 16, 32, 6, dtype=torch.float32)
362+
weight_data = torch.randint(-128, 127, (6, 6), dtype=torch.int8)
363+
x = builder.placeholder("x", x_data)
364+
weights = builder.placeholder("weights", weight_data)
357365
full = builder.call_operator(
358366
op=exir_ops.edge.aten.full.default,
359367
args=([1], -7),
368+
kwargs={"dtype": torch.int32},
360369
)
361370
full_1 = builder.call_operator(
362371
op=exir_ops.edge.aten.full.default,
@@ -368,7 +377,8 @@ def test_postpone_dequantize1(self) -> None:
368377
)
369378
full_3 = builder.call_operator(
370379
op=exir_ops.edge.aten.full.default,
371-
args=([12], 0.0),
380+
args=([1], 0),
381+
kwargs={"dtype": torch.int32},
372382
)
373383
quantize_per_tensor = builder.call_operator(
374384
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
@@ -398,8 +408,14 @@ def test_postpone_dequantize1(self) -> None:
398408
)
399409
builder.output([permute])
400410
original_graph = builder.get_graph_module()
401-
p = PostponeDequantizeOpBelowUseChainPass()
402-
converted_graph = cast(PassResult, p(original_graph)).graph_module
411+
result = transform_and_check_numerics(
412+
original_graph,
413+
(x_data, weight_data),
414+
PostponeDequantizeOpBelowUseChainPass(),
415+
"PostponeDequantizeOpBelowUseChainPass",
416+
)
417+
self.assertTrue(result.modified)
418+
converted_graph = result.graph_module
403419
# Assert that dequant node is now the successor of the permute node.
404420
self.assertTrue(
405421
get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)

0 commit comments

Comments
 (0)