Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:

def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
if const_value is None:
# Reshape accepts shape input of INT32 type as well, so we also check for INT32 here
# This is common for tflite models
const_value = _get_numpy_value(value, ir.DataType.INT32, size_limit=10)
if const_value is not None:
if const_value.ndim == 1:
return ir.Shape(const_value.tolist())
Expand Down
16 changes: 16 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,22 @@ def test_gather_symdim(self):
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_reshape_identity_int32_shape(self):
"""Reshape with a constant INT32 shape input should be recognized as identity."""
model_ir = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[3, 4] x) => (float[3, 4] z)
{
shape_i64 = Constant <value_ints=[3, 4]> ()
shape = Cast <to=6> (shape_i64)
z = Reshape (x, shape)
}
"""
)
optimized = self._fold(model_ir)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_input_size_limit(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 17]>
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_fuse_batchnorm,
_fuse_pad_into_conv,
_fuse_relus_clips,
_materialize_reshape_shape,
_min_max_to_clip,
_no_op,
_redundant_scatter_nd,
Expand All @@ -54,6 +55,7 @@
*_broadcast_to_matmul.rules,
*_cast_constant_of_shape.rules,
*_collapse_slices.rules,
*_materialize_reshape_shape.rules,
*_min_max_to_clip.rules,
*_fuse_relus_clips.rules,
*_basic_rules.basic_optimization_rules(),
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/rewriter/rules/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"max_min_rule",
"gemm_to_matmul_add_rule",
"matmul_add_to_gemm_rule",
"materialize_reshape_shape_rule",
"mul_by_1_rule",
"no_op_cast_rule",
"no_op_dynamic_scatter_nd_rule",
Expand Down Expand Up @@ -101,6 +102,9 @@
successive_relu_rule,
)
from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule
from onnxscript.rewriter.rules.common._materialize_reshape_shape import (
materialize_reshape_shape_rule,
)
from onnxscript.rewriter.rules.common._matmul_add_to_gemm import (
matmul_add_to_gemm_rule,
transpose_a_matmul_add_to_gemm_rule,
Expand Down
7 changes: 6 additions & 1 deletion onnxscript/rewriter/rules/common/_collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_
if data.shape is None or slice_output.shape is None:
return False

if not _ir_utils.is_singleton_value(steps, 1):
# All steps must be 1
steps_np = _ir_utils.get_numpy_value(steps)
if steps_np is not None:
if not all(s == 1 for s in steps_np.flat):
return False
elif not _ir_utils.is_singleton_value(steps, 1):
return False

return _ir_utils.same_shape(data.shape, slice_output.shape)
Expand Down
59 changes: 59 additions & 0 deletions onnxscript/rewriter/rules/common/_collapse_slices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,62 @@ def test_slice_equal_dynamic_shape_but_step_reverse(self):
count = _collapse_slices.rules.apply_to_model(model)
# Should not change the output shape if we did not use the default step of 1
self.assertEqual(count, 0)

def test_multi_element_steps_all_ones_collapses(self):
"""Slice with multi-axis steps=[1,1] and matching shapes should collapse."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[L, M] data) => (float[L, M] output)
{
starts = Constant<value: tensor = int64[2] {0, 0}>()
ends = Constant<value: tensor = int64[2] {9999, 9999}>()
axes = Constant<value: tensor = int64[2] {0, 1}>()
steps = Constant<value: tensor = int64[2] {1, 1}>()
output = Slice (data, starts, ends, axes, steps)
}
"""
)
count = _collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertIn("Identity", [node.op_type for node in model.graph])

def test_multi_element_steps_with_non_one_does_not_collapse(self):
"""Slice with steps containing a non-1 element should not collapse."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[10, 20] data) => (float[10, 10] output)
{
starts = Constant<value: tensor = int64[2] {0, 0}>()
ends = Constant<value: tensor = int64[2] {10, 20}>()
axes = Constant<value: tensor = int64[2] {0, 1}>()
steps = Constant<value: tensor = int64[2] {1, 2}>()
output = Slice (data, starts, ends, axes, steps)
}
"""
)
count = _collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 0)

def test_multi_element_steps_numerical_correctness(self):
"""Verify numerical correctness of multi-axis collapse."""
model_text = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[4, 5] data) => (float[4, 5] output)
{
starts = Constant<value: tensor = int64[2] {0, 0}>()
ends = Constant<value: tensor = int64[2] {100, 100}>()
axes = Constant<value: tensor = int64[2] {0, 1}>()
steps = Constant<value: tensor = int64[2] {1, 1}>()
output = Slice (data, starts, ends, axes, steps)
}
"""
original = ir.from_onnx_text(model_text)
model = ir.from_onnx_text(model_text)
_collapse_slices.rules.apply_to_model(model)
testing.assert_numerically_equal(
original,
model,
(np.random.rand(4, 5).astype(np.float32),),
)
65 changes: 65 additions & 0 deletions onnxscript/rewriter/rules/common/_materialize_reshape_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Materialize Reshape shape input from known output shape.

When symbolic shape inference has been run, a Reshape node may have a known
output shape even though its shape input is computed dynamically (e.g., via a
Shape → Cast → Split → Concat chain). This rule replaces the shape input
with a concrete constant, allowing the dynamic chain to become dead code and
be removed by unused-node elimination.

