From 3c550df961b07c00d2466287998a12414c05f42d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 12 Feb 2026 20:37:51 -0800 Subject: [PATCH 1/4] Create slice and reshape improvements Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 4 ++ onnxscript/rewriter/__init__.py | 2 + onnxscript/rewriter/rules/common/__init__.py | 4 ++ .../rewriter/rules/common/_collapse_slices.py | 7 ++- .../common/_materialize_reshape_shape.py | 63 +++++++++++++++++++ 5 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 onnxscript/rewriter/rules/common/_materialize_reshape_shape.py diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 574ddd8aef..bfa32bd821 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -196,6 +196,10 @@ def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None: def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None: const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10) + if const_value is None: + # Reshape accepts shape input of INT32 type as well, so we also check for INT32 here + # This is common for tflite models + const_value = _get_numpy_value(value, ir.DataType.INT32, size_limit=10) if const_value is not None: if const_value.ndim == 1: return ir.Shape(const_value.tolist()) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fb93bc703f..6dd19f28c2 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -42,6 +42,7 @@ _fuse_batchnorm, _fuse_pad_into_conv, _fuse_relus_clips, + _materialize_reshape_shape, _min_max_to_clip, _no_op, _redundant_scatter_nd, @@ -54,6 +55,7 @@ *_broadcast_to_matmul.rules, *_cast_constant_of_shape.rules, *_collapse_slices.rules, + *_materialize_reshape_shape.rules, *_min_max_to_clip.rules, *_fuse_relus_clips.rules, *_basic_rules.basic_optimization_rules(), diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 76d9e4f4b0..e804854ca2 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -25,6 +25,7 @@ "max_min_rule", "gemm_to_matmul_add_rule", "matmul_add_to_gemm_rule", + "materialize_reshape_shape_rule", "mul_by_1_rule", "no_op_cast_rule", "no_op_dynamic_scatter_nd_rule", @@ -107,6 +108,9 @@ transpose_ab_matmul_add_to_gemm_rule, transpose_b_matmul_add_to_gemm_rule, ) +from onnxscript.rewriter.rules.common._materialize_reshape_shape import ( + materialize_reshape_shape_rule, +) from onnxscript.rewriter.rules.common._min_max_to_clip import ( max_max_rule, max_min_rule, diff --git a/onnxscript/rewriter/rules/common/_collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py index 21b2694b82..29ab40b7ae 100644 --- a/onnxscript/rewriter/rules/common/_collapse_slices.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -82,7 +82,12 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ if data.shape is None or slice_output.shape is None: return False - if not _ir_utils.is_singleton_value(steps, 1): + # All steps must be 1 + steps_np = _ir_utils.get_numpy_value(steps) + if steps_np is not None: + if not all(s == 1 for s in steps_np.flatten()): + return False + elif not _ir_utils.is_singleton_value(steps, 1): return False return _ir_utils.same_shape(data.shape, slice_output.shape) diff --git a/onnxscript/rewriter/rules/common/_materialize_reshape_shape.py b/onnxscript/rewriter/rules/common/_materialize_reshape_shape.py new file mode 100644 index 0000000000..fb7bb12361 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_materialize_reshape_shape.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Materialize Reshape shape input from known output shape. + +When symbolic shape inference has been run, a Reshape node may have a known +output shape even though its shape input is computed dynamically (e.g., via a +Shape → Cast → Split → Concat chain). This rule replaces the shape input +with a concrete constant, allowing the dynamic chain to become dead code and +be removed by unused-node elimination. + +- Fully static output shape → constant with exact dims. +- Exactly one symbolic dim → replace it with ``-1`` (Reshape infers it). +""" + +from __future__ import annotations + +from onnxscript import ir +from onnxscript.rewriter import _ir_utils as ir_utils +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class MaterializeReshapeShape(RewriteRuleClassBase): + """Replace a dynamic Reshape shape input with a constant when output shape is known.""" + + def pattern(self, op, data, shape): + return op.Reshape(data, shape) + + def check(self, context, data: ir.Value, shape: ir.Value) -> MatchResult: + check_result = MatchResult() + + # Shape input must not already be a constant + if ir_utils.get_numpy_value(shape) is not None: + return check_result.fail("Shape input is already a constant.") + + output = context.output_values[0] + if output.shape is None: + return check_result.fail("Output shape is not known.") + + dims = list(output.shape) + sym_count = sum(1 for d in dims if not isinstance(d, int)) + + if sym_count == 0: + self._new_dims = [int(d) for d in dims] + elif sym_count == 1: + self._new_dims = [-1 if not isinstance(d, int) else int(d) for d in dims] + else: + return check_result.fail( + f"Output shape has {sym_count} symbolic dims, cannot materialize." + ) + return check_result + + def rewrite(self, op, data: ir.Value, shape: ir.Value): + new_shape = op.Constant( + value=ir.tensor(self._new_dims, dtype=ir.DataType.INT64), + ) + return op.Reshape(data, new_shape) + + +materialize_reshape_shape_rule = MaterializeReshapeShape.rule() + +rules = RewriteRuleSet([materialize_reshape_shape_rule]) + From e790612f36f58982534f7ea82129a9298d2c66de Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 12 Feb 2026 20:59:46 -0800 Subject: [PATCH 2/4] Improve tests Signed-off-by: Justin Chu --- .../optimizer/_constant_folding_test.py | 16 ++ onnxscript/rewriter/rules/common/__init__.py | 6 +- .../rules/common/_collapse_slices_test.py | 59 ++++++ .../common/_materialize_reshape_shape.py | 6 +- .../common/_materialize_reshape_shape_test.py | 169 ++++++++++++++++++ 5 files changed, 251 insertions(+), 5 deletions(-) create mode 100644 onnxscript/rewriter/rules/common/_materialize_reshape_shape_test.py diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 080af9c2f3..72abe7d93b 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -614,6 +614,22 @@ def test_gather_symdim(self): optimized = self._fold(model) self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + def test_reshape_identity_int32_shape(self): + """Reshape with a constant INT32 shape input should be recognized as identity.""" + model_ir = ir.from_onnx_text( + """ + + agraph (float[3, 4] x) => (float[3, 4] z) + { + shape_i64 = Constant () + shape = Cast (shape_i64) + z = Reshape (x, shape) + } + """ + ) + optimized = self._fold(model_ir) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") + def test_input_size_limit(self): model_text = """ diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index e804854ca2..9c4163157e 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -102,15 +102,15 @@ successive_relu_rule, ) from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule +from onnxscript.rewriter.rules.common._materialize_reshape_shape import ( + materialize_reshape_shape_rule, +) from onnxscript.rewriter.rules.common._matmul_add_to_gemm import ( matmul_add_to_gemm_rule, transpose_a_matmul_add_to_gemm_rule, transpose_ab_matmul_add_to_gemm_rule, transpose_b_matmul_add_to_gemm_rule, ) -from onnxscript.rewriter.rules.common._materialize_reshape_shape import ( - materialize_reshape_shape_rule, -) from onnxscript.rewriter.rules.common._min_max_to_clip import ( max_max_rule, max_min_rule, diff --git a/onnxscript/rewriter/rules/common/_collapse_slices_test.py b/onnxscript/rewriter/rules/common/_collapse_slices_test.py index 727240344d..bcdad4da7e 100644 --- a/onnxscript/rewriter/rules/common/_collapse_slices_test.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices_test.py @@ -119,3 +119,62 @@ def test_slice_equal_dynamic_shape_but_step_reverse(self): count = _collapse_slices.rules.apply_to_model(model) # Should not change the output shape if we did not use the default step of 1 self.assertEqual(count, 0) + + def test_multi_element_steps_all_ones_collapses(self): + """Slice with multi-axis steps=[1,1] and matching shapes should collapse.""" + model = ir.from_onnx_text( + """ + + agraph (float[L, M] data) => (float[L, M] output) + { + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + } + """ + ) + count = _collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertIn("Identity", [node.op_type for node in model.graph]) + + def test_multi_element_steps_with_non_one_does_not_collapse(self): + """Slice with steps containing a non-1 element should not collapse.""" + model = ir.from_onnx_text( + """ + + agraph (float[10, 20] data) => (float[10, 10] output) + { + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + } + """ + ) + count = _collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 0) + + def test_multi_element_steps_numerical_correctness(self): + """Verify numerical correctness of multi-axis collapse.""" + model_text = """ + + agraph (float[4, 5] data) => (float[4, 5] output) + { + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + } + """ + original = ir.from_onnx_text(model_text) + model = ir.from_onnx_text(model_text) + _collapse_slices.rules.apply_to_model(model) + testing.assert_numerically_equal( + original, + model, + (np.random.rand(4, 5).astype(np.float32),), + ) diff --git a/onnxscript/rewriter/rules/common/_materialize_reshape_shape.py b/onnxscript/rewriter/rules/common/_materialize_reshape_shape.py index fb7bb12361..0a53407160 100644 --- a/onnxscript/rewriter/rules/common/_materialize_reshape_shape.py +++ b/onnxscript/rewriter/rules/common/_materialize_reshape_shape.py @@ -48,16 +48,18 @@ def check(self, context, data: ir.Value, shape: ir.Value) -> MatchResult: return check_result.fail( f"Output shape has {sym_count} symbolic dims, cannot materialize." ) + + # Preserve allowzero attribute from original node + self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0) return check_result def rewrite(self, op, data: ir.Value, shape: ir.Value): new_shape = op.Constant( value=ir.tensor(self._new_dims, dtype=ir.DataType.INT64), ) - return op.Reshape(data, new_shape) + return op.Reshape(data, new_shape, allowzero=self._allowzero or None) materialize_reshape_shape_rule = MaterializeReshapeShape.rule() rules = RewriteRuleSet([materialize_reshape_shape_rule]) - diff --git a/onnxscript/rewriter/rules/common/_materialize_reshape_shape_test.py b/onnxscript/rewriter/rules/common/_materialize_reshape_shape_test.py new file mode 100644 index 0000000000..522d8750c9 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_materialize_reshape_shape_test.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _materialize_reshape_shape + + +class MaterializeReshapeShapeTest(unittest.TestCase): + def test_fully_static_output_shape_materializes(self): + """When output shape is fully static, replace dynamic shape input with constant.""" + model = ir.from_onnx_text( + """ + + agraph (float[6] data) => (float[2, 3] output) + { + shape = Shape(data) + output = Reshape(data, shape) + } + """ + ) + for node in model.graph: + if node.op_type == "Reshape": + node.outputs[0].shape = ir.Shape([2, 3]) + break + count = _materialize_reshape_shape.rules.apply_to_model(model) + self.assertEqual(count, 1) + reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"] + self.assertEqual(len(reshape_nodes), 1) + shape_input = reshape_nodes[0].inputs[1] + self.assertIsNotNone(shape_input.const_value) + self.assertEqual(shape_input.const_value.numpy().tolist(), [2, 3]) + + def test_one_symbolic_dim_uses_minus_one(self): + """When output has one symbolic dim, replace it with -1.""" + model = ir.from_onnx_text( + """ + + agraph (float[6] data) => (float[B, 3] output) + { + shape = Shape(data) + output = Reshape(data, shape) + } + """ + ) + for node in model.graph: + if node.op_type == "Reshape": + node.outputs[0].shape = ir.Shape(["B", 3]) + break + count = _materialize_reshape_shape.rules.apply_to_model(model) + self.assertEqual(count, 1) + reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"] + self.assertEqual(len(reshape_nodes), 1) + shape_input = reshape_nodes[0].inputs[1] + self.assertIsNotNone(shape_input.const_value) + self.assertEqual(shape_input.const_value.numpy().tolist(), [-1, 3]) + + def test_two_symbolic_dims_not_materialized(self): + """When output has two symbolic dims, the rule should not fire.""" + model = ir.from_onnx_text( + """ + + agraph (float[6] data) => (float[B, C] output) + { + shape = Shape(data) + output = Reshape(data, shape) + } + """ + ) + for node in model.graph: + if node.op_type == "Reshape": + node.outputs[0].shape = ir.Shape(["B", "C"]) + break + count = _materialize_reshape_shape.rules.apply_to_model(model) + self.assertEqual(count, 0) + + def test_constant_shape_input_not_replaced(self): + """When the shape input is already a constant, the rule should not fire.""" + model = ir.from_onnx_text( + """ + + agraph (float[6] data) => (float[2, 3] output) + { + shape = Constant() + output = Reshape(data, shape) + } + """ + ) + count = _materialize_reshape_shape.rules.apply_to_model(model) + self.assertEqual(count, 0) + + def test_unknown_output_shape_not_materialized(self): + """When the output shape is unknown, the rule should not fire.""" + model = ir.from_onnx_text( + """ + + agraph (float[6] data) => (float output) + { + shape = Shape(data) + output = Reshape(data, shape) + } + """ + ) + for node in model.graph: + if node.op_type == "Reshape": + node.outputs[0].shape = None + break + count = _materialize_reshape_shape.rules.apply_to_model(model) + self.assertEqual(count, 0) + + def test_allowzero_attribute_preserved(self): + """The allowzero attribute should be preserved on the new Reshape.""" + model = ir.from_onnx_text( + """ + + agraph (float[6] data) => (float[2, 3] output) + { + shape = Shape(data) + output = Reshape(data, shape) + } + """ + ) + for node in model.graph: + if node.op_type == "Reshape": + node.outputs[0].shape = ir.Shape([2, 3]) + break + count = _materialize_reshape_shape.rules.apply_to_model(model) + self.assertEqual(count, 1) + reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"] + self.assertEqual(len(reshape_nodes), 1) + allowzero = reshape_nodes[0].attributes.get_int("allowzero", 0) + self.assertEqual(allowzero, 1) + + def test_numerical_correctness_static(self): + """Verify numerical equivalence for fully static materialization.""" + # Build a model where a dynamic Concat produces the shape for Reshape. + # After materialization, the Reshape uses a constant shape. + model_text = """ + + agraph (float[12] data, float[3, 4] ref) => (float[3, 4] output) + { + shape = Shape(ref) + output = Reshape(data, shape) + } + """ + original = ir.from_onnx_text(model_text) + model = ir.from_onnx_text(model_text) + for node in model.graph: + if node.op_type == "Reshape": + node.outputs[0].shape = ir.Shape([3, 4]) + break + _materialize_reshape_shape.rules.apply_to_model(model) + testing.assert_numerically_equal( + original, + model, + ( + np.arange(12).astype(np.float32), + np.zeros((3, 4), dtype=np.float32), + ), + ) + + +if __name__ == "__main__": + unittest.main() From 0bcd3bc8781ac81ece3a7d0b2610e2ec1a9f28e4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 13 Feb 2026 07:26:54 -0800 Subject: [PATCH 3/4] Update onnxscript/rewriter/rules/common/_collapse_slices.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/rewriter/rules/common/_collapse_slices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/rules/common/_collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py index 29ab40b7ae..14e7e06d17 100644 --- a/onnxscript/rewriter/rules/common/_collapse_slices.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -85,7 +85,7 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ # All steps must be 1 steps_np = _ir_utils.get_numpy_value(steps) if steps_np is not None: - if not all(s == 1 for s in steps_np.flatten()): + if not all(s == 1 for s in steps_np.flat): return False elif not _ir_utils.is_singleton_value(steps, 1): return False From c2620f707258c5618123782768c9768324db7cf8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 13 Feb 2026 14:22:26 -0800 Subject: [PATCH 4/4] Reshape does not support INT32 Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 4 ---- onnxscript/optimizer/_constant_folding_test.py | 16 ---------------- 2 files changed, 20 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index bfa32bd821..574ddd8aef 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -196,10 +196,6 @@ def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None: def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None: const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10) - if const_value is None: - # Reshape accepts shape input of INT32 type as well, so we also check for INT32 here - # This is common for tflite models - const_value = _get_numpy_value(value, ir.DataType.INT32, size_limit=10) if const_value is not None: if const_value.ndim == 1: return ir.Shape(const_value.tolist()) diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 72abe7d93b..080af9c2f3 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -614,22 +614,6 @@ def test_gather_symdim(self): optimized = self._fold(model) self.assertEqual(optimized.graph.node(-1).op_type, "Identity") - def test_reshape_identity_int32_shape(self): - """Reshape with a constant INT32 shape input should be recognized as identity.""" - model_ir = ir.from_onnx_text( - """ - - agraph (float[3, 4] x) => (float[3, 4] z) - { - shape_i64 = Constant () - shape = Cast (shape_i64) - z = Reshape (x, shape) - } - """ - ) - optimized = self._fold(model_ir) - self.assertEqual(optimized.graph.node(-1).op_type, "Identity") - def test_input_size_limit(self): model_text = """