diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 641799aa2..c0ca466e0 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -2119,18 +2119,7 @@ ] }, "torch.Tensor.select_scatter": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.Tensor.select_scatter", - "min_input_args": 3, - "args_list": [ - "src", - "dim", - "index" - ], - "kwargs_change": { - "src": "values", - "dim": "axis" - } + "Matcher": "ChangePrefixMatcher" }, "torch.Tensor.set_": { "Matcher": "TensorSetMatcher", @@ -2352,12 +2341,7 @@ "Matcher": "ChangePrefixMatcher" }, "torch.Tensor.take": { - "Matcher": "TensorTakeMatcher", - "paddle_api": "paddle.Tensor.take", - "min_input_args": 1, - "args_list": [ - "index" - ] + "Matcher": "ChangePrefixMatcher" }, "torch.Tensor.take_along_dim": { "Matcher": "ChangePrefixMatcher" @@ -10627,20 +10611,7 @@ ] }, "torch.select_scatter": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.select_scatter", - "min_input_args": 4, - "args_list": [ - "input", - "src", - "dim", - "index" - ], - "kwargs_change": { - "input": "x", - "src": "values", - "dim": "axis" - } + "Matcher": "ChangePrefixMatcher" }, "torch.selu": { "Matcher": "GenericMatcher", @@ -10719,17 +10690,7 @@ } }, "torch.sgn": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.sgn", - "min_input_args": 1, - "args_list": [ - "input", - "*", - "out" - ], - "kwargs_change": { - "input": "x" - } + "Matcher": "ChangePrefixMatcher" }, "torch.sigmoid": { "Matcher": "ChangePrefixMatcher" @@ -10934,17 +10895,7 @@ } }, "torch.signbit": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.signbit", - "min_input_args": 1, - "args_list": [ - "input", - "*", - "out" - ], - "kwargs_change": { - "input": "x" - } + "Matcher": "ChangePrefixMatcher" }, "torch.sin": { "Matcher": "ChangePrefixMatcher" @@ -11531,16 +11482,7 @@ "Matcher": "ChangePrefixMatcher" }, "torch.take": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.take", - "min_input_args": 2, - "args_list": [ - "input", - "index" - ], - "kwargs_change": { - "input": "x" - } + "Matcher": "ChangePrefixMatcher" }, "torch.take_along_dim": { "Matcher": "ChangePrefixMatcher" @@ -11558,20 +11500,7 @@ "Matcher": "ChangePrefixMatcher" }, "torch.tensordot": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.tensordot", - "min_input_args": 2, - "args_list": [ - "a", - "b", - "dims", - "out" - ], - "kwargs_change": { - "a": "x", - "b": "y", - "dims": "axes" - } + "Matcher": "ChangePrefixMatcher" }, "torch.testing.assert_allclose": { "Matcher": "Assert_AllcloseMatcher", @@ -11656,42 +11585,14 @@ "Matcher": "ChangePrefixMatcher" }, "torch.tril_indices": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.tril_indices", - "min_input_args": 2, - "args_list": [ - "row", - "col", - "offset", - "*", - "dtype", - "device", - "layout" - ], - "kwargs_change": { - "dtype": "dtype" - } + "Matcher": "ChangePrefixMatcher" }, "torch.triplet_margin_loss": {}, "torch.triu": { "Matcher": "ChangePrefixMatcher" }, "torch.triu_indices": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.triu_indices", - "min_input_args": 2, - "args_list": [ - "row", - "col", - "offset", - "*", - "dtype", - "device", - "layout" - ], - "kwargs_change": { - "dtype": "dtype" - } + "Matcher": "ChangePrefixMatcher" }, "torch.true_divide": { "Matcher": "ChangePrefixMatcher" diff --git a/tests/test_Tensor_select_scatter.py b/tests/test_Tensor_select_scatter.py index 8664ddf92..3a705ce11 100644 --- a/tests/test_Tensor_select_scatter.py +++ b/tests/test_Tensor_select_scatter.py @@ -77,3 +77,94 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + """Test with 3D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.zeros((2, 3, 4)).type(torch.float32) + src = torch.ones((2, 4)).type(torch.float32) + result = input.select_scatter(src, 1, 1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """Test with 3D tensor and keyword arguments""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.zeros((2, 3, 4)).type(torch.float32) + src = torch.ones((2, 4)).type(torch.float32) + result = input.select_scatter(src=src, dim=1, index=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Test with 4D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.zeros((2, 3, 4, 5)).type(torch.float32) + src = torch.ones((2, 3, 5)).type(torch.float32) + result = input.select_scatter(src, dim=2, index=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Test with index=0""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.zeros((3, 4)).type(torch.float32) + src = torch.ones(4).type(torch.float32) + result = input.select_scatter(src, dim=0, index=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Test with float64 dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.zeros((2, 3)).type(torch.float64) + src = torch.ones(3).type(torch.float64) + result = input.select_scatter(src, dim=0, index=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_11(): + """Test with keyword arguments out of order""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.zeros((2, 3, 4)) + src = torch.ones((2, 4)) + result = input.select_scatter(index=1, src=src, dim=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + """Test on last dimension""" + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.zeros((2, 3, 4)) + src = torch.ones((2, 3)) + result = input.select_scatter(src=src, dim=2, index=0) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_sgn.py b/tests/test_Tensor_sgn.py index f418366cb..4ad2c387b 100644 --- a/tests/test_Tensor_sgn.py +++ b/tests/test_Tensor_sgn.py @@ -39,3 +39,99 @@ def test_case_2(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + """Test with 2D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[0.5950, -0.0872], [2.3298, -0.2972]]) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + """Test with 3D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[0.5950, -0.0872], [2.3298, -0.2972]]]) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + """Test with zero values""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0., 1., 0., -1., 0.]) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + """Test with negative values only""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([-1., -2., -3., -0.5]) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """Test with positive values only""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1., 2., 3., 0.5]) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Test with single element tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([3.5]) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Test with large values""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1e10, -1e10, 1e-10, -1e-10]) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Test with float64 dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0.5950, -0.0872, 2.3298, -0.2972], dtype=torch.float64) + result = a.sgn() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_signbit.py b/tests/test_Tensor_signbit.py index ced4f975f..47d84f6ec 100644 --- a/tests/test_Tensor_signbit.py +++ b/tests/test_Tensor_signbit.py @@ -39,3 +39,99 @@ def test_case_2(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + """Test with 2D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[-1., 2., -3.], [4., -5., 6.]], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + """Test with 3D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[[-0., 1.1], [-2.1, 0.]], [[2.5, -3.], [4., -5.]]], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + """Test with all positive values""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2., 3., 0.5, 10.], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + """Test with all negative values""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-1., -2., -3., -0.5, -10.], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """Test with single element tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-5.], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Test with single positive element""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([5.], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Test with mixed positive and negative""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., -2., 3., -0.5, 5.], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Test with all zeros""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 0., 0., 0.], dtype=torch.float32) + result = x.signbit() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_take.py b/tests/test_Tensor_take.py index 6eba2d7cb..5735426d1 100644 --- a/tests/test_Tensor_take.py +++ b/tests/test_Tensor_take.py @@ -41,3 +41,75 @@ def test_case_2(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + """1D tensor test""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 2, 3, 4, 5]) + result = x.take(torch.tensor([0, 2, 4])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + """3D tensor test""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + result = x.take(torch.tensor([0, 3, 7])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + """Float tensor test""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]]) + result = x.take(torch.tensor([1, 4])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + """Multiple indices""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + result = x.take(torch.tensor([0, 4, 8])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """Repeated indices""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2], [3, 4]]) + result = x.take(torch.tensor([0, 0, 3])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Expression argument test""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 2, 3, 4, 5, 6]) + result = x.take(torch.tensor([1 + 1])) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_select_scatter.py b/tests/test_select_scatter.py index d3339e735..e938a5192 100644 --- a/tests/test_select_scatter.py +++ b/tests/test_select_scatter.py @@ -77,3 +77,68 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + """Test with 2D tensor on different dimension""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.zeros((3, 4)).type(torch.float32) + values = torch.ones(4).type(torch.float32) + result = torch.select_scatter(x, values, dim=0, index=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """Test with keyword arguments out of order""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.zeros((2,3,4)).type(torch.float32) + values = torch.ones((2,4)).type(torch.float32) + result = torch.select_scatter(index=1, src=values, dim=1, input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Test with 4D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.zeros((2, 3, 4, 5)).type(torch.float32) + values = torch.ones((2, 3, 5)).type(torch.float32) + result = torch.select_scatter(input=x, src=values, dim=2, index=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Test with index=0""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.zeros((2,3,4)).type(torch.float32) + values = torch.ones((2,4)).type(torch.float32) + result = torch.select_scatter(x, values, dim=1, index=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Test with different dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.zeros((2,3,4)).type(torch.float64) + values = torch.ones((2,4)).type(torch.float64) + result = torch.select_scatter(input=x, src=values, dim=1, index=2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_sgn.py b/tests/test_sgn.py index 530e6e44a..83b5aa909 100644 --- a/tests/test_sgn.py +++ b/tests/test_sgn.py @@ -75,3 +75,90 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["out"]) + + +def test_case_6(): + """Test with 2D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[0.5950, -0.0872], [2.3298, -0.2972]]) + result = torch.sgn(a) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """Test with 3D tensor and out parameter""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[0.5950, -0.0872], [2.3298, -0.2972]]]) + out = torch.zeros_like(a) + torch.sgn(a, out=out) + """ + ) + obj.run(pytorch_code, ["out"]) + + +def test_case_8(): + """Test with zero values""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0., 1., 0., -1., 0.]) + result = torch.sgn(input=a) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Test with negative values only""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([-1., -2., -3., -0.5]) + result = torch.sgn(a) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Test out parameter with pre-allocated tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1., 2., -3., 0.]) + out = torch.zeros(4) + result = torch.sgn(input=a, out=out) + result_value = result is out + """ + ) + obj.run(pytorch_code, ["out", "result_value"], check_value=False) + + +def test_case_11(): + """Test with single element tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([3.5]) + result = torch.sgn(input=a) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + """Test with large values""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([1e10, -1e10, 1e-10, -1e-10]) + result = torch.sgn(a) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_signbit.py b/tests/test_signbit.py index bb87afebc..092dd2be5 100644 --- a/tests/test_signbit.py +++ b/tests/test_signbit.py @@ -75,3 +75,103 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_6(): + """Test with float64 dtype""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float64) + result = torch.signbit(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """Test with 2D tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[-1., 2., -3.], [4., -5., 6.]], dtype=torch.float32) + result = torch.signbit(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Test with 3D tensor and out parameter""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[[-0., 1.1], [-2.1, 0.]], [[2.5, -3.], [4., -5.]]], dtype=torch.float32) + out = torch.zeros(2, 2, 2, dtype=torch.bool) + torch.signbit(x, out=out) + """ + ) + obj.run(pytorch_code, ["out"]) + + +def test_case_9(): + """Test with all positive values""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2., 3., 0.5, 10.], dtype=torch.float32) + result = torch.signbit(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Test with all negative values""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-1., -2., -3., -0.5, -10.], dtype=torch.float32) + result = torch.signbit(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_11(): + """Test with single element tensor""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-5.], dtype=torch.float32) + result = torch.signbit(input=x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + """Test with keyword arguments out of order""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32) + out = torch.zeros(5, dtype=torch.bool) + result = torch.signbit(out=out, input=x) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_13(): + """Test with out parameter pre-allocated""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., -2., 3., -0.], dtype=torch.float32) + out = torch.zeros(4, dtype=torch.bool) + result = torch.signbit(input=x, out=out) + result_is_same = result is out + """ + ) + obj.run(pytorch_code, ["out", "result_is_same"], check_value=False) diff --git a/tests/test_take.py b/tests/test_take.py index 9dc45bc2d..11edb5642 100644 --- a/tests/test_take.py +++ b/tests/test_take.py @@ -90,3 +90,51 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """1D tensor test""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 2, 3, 4, 5]) + result = torch.take(x, torch.tensor([0, 2, 4])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """3D tensor test""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + result = torch.take(x, torch.tensor([0, 3, 7])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Mixed parameter types""" + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + idx = torch.tensor([1, 4]) + result = torch.take(input=x, index=idx) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Expression argument test""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.take(torch.tensor([1, 2, 3, 4, 5, 6]), torch.tensor([2 * 2])) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_tensordot.py b/tests/test_tensordot.py index 9c9cf438f..2b8e8474a 100644 --- a/tests/test_tensordot.py +++ b/tests/test_tensordot.py @@ -80,3 +80,68 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_6(): + """Single dimension contraction""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(12.).reshape(3, 4) + b = torch.arange(12.).reshape(4, 3) + result = torch.tensordot(a, b, dims=([1], [0])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + """4D tensors""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.randn(2, 3, 4, 5) + b = torch.randn(4, 5, 2, 3) + result = torch.tensordot(a, b, dims=([2, 3], [0, 1])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Integer dims parameter""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(6.).reshape(2, 3) + b = torch.arange(6.).reshape(2, 3) + result = torch.tensordot(a, b, dims=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Keyword arguments out of order""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(24.).reshape(3, 8) + b = torch.arange(24.).reshape(3, 8) + result = torch.tensordot(b=b, a=a, dims=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """All keyword arguments""" + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(12.).reshape(3, 4) + b = torch.arange(12.).reshape(4, 3) + result = torch.tensordot(a=a, b=b, dims=([1], [0])) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_tril_indices.py b/tests/test_tril_indices.py index d9a979398..35a66d2db 100644 --- a/tests/test_tril_indices.py +++ b/tests/test_tril_indices.py @@ -87,3 +87,69 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Rectangular matrix (more rows than cols)""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tril_indices(row=5, col=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Rectangular matrix (more cols than rows)""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tril_indices(row=2, col=5) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Large positive offset""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tril_indices(row=5, col=5, offset=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_11(): + """1x1 matrix""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tril_indices(row=1, col=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + """Default dtype and device""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tril_indices(3, 3, offset=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_13(): + """Expression argument test""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tril_indices(2 + 1, 3, offset=-1) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_triu_indices.py b/tests/test_triu_indices.py index 0bf244968..5ab1ca22a 100644 --- a/tests/test_triu_indices.py +++ b/tests/test_triu_indices.py @@ -87,3 +87,69 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + """Rectangular matrix (more rows than cols)""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.triu_indices(row=5, col=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + """Rectangular matrix (more cols than rows)""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.triu_indices(row=2, col=5) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + """Large positive offset""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.triu_indices(row=5, col=5, offset=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_11(): + """1x1 matrix""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.triu_indices(row=1, col=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_12(): + """Default dtype and device""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.triu_indices(3, 3, offset=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_13(): + """Expression argument test""" + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.triu_indices(2 + 1, 3, offset=-1) + """ + ) + obj.run(pytorch_code, ["result"])