- Fully static output shape → constant with exact dims.
- Exactly one symbolic dim → replace it with ``-1`` (Reshape infers it).
"""

from __future__ import annotations

from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet


class MaterializeReshapeShape(RewriteRuleClassBase):
"""Replace a dynamic Reshape shape input with a constant when output shape is known."""

def pattern(self, op, data, shape):
return op.Reshape(data, shape)

def check(self, context, data: ir.Value, shape: ir.Value) -> MatchResult:
check_result = MatchResult()

# Shape input must not already be a constant
if ir_utils.get_numpy_value(shape) is not None:
return check_result.fail("Shape input is already a constant.")

output = context.output_values[0]
if output.shape is None:
return check_result.fail("Output shape is not known.")

dims = list(output.shape)
sym_count = sum(1 for d in dims if not isinstance(d, int))

if sym_count == 0:
self._new_dims = [int(d) for d in dims]
elif sym_count == 1:
self._new_dims = [-1 if not isinstance(d, int) else int(d) for d in dims]
else:
return check_result.fail(
f"Output shape has {sym_count} symbolic dims, cannot materialize."
)

# Preserve allowzero attribute from original node
self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0)
return check_result

def rewrite(self, op, data: ir.Value, shape: ir.Value):
new_shape = op.Constant(
value=ir.tensor(self._new_dims, dtype=ir.DataType.INT64),
)
return op.Reshape(data, new_shape, allowzero=self._allowzero or None)


materialize_reshape_shape_rule = MaterializeReshapeShape.rule()

rules = RewriteRuleSet([materialize_reshape_shape_rule])
169 changes: 169 additions & 0 deletions onnxscript/rewriter/rules/common/_materialize_reshape_shape_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np

from onnxscript import ir
from onnxscript.rewriter import testing
from onnxscript.rewriter.rules.common import _materialize_reshape_shape


class MaterializeReshapeShapeTest(unittest.TestCase):
def test_fully_static_output_shape_materializes(self):
"""When output shape is fully static, replace dynamic shape input with constant."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[6] data) => (float[2, 3] output)
{
shape = Shape(data)
output = Reshape(data, shape)
}
"""
)
for node in model.graph:
if node.op_type == "Reshape":
node.outputs[0].shape = ir.Shape([2, 3])
break
count = _materialize_reshape_shape.rules.apply_to_model(model)
self.assertEqual(count, 1)
reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"]
self.assertEqual(len(reshape_nodes), 1)
shape_input = reshape_nodes[0].inputs[1]
self.assertIsNotNone(shape_input.const_value)
self.assertEqual(shape_input.const_value.numpy().tolist(), [2, 3])

def test_one_symbolic_dim_uses_minus_one(self):
"""When output has one symbolic dim, replace it with -1."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[6] data) => (float[B, 3] output)
{
shape = Shape(data)
output = Reshape(data, shape)
}
"""
)
for node in model.graph:
if node.op_type == "Reshape":
node.outputs[0].shape = ir.Shape(["B", 3])
break
count = _materialize_reshape_shape.rules.apply_to_model(model)
self.assertEqual(count, 1)
reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"]
self.assertEqual(len(reshape_nodes), 1)
shape_input = reshape_nodes[0].inputs[1]
self.assertIsNotNone(shape_input.const_value)
self.assertEqual(shape_input.const_value.numpy().tolist(), [-1, 3])

def test_two_symbolic_dims_not_materialized(self):
"""When output has two symbolic dims, the rule should not fire."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[6] data) => (float[B, C] output)
{
shape = Shape(data)
output = Reshape(data, shape)
}
"""
)
for node in model.graph:
if node.op_type == "Reshape":
node.outputs[0].shape = ir.Shape(["B", "C"])
break
count = _materialize_reshape_shape.rules.apply_to_model(model)
self.assertEqual(count, 0)

def test_constant_shape_input_not_replaced(self):
"""When the shape input is already a constant, the rule should not fire."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[6] data) => (float[2, 3] output)
{
shape = Constant<value: tensor = int64[2] {2, 3}>()
output = Reshape(data, shape)
}
"""
)
count = _materialize_reshape_shape.rules.apply_to_model(model)
self.assertEqual(count, 0)

def test_unknown_output_shape_not_materialized(self):
"""When the output shape is unknown, the rule should not fire."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[6] data) => (float output)
{
shape = Shape(data)
output = Reshape(data, shape)
}
"""
)
for node in model.graph:
if node.op_type == "Reshape":
node.outputs[0].shape = None
break
count = _materialize_reshape_shape.rules.apply_to_model(model)
self.assertEqual(count, 0)

def test_allowzero_attribute_preserved(self):
"""The allowzero attribute should be preserved on the new Reshape."""
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[6] data) => (float[2, 3] output)
{
shape = Shape(data)
output = Reshape<allowzero=1>(data, shape)
}
"""
)
for node in model.graph:
if node.op_type == "Reshape":
node.outputs[0].shape = ir.Shape([2, 3])
break
count = _materialize_reshape_shape.rules.apply_to_model(model)
self.assertEqual(count, 1)
reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"]
self.assertEqual(len(reshape_nodes), 1)
allowzero = reshape_nodes[0].attributes.get_int("allowzero", 0)
self.assertEqual(allowzero, 1)

def test_numerical_correctness_static(self):
"""Verify numerical equivalence for fully static materialization."""
# Build a model where a dynamic Concat produces the shape for Reshape.
# After materialization, the Reshape uses a constant shape.
model_text = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[12] data, float[3, 4] ref) => (float[3, 4] output)
{
shape = Shape(ref)
output = Reshape(data, shape)
}
"""
original = ir.from_onnx_text(model_text)
model = ir.from_onnx_text(model_text)
for node in model.graph:
if node.op_type == "Reshape":
node.outputs[0].shape = ir.Shape([3, 4])
break
_materialize_reshape_shape.rules.apply_to_model(model)
testing.assert_numerically_equal(
original,
model,
(
np.arange(12).astype(np.float32),
np.zeros((3, 4), dtype=np.float32),
),
)


if __name__ == "__main__":
unittest.main()
Loading