|
3 | 3 |
|
4 | 4 | import pytensor |
5 | 5 | import pytensor.tensor as pt |
| 6 | +from pytensor.compile import get_default_mode |
6 | 7 | from pytensor.compile.mode import Mode |
7 | 8 | from pytensor.configdefaults import config |
8 | 9 | from pytensor.graph.rewriting.basic import check_stack_trace |
| 10 | +from pytensor.link.numba import NumbaLinker |
9 | 11 | from pytensor.tensor.conv import abstract_conv |
10 | 12 | from pytensor.tensor.conv.abstract_conv import ( |
11 | 13 | AbstractConv2d, |
@@ -757,6 +759,10 @@ def abstract_conv_gradinputs(filters_val, output_val): |
757 | 759 | def run_test_case(self, *args, **kargs): |
758 | 760 | raise NotImplementedError() |
759 | 761 |
|
| 762 | + @pytest.mark.xfail( |
| 763 | + condition=isinstance(get_default_mode().linker, NumbaLinker), |
| 764 | + reason="Involves Ops with no Python implementation for numba to use as fallback", |
| 765 | + ) |
760 | 766 | def test_all(self): |
761 | 767 | ds = self.default_subsamples |
762 | 768 | db = self.default_border_mode |
@@ -815,6 +821,10 @@ def setup_class(cls): |
815 | 821 | def run_test_case_gi(self, *args, **kwargs): |
816 | 822 | raise NotImplementedError() |
817 | 823 |
|
| 824 | + @pytest.mark.xfail( |
| 825 | + condition=isinstance(get_default_mode().linker, NumbaLinker), |
| 826 | + reason="Involves Ops with no Python implementation for numba to use as fallback", |
| 827 | + ) |
818 | 828 | def test_gradinput_arbitrary_output_shapes(self): |
819 | 829 | # this computes the grad wrt inputs for an output shape |
820 | 830 | # that the forward convolution would not produce |
@@ -948,10 +958,7 @@ def run_gradinput( |
948 | 958 | ) |
949 | 959 |
|
950 | 960 |
|
951 | | -@pytest.mark.skipif( |
952 | | - config.cxx == "", |
953 | | - reason="SciPy and cxx needed", |
954 | | -) |
| 961 | +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") |
955 | 962 | class TestAbstractConvNoOptim(BaseTestConv2d): |
956 | 963 | @classmethod |
957 | 964 | def setup_class(cls): |
@@ -1884,9 +1891,10 @@ def test_conv2d_grad_wrt_weights(self): |
1884 | 1891 | ) |
1885 | 1892 |
|
1886 | 1893 |
|
1887 | | -@pytest.mark.skipif( |
1888 | | - config.cxx == "", |
1889 | | - reason="SciPy and cxx needed", |
| 1894 | +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") |
| 1895 | +@pytest.mark.xfail( |
| 1896 | + condition=isinstance(get_default_mode().linker, NumbaLinker), |
| 1897 | + reason="Involves Ops with no Python implementation for numba to use as fallback", |
1890 | 1898 | ) |
1891 | 1899 | class TestGroupedConvNoOptim: |
1892 | 1900 | conv = abstract_conv.AbstractConv2d |
@@ -2096,9 +2104,10 @@ def conv_gradinputs(filters_val, output_val): |
2096 | 2104 | utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1) |
2097 | 2105 |
|
2098 | 2106 |
|
2099 | | -@pytest.mark.skipif( |
2100 | | - config.cxx == "", |
2101 | | - reason="SciPy and cxx needed", |
| 2107 | +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") |
| 2108 | +@pytest.mark.xfail( |
| 2109 | + condition=isinstance(get_default_mode().linker, NumbaLinker), |
| 2110 | + reason="Involves Ops with no Python implementation for numba to use as fallback", |
2102 | 2111 | ) |
2103 | 2112 | class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim): |
2104 | 2113 | conv = abstract_conv.AbstractConv3d |
|
0 commit comments