Skip to content

Commit 0c3d919

Browse files
committed
XFAIL conv tests of Ops without Python implementation
Mark overly specific tests as xfail
1 parent e8bc756 commit 0c3d919

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

tests/tensor/conv/test_abstract_conv.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6+
from pytensor.compile import get_default_mode
67
from pytensor.compile.mode import Mode
78
from pytensor.configdefaults import config
89
from pytensor.graph.rewriting.basic import check_stack_trace
10+
from pytensor.link.numba import NumbaLinker
911
from pytensor.tensor.conv import abstract_conv
1012
from pytensor.tensor.conv.abstract_conv import (
1113
AbstractConv2d,
@@ -757,6 +759,10 @@ def abstract_conv_gradinputs(filters_val, output_val):
757759
def run_test_case(self, *args, **kargs):
758760
raise NotImplementedError()
759761

762+
@pytest.mark.xfail(
763+
condition=isinstance(get_default_mode().linker, NumbaLinker),
764+
reason="Involves Ops with no Python implementation for numba to use as fallback",
765+
)
760766
def test_all(self):
761767
ds = self.default_subsamples
762768
db = self.default_border_mode
@@ -815,6 +821,10 @@ def setup_class(cls):
815821
def run_test_case_gi(self, *args, **kwargs):
816822
raise NotImplementedError()
817823

824+
@pytest.mark.xfail(
825+
condition=isinstance(get_default_mode().linker, NumbaLinker),
826+
reason="Involves Ops with no Python implementation for numba to use as fallback",
827+
)
818828
def test_gradinput_arbitrary_output_shapes(self):
819829
# this computes the grad wrt inputs for an output shape
820830
# that the forward convolution would not produce
@@ -948,10 +958,7 @@ def run_gradinput(
948958
)
949959

950960

951-
@pytest.mark.skipif(
952-
config.cxx == "",
953-
reason="SciPy and cxx needed",
954-
)
961+
@pytest.mark.skipif(config.cxx == "", reason="cxx needed")
955962
class TestAbstractConvNoOptim(BaseTestConv2d):
956963
@classmethod
957964
def setup_class(cls):
@@ -1884,9 +1891,10 @@ def test_conv2d_grad_wrt_weights(self):
18841891
)
18851892

18861893

1887-
@pytest.mark.skipif(
1888-
config.cxx == "",
1889-
reason="SciPy and cxx needed",
1894+
@pytest.mark.skipif(config.cxx == "", reason="cxx needed")
1895+
@pytest.mark.xfail(
1896+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1897+
reason="Involves Ops with no Python implementation for numba to use as fallback",
18901898
)
18911899
class TestGroupedConvNoOptim:
18921900
conv = abstract_conv.AbstractConv2d
@@ -2096,9 +2104,10 @@ def conv_gradinputs(filters_val, output_val):
20962104
utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1)
20972105

20982106

2099-
@pytest.mark.skipif(
2100-
config.cxx == "",
2101-
reason="SciPy and cxx needed",
2107+
@pytest.mark.skipif(config.cxx == "", reason="cxx needed")
2108+
@pytest.mark.xfail(
2109+
condition=isinstance(get_default_mode().linker, NumbaLinker),
2110+
reason="Involves Ops with no Python implementation for numba to use as fallback",
21022111
)
21032112
class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim):
21042113
conv = abstract_conv.AbstractConv3d

0 commit comments

Comments
 (0)