diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py new file mode 100644 index 0000000000..aa2e03b0d4 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses Reduce{Max,Min} and Arg{Max,Min} patterns into TopK. + +Supported transformations: +- ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0] +- ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] + +Constraints: + - This rule only works for opset 18+. + - Both nodes must operate on the same input X. + - Both nodes must target the same axis. + - Both nodes must have the same keepdims attribute value. + - The Reduce node must operate on a single axis (len(axes) == 1). +""" + +from __future__ import annotations + +from abc import abstractmethod + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _FuseReduceArgToTopKBase(RewriteRuleClassBase): + """Base class for fusing Reduce{Max,Min} + Arg{Max,Min} into TopK. + + This base class contains the common logic for checking and rewriting patterns where + a Reduce operation and its corresponding Arg operation can be replaced with a single + TopK operation. + + Subclasses must implement: + - pattern(): Define the specific Reduce and Arg operations to match + - largest: Property returning 1 for Max operations, 0 for Min operations + """ + + @property + @abstractmethod + def largest(self) -> int: + """Return 1 for Max operations (largest elements), 0 for Min operations (smallest elements).""" + + @staticmethod + def _normalize_axis(axis: int, rank: int | None) -> int: + """Normalize a potentially negative axis to a positive axis index. + + Args: + axis: The axis to normalize (can be negative). + rank: The rank of the tensor, or None if unknown. + + Returns: + The normalized axis (non-negative if rank is known and axis was negative). + """ + if rank is not None and axis < 0: + return axis + rank + return axis + + def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: + """Check if Reduce and Arg operations can be safely fused into TopK. + + Conditions: + - Both nodes must have the same keepdims attribute. + - The Reduce node must operate on a single axis. + - Both nodes must operate on the same axis. + - The Arg node must not use select_last_index=1 (TopK doesn't support this). + + Args: + context: The rewrite context (unused). + reduce_val: The output of the Reduce operation (ReduceMax/ReduceMin). + arg_idx: The output of the Arg operation (ArgMax/ArgMin). + + Returns: + MatchResult: Success if the pattern can be fused, Failure otherwise. + """ + del context + check_result = MatchResult() + + reduce_node = reduce_val.producer() + arg_node = arg_idx.producer() + + # Step 1: Get keepdims attribute from both nodes + reduce_keepdims = reduce_node.attributes.get_int("keepdims", 1) + arg_keepdims = arg_node.attributes.get_int("keepdims", 1) + + # Step 2: Check if keepdims match + if reduce_keepdims != arg_keepdims: + return check_result.fail( + f"keepdims mismatch: {reduce_node.op_type} has {reduce_keepdims}, " + f"{arg_node.op_type} has {arg_keepdims}." + ) + + # Step 3: Get axes from Reduce ops's inputs + if len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None: + axes_input = reduce_node.inputs[1] + axes_tensor = ir.convenience.get_const_tensor(axes_input) + if axes_tensor is None: + return check_result.fail( + f"{reduce_node.op_type} axes input is not a constant." + ) + try: + axes_array = axes_tensor.numpy() + axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)] + except Exception: + return check_result.fail(f"Cannot parse {reduce_node.op_type} axes input.") + else: + return check_result.fail(f"{reduce_node.op_type} axes not found in inputs.") + + # Step 4: Check that Reduce operates on exactly one axis + if len(axes_list) != 1: + return check_result.fail( + f"{reduce_node.op_type} must operate on a single axis, got {len(axes_list)} axes." + ) + + reduce_axis = axes_list[0] + + # Step 5: Get axis from Arg operation + # ONNX default: axis = 0 for ArgMax/ArgMin + arg_axis = arg_node.attributes.get_int("axis", 0) + + # Step 6: Check select_last_index attribute (if present) + # TopK always returns the first occurrence in case of ties + select_last_index_attr = arg_node.attributes.get_int("select_last_index", 0) + if select_last_index_attr != 0: + return check_result.fail( + f"{arg_node.op_type} has select_last_index=1, which is not supported by TopK." + ) + + # Step 7: Normalize axes if rank is known (handle negative indices) + input_x = reduce_node.inputs[0] + rank = input_x.shape.rank() if input_x.shape is not None else None + + if self._normalize_axis(reduce_axis, rank) != self._normalize_axis(arg_axis, rank): + return check_result.fail( + f"Axis mismatch: {reduce_node.op_type} operates on axis {reduce_axis}, " + f"{arg_node.op_type} operates on axis {arg_axis}." + ) + + return check_result + + def rewrite(self, op, x, reduce_val, arg_idx): + """Rewrite the matched pattern with TopK (and optionally Squeeze). + + Args: + op: The operation builder. + x: The input to both Reduce and Arg operations. + reduce_val: The output of the Reduce operation. + arg_idx: The output of the Arg operation. + + Returns: + Tuple of (values, indices) matching the original outputs. + """ + # Step 1: Get the nodes + arg_node = arg_idx.producer() + + # Step 2: Extract necessary attributes with ONNX default values + axis = arg_node.attributes.get_int("axis", 0) + keepdims = arg_node.attributes.get_int("keepdims", 1) + + # Step 2b: Normalize axis (convert negative to positive) if rank is known + if axis < 0 and x.shape is not None: + axis = x.shape.rank() + axis + + # Step 3: Create K constant + k_constant = op.Constant(value=ir.tensor(np.array([1], dtype=np.int64))) + + # Step 4: Create TopK node + topk_values, topk_indices = op.TopK( + x, + k_constant, + axis=axis, + largest=self.largest, + sorted=1, + _outputs=2, + ) + + # Step 5: Handle keepdims=0 case + if keepdims == 0: + # TopK output always keeps the dimension (just makes it size 1) + # We need to squeeze it to match the original Reduce/Arg behavior + axes_constant = op.Constant(value=ir.tensor(np.array([axis], dtype=np.int64))) + + new_values = op.Squeeze(topk_values, axes_constant) + new_indices = op.Squeeze(topk_indices, axes_constant) + else: + new_values = topk_values + new_indices = topk_indices + + return new_values, new_indices + + +class FuseReduceMaxArgMaxToTopK(_FuseReduceArgToTopKBase): + """Replaces ReduceMax + ArgMax with TopK(largest=1). + + Transformation: + ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) + → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0] + + When keepdims=0, the output of TopK is squeezed to match the original output shapes. + """ + + @property + def largest(self) -> int: + return 1 # TopK returns largest elements + + def pattern(self, op, x): + """Define the pattern to match: ReduceMax and ArgMax on the same input. + + Note: For opset 18+, ReduceMax has a second input for axes, which we allow + but will validate in check() to ensure it's a constant. + """ + reduce_val = op.ReduceMax(x, _allow_other_inputs=True, _outputs=["reduce_val"]) + arg_idx = op.ArgMax(x, _outputs=["arg_idx"]) + return reduce_val, arg_idx + + +class FuseReduceMinArgMinToTopK(_FuseReduceArgToTopKBase): + """Replaces ReduceMin + ArgMin with TopK(largest=0). + + Transformation: + ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) + → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] + + When keepdims=0, the output of TopK is squeezed to match the original output shapes. + """ + + @property + def largest(self) -> int: + return 0 # TopK returns smallest elements + + def pattern(self, op, x): + """Define the pattern to match: ReduceMin and ArgMin on the same input. + + Note: For opset 18+, ReduceMin has a second input for axes, which we allow + but will validate in check() to ensure it's a constant. + """ + reduce_val = op.ReduceMin(x, _allow_other_inputs=True, _outputs=["reduce_val"]) + arg_idx = op.ArgMin(x, _outputs=["arg_idx"]) + return reduce_val, arg_idx + + +reduce_max_argmax_to_topk_rule = FuseReduceMaxArgMaxToTopK().rule() +reduce_min_argmin_to_topk_rule = FuseReduceMinArgMinToTopK().rule() + +rules = RewriteRuleSet([reduce_max_argmax_to_topk_rule, reduce_min_argmin_to_topk_rule]) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py new file mode 100644 index 0000000000..58ed6dec4e --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py @@ -0,0 +1,579 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript.rewriter import MatchingTracer, MatchStatus, testing +from onnxscript.rewriter.rules.common._fuse_reduce_arg_to_topk import ( + reduce_max_argmax_to_topk_rule, + reduce_min_argmin_to_topk_rule, + rules, +) + + +class FuseReduceArgToTopKTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20260127) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + count = rules.apply_to_model(updated_model) + + # Check that the rule was applied + self.assertGreater(count, 0) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = ( + self.rng.uniform( + low=-10.0, + high=10.0, + size=(2, *updated_model.graph.inputs[0].shape[1:]), + ).astype(np.float32), + ) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rule][0] + self.assertEqual(tracer_match.status, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class TestFuseReduceMaxArgMaxToTopK(FuseReduceArgToTopKTestBase): + @parameterized.expand( + [ + ("keepdims_1_axis_1", 1, 1), + ("keepdims_1_axis_2", 1, 2), + ("keepdims_1_axis_neg1", 1, -1), + ("keepdims_0_axis_1", 0, 1), + ("keepdims_0_axis_2", 0, 2), + ("keepdims_0_axis_neg1", 0, -1), + ] + ) + def test_successful_fuse_reduce_argmax_to_topk(self, _, keepdims, axis): + """Test fusion of ReduceMax + ArgMax into TopK with various keepdims and axis values.""" + # When keepdims=0, the output rank is reduced by 1 + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} max_val, int64{output_shape_str} max_idx) + + {{ + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + }} + """) + + # Expected: Constant for K, TopK, possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] + else: + expected_op_types = ["Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + def test_fuse_reduce_argmax_explicit_axis_0(self): + """Test fusion with explicit axis=0.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + @parameterized.expand( + [ + ("keepdims_1_axis_2", 1, 2), + ("keepdims_0_axis_2", 0, 2), + ] + ) + def test_fuse_reduce_argmax_axes_from_constant_node(self, _, keepdims, axis): + """Test fusion when axes come from a Constant node (opset 18+).""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} max_val, int64{output_shape_str} max_idx) + {{ + axes = Constant() + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + }} + """) + + # Expected: Constant for axes, Constant for K, TopK, + # possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = [ + "Constant", + "Constant", + "TopK", + "Constant", + "Squeeze", + "Squeeze", + ] + else: + expected_op_types = ["Constant", "Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_reduce_argmax_mixed_negative_positive_axes(self): + """Test fusion when ReduceMax uses negative axis and ArgMax uses positive axis. + + Input shape is [N, 32, 14, 17], rank is 4. Axis -1 is equivalent to axis 3. + The rule should normalize both axes before comparison. + """ + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, 32, 14, 1] max_val, int64[N, 32, 14, 1] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_fail_keepdims_mismatch(self): + """Test that fusion fails when keepdims values don't match.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, reduce_max_argmax_to_topk_rule, "keepdims mismatch" + ) + + def test_fail_axis_mismatch(self): + """Test that fusion fails when axes don't match.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, reduce_max_argmax_to_topk_rule, "Axis mismatch" + ) + + def test_fail_multiple_axes_reduce_max(self): + """Test that fusion fails when ReduceMax operates on multiple axes.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, + reduce_max_argmax_to_topk_rule, + "ReduceMax must operate on a single axis", + ) + + def test_fail_select_last_index_argmax(self): + """Test that fusion fails when ArgMax has select_last_index=1.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, + reduce_max_argmax_to_topk_rule, + "ArgMax has select_last_index=1, which is not supported by TopK.", + ) + + def test_successful_fuse_with_default_keepdims(self): + """Test fusion with default keepdims (should be 1).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + # Both should use default keepdims=1, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_default_axis(self): + """Test fusion with default axis (should be 0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + # ArgMax should use default axis=0, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_all_defaults(self): + """Test fusion with all default values (keepdims=1, axis=0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + } + """) + + # Both should use defaults: keepdims=1, axis=0 + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_no_fusion_different_inputs(self): + """Test that fusion doesn't happen when nodes have different inputs.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X, float[N, 32, 14, 17] Y) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + + { + max_val = ReduceMax(X, axes) + max_idx = ArgMax(Y) + } + """) + + # Pattern won't match at all because inputs are different + updated_model = self.clone_model(base_model) + count = rules.apply_to_model(updated_model) + self.assertEqual(count, 0) + + # Model should be unchanged + self.assertEqual( + [node.op_type for node in base_model.graph], + [node.op_type for node in updated_model.graph], + ) + + +class TestFuseReduceMinArgMinToTopK(FuseReduceArgToTopKTestBase): + """Test cases for ReduceMin + ArgMin → TopK(largest=0) fusion.""" + + @parameterized.expand( + [ + ("keepdims_1_axis_1", 1, 1), + ("keepdims_1_axis_2", 1, 2), + ("keepdims_1_axis_neg1", 1, -1), + ("keepdims_0_axis_1", 0, 1), + ("keepdims_0_axis_2", 0, 2), + ("keepdims_0_axis_neg1", 0, -1), + ] + ) + def test_successful_fuse_reduce_argmin_to_topk(self, _, keepdims, axis): + """Test fusion of ReduceMin + ArgMin into TopK with various keepdims and axis values.""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} min_val, int64{output_shape_str} min_idx) + + {{ + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + }} + """) + + if keepdims == 0: + expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] + else: + expected_op_types = ["Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + def test_fuse_reduce_argmin_explicit_axis_0(self): + """Test fusion with explicit axis=0.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + @parameterized.expand( + [ + ("keepdims_1_axis_2", 1, 2), + ("keepdims_0_axis_2", 0, 2), + ] + ) + def test_fuse_reduce_argmin_axes_from_constant_node(self, _, keepdims, axis): + """Test fusion when axes come from a Constant node for Min operations (opset 18+).""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} min_val, int64{output_shape_str} min_idx) + {{ + axes = Constant() + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + }} + """) + + # Expected: Constant for axes, Constant for K, TopK, + # possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = [ + "Constant", + "Constant", + "TopK", + "Constant", + "Squeeze", + "Squeeze", + ] + else: + expected_op_types = ["Constant", "Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_reduce_argmin_mixed_axes(self): + """Test fusion with mixed negative/positive axes for Min operations. + + Axis -2 is equivalent to axis 2 for rank-4 tensors. + """ + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, 32, 1, 17] min_val, int64[N, 32, 1, 17] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_fail_axis_mismatch(self): + """Test that fusion fails when axes don't match for Min operations.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + self.run_failed_condition_test( + base_model, reduce_min_argmin_to_topk_rule, "Axis mismatch" + ) + + def test_fail_keepdims_mismatch(self): + """Test that fusion fails when keepdims values don't match for Min operations.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + self.run_failed_condition_test( + base_model, reduce_min_argmin_to_topk_rule, "keepdims mismatch" + ) + + def test_fail_multiple_axes_reduce_min(self): + """Test that fusion fails when ReduceMin operates on multiple axes.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + + self.run_failed_condition_test( + base_model, + reduce_min_argmin_to_topk_rule, + "ReduceMin must operate on a single axis", + ) + + def test_fail_select_last_index_argmin(self): + """Test that fusion fails when ArgMin has select_last_index=1.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + self.run_failed_condition_test( + base_model, + reduce_min_argmin_to_topk_rule, + "ArgMin has select_last_index=1, which is not supported by TopK.", + ) + + def test_successful_fuse_with_default_keepdims(self): + """Test fusion with default keepdims (should be 1).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + + # Both should use default keepdims=1, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_default_axis(self): + """Test fusion with default axis (should be 0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + + # ArgMin should use default axis=0, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_all_defaults(self): + """Test fusion with all default values (keepdims=1, axis=0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + } + """) + + # Both should use defaults: keepdims=1, axis=0 + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_no_fusion_different_inputs(self): + """Test that fusion doesn't happen when nodes have different inputs.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X, float[N, 32, 14, 17] Y) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + + { + min_val = ReduceMin(X, axes) + min_idx = ArgMin(Y) + } + """) + + # Pattern won't match at all because inputs are different + updated_model = self.clone_model(base_model) + count = rules.apply_to_model(updated_model) + self.assertEqual(count, 0) + + # Model should be unchanged + self.assertEqual( + [node.op_type for node in base_model.graph], + [node.op_type for node in updated_model.graph], + ) + + +if __name__ == "__main__": + unittest.main()