From e4792d29ae9a3ed494caab2cb1cc9ebe7f645676 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 9 Feb 2026 14:58:47 +0100 Subject: [PATCH 01/11] Add converter torch aten::histc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Xavier Dupré --- .../function_libs/torch_lib/ops/core.py | 34 ++++++++++++++++++- .../function_libs/torch_lib/e2e_ops_tests.py | 20 +++++++++++ .../function_libs/torch_lib/ops_test_data.py | 1 + 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 860b878edb..05ea9c7bfe 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4608,12 +4608,44 @@ def aten_hinge_embedding_loss( raise NotImplementedError() +@torch_op("aten::histc", trace_only=True) def aten_histc( self: TensorType, bins: int = 100, min: float = 0.0, max: float = 0.0 ) -> TensorType: """histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor""" + delta = (max - min) / (bins * 1.0) + values = [min + delta * i for i in range(bins + 1)] - raise NotImplementedError() + flat_self = op.Reshape(self, [-1]) + computation_type = self.type.dtype + + cond = op.And( + op.GreaterOrEqual(flat_self, op.CastLike([min], self)), + op.LessOrEqual(flat_self, op.CastLike([max], self)), + ) + if self.type.dtype in {ir.DataType.INT32, ir.DataType.INT64}: + # max is included. + values[-1] += 1 + else: + cond = op.And(cond, op.Not(op.IsNaN(flat_self))) + # max is included. + dtype = self.type.dtype.numpy() + values[-1] = np.nextafter(values[-1], np.array(np.inf, dtype=dtype), dtype=dtype) + typed_values = op.Constant(value=ir.tensor(values, dtype=self.type.dtype)) + + clipped = op.Where(cond, flat_self, op.CastLike([min - 1], self)) + bins = op.Unsqueeze(typed_values, [1]) + + less = op.Cast( + op.Less(op.Unsqueeze(clipped, [0]), bins), + to=computation_type, + ) + sums = op.ReduceSum(less, [1], keepdims=0) + res = op.Sub( + op.Slice(sums, [1], op.Shape(sums), [0]), + op.Slice(sums, [0], [-1], [0]), + ) + return res def aten_histogramdd( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 9ee12f3ac3..381dcced23 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -914,6 +914,26 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_aten_histc_float(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.histc(x, 3, 0, 2) + + inputs = ((torch.arange(20) / 10).to(torch.float32),) + model = Model() + expected = model(*inputs) + onnx_program = torch.onnx.export( + Model(), (torch.rand(10, 10, 10),), dynamo=True, verbose=False, dynamic_shapes=({0: "batch"},) + ) + _testing.assert_onnx_program(onnx_program) + + for k in range(101): + with self.subTest(k=k): + inputs = (torch.tensor([(k - 1) / 49.0], dtype=torch.float32),) + expected = model(*inputs) + got = onnx_program.call_reference({"x": inputs[0]}) + torch.testing.assert_close(expected, got[0]) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index f9870ed840..273c778023 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -747,6 +747,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("ge", core_ops.aten_ge), TorchLibOpInfo("gt", core_ops.aten_gt), + TorchLibOpInfo("histc", core_ops.aten_histc), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), From f98ee582d5faeffb4f580cd539ac43c1d973438f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 9 Feb 2026 15:03:56 +0100 Subject: [PATCH 02/11] Potential fix for pull request finding 'Variable defined multiple times' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- tests/function_libs/torch_lib/e2e_ops_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 381dcced23..ea7081a407 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -921,7 +921,6 @@ def forward(self, x): inputs = ((torch.arange(20) / 10).to(torch.float32),) model = Model() - expected = model(*inputs) onnx_program = torch.onnx.export( Model(), (torch.rand(10, 10, 10),), dynamo=True, verbose=False, dynamic_shapes=({0: "batch"},) ) From 8ac594cad8acdb502570b75bd0f614617ba624b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 9 Feb 2026 15:06:42 +0100 Subject: [PATCH 03/11] Potential fix for pull request finding 'Variable defined multiple times' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- tests/function_libs/torch_lib/e2e_ops_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index ea7081a407..5556ca17a3 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -919,7 +919,6 @@ class Model(torch.nn.Module): def forward(self, x): return torch.histc(x, 3, 0, 2) - inputs = ((torch.arange(20) / 10).to(torch.float32),) model = Model() onnx_program = torch.onnx.export( Model(), (torch.rand(10, 10, 10),), dynamo=True, verbose=False, dynamic_shapes=({0: "batch"},) From a92e79162b611da555fbb3cdf8de9c4df2655b00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 12:47:04 +0100 Subject: [PATCH 04/11] disable tests on float16 --- .../function_libs/torch_lib/ops/core.py | 1 + .../function_libs/torch_lib/e2e_ops_tests.py | 23 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 13 ++++++++++- 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 05ea9c7bfe..c610c97de9 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4630,6 +4630,7 @@ def aten_histc( cond = op.And(cond, op.Not(op.IsNaN(flat_self))) # max is included. dtype = self.type.dtype.numpy() + values = np.array(values, dtype=dtype) values[-1] = np.nextafter(values[-1], np.array(np.inf, dtype=dtype), dtype=dtype) typed_values = op.Constant(value=ir.tensor(values, dtype=self.type.dtype)) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 5556ca17a3..f99e57c155 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -932,6 +932,29 @@ def forward(self, x): got = onnx_program.call_reference({"x": inputs[0]}) torch.testing.assert_close(expected, got[0]) + @unittest.skip("see https://github.com/pytorch/pytorch/issues/174668") + def test_aten_histc_float16(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.histc(x, 60, -10, 10) + + model = Model() + onnx_program = torch.onnx.export( + Model(), + (torch.rand((10, 10, 10), dtype=torch.float16),), + dynamo=True, + verbose=False, + dynamic_shapes=({0: "batch"},), + ) + _testing.assert_onnx_program(onnx_program) + + for k in range(101): + with self.subTest(k=k): + inputs = (torch.tensor([(k - 1) / 49.0], dtype=torch.float16),) + expected = model(*inputs) + got = onnx_program.call_reference({"x": inputs[0]}) + torch.testing.assert_close(expected, got[0]) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 273c778023..78c5c4609d 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -747,7 +747,18 @@ def _where_input_wrangler( ), TorchLibOpInfo("ge", core_ops.aten_ge), TorchLibOpInfo("gt", core_ops.aten_gt), - TorchLibOpInfo("histc", core_ops.aten_histc), + TorchLibOpInfo("histc", core_ops.aten_histc).skip( + matcher=lambda sample: ( + sample.kwargs.get("min") == sample.kwargs.get("max") + or sample.dtype == torch.float16 + ), + reason="then min=sample.min(), max=sample.max(), torch.histc does not" + "define what happens when both are equal (1 sample with one element " + "for example). torch does something, maybe " + "something like zeros(bins)[bins // 2 + 1] = 1., " + "we skip float16 because of " + "https://github.com/pytorch/pytorch/issues/174668", + ), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), From 0ca2a3fe6206be8de848aa7949986fdd5b71aa4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 13:38:12 +0100 Subject: [PATCH 05/11] better syntax --- tests/function_libs/torch_lib/ops_test_data.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 78c5c4609d..8de962c13d 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -748,10 +748,8 @@ def _where_input_wrangler( TorchLibOpInfo("ge", core_ops.aten_ge), TorchLibOpInfo("gt", core_ops.aten_gt), TorchLibOpInfo("histc", core_ops.aten_histc).skip( - matcher=lambda sample: ( - sample.kwargs.get("min") == sample.kwargs.get("max") - or sample.dtype == torch.float16 - ), + dtypes=(torch.float16,), + matcher=lambda sample: sample.kwargs.get("min") == sample.kwargs.get("max"), reason="then min=sample.min(), max=sample.max(), torch.histc does not" "define what happens when both are equal (1 sample with one element " "for example). torch does something, maybe " From a500508057236a683300a0d22baa3168a7b41e3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 11:04:31 +0100 Subject: [PATCH 06/11] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c610c97de9..21b6d72249 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4617,7 +4617,7 @@ def aten_histc( values = [min + delta * i for i in range(bins + 1)] flat_self = op.Reshape(self, [-1]) - computation_type = self.type.dtype + computation_type = self.dtype cond = op.And( op.GreaterOrEqual(flat_self, op.CastLike([min], self)), From 456e953efd7918c6289117eda2734e6edcddc18a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 11:11:58 +0100 Subject: [PATCH 07/11] better skip --- tests/function_libs/torch_lib/e2e_ops_tests.py | 6 +++++- tests/function_libs/torch_lib/ops_test_data.py | 7 ++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index f99e57c155..019e6f7fe5 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -921,7 +921,11 @@ def forward(self, x): model = Model() onnx_program = torch.onnx.export( - Model(), (torch.rand(10, 10, 10),), dynamo=True, verbose=False, dynamic_shapes=({0: "batch"},) + Model(), + (torch.rand(10, 10, 10),), + dynamo=True, + verbose=False, + dynamic_shapes=({0: "batch"},), ) _testing.assert_onnx_program(onnx_program) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8de962c13d..bf03857574 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -748,13 +748,14 @@ def _where_input_wrangler( TorchLibOpInfo("ge", core_ops.aten_ge), TorchLibOpInfo("gt", core_ops.aten_gt), TorchLibOpInfo("histc", core_ops.aten_histc).skip( - dtypes=(torch.float16,), matcher=lambda sample: sample.kwargs.get("min") == sample.kwargs.get("max"), reason="then min=sample.min(), max=sample.max(), torch.histc does not" "define what happens when both are equal (1 sample with one element " "for example). torch does something, maybe " - "something like zeros(bins)[bins // 2 + 1] = 1., " - "we skip float16 because of " + "something like zeros(bins)[bins // 2 + 1] = 1.", + ).skip( + dtypes=(torch.float16,), + reason="we skip float16 because of " "https://github.com/pytorch/pytorch/issues/174668", ), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB From 500dc019e1f62279e03ccb4eb3885ee59c8af116 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 11:19:43 +0100 Subject: [PATCH 08/11] lint --- tests/function_libs/torch_lib/ops_test_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index bf03857574..041045fbcf 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -747,7 +747,8 @@ def _where_input_wrangler( ), TorchLibOpInfo("ge", core_ops.aten_ge), TorchLibOpInfo("gt", core_ops.aten_gt), - TorchLibOpInfo("histc", core_ops.aten_histc).skip( + TorchLibOpInfo("histc", core_ops.aten_histc) + .skip( matcher=lambda sample: sample.kwargs.get("min") == sample.kwargs.get("max"), reason="then min=sample.min(), max=sample.max(), torch.histc does not" "define what happens when both are equal (1 sample with one element " @@ -755,8 +756,7 @@ def _where_input_wrangler( "something like zeros(bins)[bins // 2 + 1] = 1.", ).skip( dtypes=(torch.float16,), - reason="we skip float16 because of " - "https://github.com/pytorch/pytorch/issues/174668", + reason="we skip float16 because of https://github.com/pytorch/pytorch/issues/174668", ), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB From cae0f02f651ca66eb9b1320d56bf3e2a0364d864 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 11:34:42 +0100 Subject: [PATCH 09/11] lint --- tests/function_libs/torch_lib/ops_test_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 041045fbcf..ebaea3fd77 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -754,7 +754,8 @@ def _where_input_wrangler( "define what happens when both are equal (1 sample with one element " "for example). torch does something, maybe " "something like zeros(bins)[bins // 2 + 1] = 1.", - ).skip( + ) + .skip( dtypes=(torch.float16,), reason="we skip float16 because of https://github.com/pytorch/pytorch/issues/174668", ), From 65b365f369f29137c205e67dd2e4ee5d0c0fa7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 17:07:14 +0100 Subject: [PATCH 10/11] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 21b6d72249..8c2078346f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4613,6 +4613,13 @@ def aten_histc( self: TensorType, bins: int = 100, min: float = 0.0, max: float = 0.0 ) -> TensorType: """histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor""" + if min == max: + # This ONNXScript implementation precomputes static bin edges and cannot + # faithfully reproduce torch.histc's dynamic behavior when min == max + # (including the default min=0, max=0, which infers the range from data). + raise NotImplementedError( + f"aten_histc with min == max ({min}) is not supported in this export path." + ) delta = (max - min) / (bins * 1.0) values = [min + delta * i for i in range(bins + 1)] From 86dd6e018065e44d9ab9a72fdbcc7ef575f8477e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 17:17:40 +0100 Subject: [PATCH 11/11] remove int case --- .../function_libs/torch_lib/ops/core.py | 19 +- .../function_libs/torch_lib/ops_test_data.py | 170 +++++++++++------- 2 files changed, 115 insertions(+), 74 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8c2078346f..9dfdda3a90 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4630,15 +4630,16 @@ def aten_histc( op.GreaterOrEqual(flat_self, op.CastLike([min], self)), op.LessOrEqual(flat_self, op.CastLike([max], self)), ) - if self.type.dtype in {ir.DataType.INT32, ir.DataType.INT64}: - # max is included. - values[-1] += 1 - else: - cond = op.And(cond, op.Not(op.IsNaN(flat_self))) - # max is included. - dtype = self.type.dtype.numpy() - values = np.array(values, dtype=dtype) - values[-1] = np.nextafter(values[-1], np.array(np.inf, dtype=dtype), dtype=dtype) + + assert self.type.dtype not in {ir.DataType.INT32, ir.DataType.INT64}, ( + f"torch.histc only works on float but {self.type.dtype=}" + ) + + cond = op.And(cond, op.Not(op.IsNaN(flat_self))) + # max is included. + dtype = self.type.dtype.numpy() + values = np.array(values, dtype=dtype) + values[-1] = np.nextafter(values[-1], np.array(np.inf, dtype=dtype), dtype=dtype) typed_values = op.Constant(value=ir.tensor(values, dtype=self.type.dtype)) clipped = op.Where(cond, flat_self, op.CastLike([min - 1], self)) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index ebaea3fd77..a40535f4ba 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -480,8 +480,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( - matcher=lambda sample: not (len(sample.kwargs) > 0) - or isinstance(sample.kwargs.get("dim"), tuple), + matcher=lambda sample: ( + not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple) + ), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", ), TorchLibOpInfo("all_dims", core_ops.aten_all_dims).skip( @@ -518,9 +519,11 @@ def _where_input_wrangler( ) .skip( "decomposed", - matcher=lambda sample: torch.numel(sample.input) == 0 - or torch.numel(sample.args[0]) == 0 - or torch.numel(sample.args[1]) == 0, + matcher=lambda sample: ( + torch.numel(sample.input) == 0 + or torch.numel(sample.args[0]) == 0 + or torch.numel(sample.args[1]) == 0 + ), reason="zero sized inputs cannot be compared", ), TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}), @@ -534,8 +537,9 @@ def _where_input_wrangler( reason="this Aten overload only support one tensor as input by design", ), TorchLibOpInfo("any_dim", core_ops.aten_any_dim).skip( - matcher=lambda sample: not (len(sample.kwargs) > 0) - or isinstance(sample.kwargs.get("dim"), tuple), + matcher=lambda sample: ( + not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple) + ), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", ), TorchLibOpInfo("any_dims", core_ops.aten_any_dims).skip( @@ -750,7 +754,7 @@ def _where_input_wrangler( TorchLibOpInfo("histc", core_ops.aten_histc) .skip( matcher=lambda sample: sample.kwargs.get("min") == sample.kwargs.get("max"), - reason="then min=sample.min(), max=sample.max(), torch.histc does not" + reason="then min=sample.min(), max=sample.max(), torch.histc does not " "define what happens when both are equal (1 sample with one element " "for example). torch does something, maybe " "something like zeros(bins)[bins // 2 + 1] = 1.", @@ -883,8 +887,10 @@ def _where_input_wrangler( TorchLibOpInfo("mH", core_ops.aten_mH), TorchLibOpInfo("mH", core_ops.aten_mH_complex, complex=True), TorchLibOpInfo("min_dim", core_ops.aten_min_dim).xfail( - matcher=lambda sample: len(sample.args) == 0 - or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), + matcher=lambda sample: ( + len(sample.args) == 0 + or (len(sample.args) > 0 and not isinstance(sample.args[0], int)) + ), reason="this ATen overload only support one tensor as input and another int as args", ), TorchLibOpInfo("min", core_ops.aten_min).skip( @@ -921,8 +927,13 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-2)}, input_wrangler=_cross_entropy_input_wrangler, ).xfail( - matcher=lambda sample: len(sample.args) < 1 - or (isinstance(sample.args[0], torch.Tensor) and sample.args[0].dtype != torch.int64), + matcher=lambda sample: ( + len(sample.args) < 1 + or ( + isinstance(sample.args[0], torch.Tensor) + and sample.args[0].dtype != torch.int64 + ) + ), reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[target] as int type", ), TorchLibOpInfo( @@ -1029,10 +1040,12 @@ def _where_input_wrangler( nn_ops.aten_replication_pad3d, input_wrangler=_replication_pad3d_input_wrangler, ).skip( - matcher=lambda sample: not ( - len(sample.args) > 1 - and sample.args[1] == "replicate" - and len(sample.input.shape) == 5 + matcher=lambda sample: ( + not ( + len(sample.args) > 1 + and sample.args[1] == "replicate" + and len(sample.input.shape) == 5 + ) ), reason="this Aten overload need args[1] == 'replicate' for pad mode, and 3D tensor", ), @@ -1080,16 +1093,18 @@ def _where_input_wrangler( TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), TorchLibOpInfo("prod", core_ops.aten_prod).skip( - matcher=lambda sample: sample.kwargs.get("dim") is not None - or sample.kwargs.get("keepdim") is not None - or len(sample.args) > 0, + matcher=lambda sample: ( + sample.kwargs.get("dim") is not None + or sample.kwargs.get("keepdim") is not None + or len(sample.args) > 0 + ), reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip( matcher=lambda sample: ( - sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None - ) - or sample.kwargs.get("dtype") != -1, + (sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None) + or sample.kwargs.get("dtype") != -1 + ), reason="this Aten overload can accept 3 inputs:(self, dim, keepdim)", ), TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu), @@ -1222,8 +1237,9 @@ def _where_input_wrangler( reason="this Aten overload only support one tensor as input and one int as args by design", ) .skip( - matcher=lambda sample: len(sample.input.shape) != 0 - and sample.input.shape[sample.args[0]] != 1, + matcher=lambda sample: ( + len(sample.input.shape) != 0 and sample.input.shape[sample.args[0]] != 1 + ), reason="this Aten overload only support squeeze dim with size 1", ), TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True) @@ -1232,8 +1248,9 @@ def _where_input_wrangler( reason="this Aten overload only support one tensor as input and one int as args by design", ) .skip( - matcher=lambda sample: len(sample.input.shape) != 0 - and sample.input.shape[sample.args[0]] != 1, + matcher=lambda sample: ( + len(sample.input.shape) != 0 and sample.input.shape[sample.args[0]] != 1 + ), reason="this Aten overload only support squeeze dim with size 1", ), TorchLibOpInfo("squeeze", core_ops.aten_squeeze).skip( @@ -1253,8 +1270,9 @@ def _where_input_wrangler( TorchLibOpInfo("tan", core_ops.aten_tan), TorchLibOpInfo("tanh", core_ops.aten_tanh), TorchLibOpInfo("tile", core_ops.aten_tile).skip( - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) - or not sample.input.shape, + matcher=lambda sample: ( + any(dim == 0 for dim in sample.input.shape) or not sample.input.shape + ), reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), TorchLibOpInfo("topk", core_ops.aten_topk) @@ -1391,8 +1409,9 @@ def _where_input_wrangler( tolerance={torch.float16: (8e-2, 2e-3)}, ).skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT - matcher=lambda sample: sample.kwargs.get("mode") == "bicubic" - or len(sample.args[0].shape) != 4, + matcher=lambda sample: ( + sample.kwargs.get("mode") == "bicubic" or len(sample.args[0].shape) != 4 + ), reason="fixme: 'bicubic' mode in ORT implemented differently with Torch and only support 4D-tensor", ), TorchLibOpInfo( @@ -1425,8 +1444,10 @@ def _where_input_wrangler( enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, ) .xfail( - matcher=lambda sample: len(sample.args) == 0 - or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), + matcher=lambda sample: ( + len(sample.args) == 0 + or (len(sample.args) > 0 and not isinstance(sample.args[0], int)) + ), reason="this ATen overload only support one tensor as input and another int as args", ), TorchLibOpInfo("max", core_ops.aten_max).skip( @@ -1467,8 +1488,9 @@ def _where_input_wrangler( reason="native_batch_norm outputs different dtypes on CPU and CUDA. Our implematation is based on that for CUDA", ) .skip( - matcher=lambda sample: sample.kwargs.get("training") is True - or sample.args[-3] is True, + matcher=lambda sample: ( + sample.kwargs.get("training") is True or sample.args[-3] is True + ), reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( @@ -1482,8 +1504,9 @@ def _where_input_wrangler( reason="native_batch_norm outputs different shapes on CPU and CUDA when training is False. Our implematation is based on that for CUDA", ) .skip( - matcher=lambda sample: sample.kwargs.get("training") is True - or sample.args[-3] is True, + matcher=lambda sample: ( + sample.kwargs.get("training") is True or sample.args[-3] is True + ), reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( @@ -1500,8 +1523,9 @@ def _where_input_wrangler( reason="native_batch_norm outputs different results on CPU and CUDA when training is False. Our implematation is based on that for CUDA", ) .skip( - matcher=lambda sample: sample.kwargs.get("training") is True - or sample.args[-3] is True, + matcher=lambda sample: ( + sample.kwargs.get("training") is True or sample.args[-3] is True + ), reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( @@ -1531,17 +1555,21 @@ def _where_input_wrangler( input_wrangler=_avg_pool_input_wrangler, ) .xfail( - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), + matcher=lambda sample: ( + (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None) + ), reason="ONNX doesn't support divisor_override argument", ) .xfail( - matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) - and ( - sample.kwargs.get("count_include_pad") is True - or sample.input.shape[2] - % (sample.args[0][0] if isinstance(sample.args[0], tuple) else sample.args[0]) - != 0 + matcher=lambda sample: ( + (sample.kwargs.get("ceil_mode") is True) + and ( + sample.kwargs.get("count_include_pad") is True + or sample.input.shape[2] + % (sample.args[0][0] if isinstance(sample.args[0], tuple) else sample.args[0]) + != 0 + ) ), reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", ), @@ -1550,8 +1578,10 @@ def _where_input_wrangler( nn_ops.aten_avg_pool2d, input_wrangler=_avg_pool_input_wrangler, ).xfail( - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), + matcher=lambda sample: ( + (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None) + ), reason="ONNX doesn't support divisor_override argument", ), TorchLibOpInfo( @@ -1560,8 +1590,10 @@ def _where_input_wrangler( input_wrangler=_avg_pool_input_wrangler, ) .xfail( - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), + matcher=lambda sample: ( + (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None) + ), reason="ONNX doesn't support divisor_override argument", ) .xfail( @@ -1625,8 +1657,9 @@ def _where_input_wrangler( nn_ops.aten_im2col, input_wrangler=_im2col_input_wrangler, ).xfail( - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) - or not sample.input.shape, + matcher=lambda sample: ( + any(dim == 0 for dim in sample.input.shape) or not sample.input.shape + ), reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), TorchLibOpInfo( @@ -1667,8 +1700,9 @@ def _where_input_wrangler( input_wrangler=_max_pool_input_wrangler, ) .skip( - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True - and sample.kwargs.get("padding") == 1, + matcher=lambda sample: ( + sample.kwargs.get("ceil_mode") is True and sample.kwargs.get("padding") == 1 + ), reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", ) .skip( @@ -1681,8 +1715,9 @@ def _where_input_wrangler( input_wrangler=_max_pool_input_wrangler, ) .skip( - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True - and sample.kwargs.get("padding") == 1, + matcher=lambda sample: ( + sample.kwargs.get("ceil_mode") is True and sample.kwargs.get("padding") == 1 + ), reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", ) .skip( @@ -1704,9 +1739,11 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyFullGraph", ) .xfail( - matcher=lambda sample: len(sample.input.shape) != 4 - or len(sample.args[0].shape) != 4 - or len(sample.args[1].shape) != 4, + matcher=lambda sample: ( + len(sample.input.shape) != 4 + or len(sample.args[0].shape) != 4 + or len(sample.args[1].shape) != 4 + ), reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( @@ -1732,8 +1769,9 @@ def _where_input_wrangler( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, ).xfail( - matcher=lambda sample: sample.args[1] is False - and sample.kwargs.get("scales_h") is not None, + matcher=lambda sample: ( + sample.args[1] is False and sample.kwargs.get("scales_h") is not None + ), reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo( @@ -1752,8 +1790,9 @@ def _where_input_wrangler( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, ).xfail( - matcher=lambda sample: sample.args[1] is False - and sample.kwargs.get("scales_h") is not None, + matcher=lambda sample: ( + sample.args[1] is False and sample.kwargs.get("scales_h") is not None + ), reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo( @@ -1772,8 +1811,9 @@ def _where_input_wrangler( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, ).xfail( - matcher=lambda sample: sample.args[1] is False - and sample.kwargs.get("scales") is not None, + matcher=lambda sample: ( + sample.args[1] is False and sample.kwargs.get("scales") is not None + ), reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo("ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d),