@@ -708,10 +708,10 @@ def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="in
708708 "data, kernel, data_format, strides, padding" ,
709709 [
710710 # Tf on CI (CPU) support only NHWC
711- # ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"),
712- # ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"),
713- ((1 , 32 , 128 , 128 ), (3 , 3 , 32 , 32 ), "NCHW" , (1 , 1 , 1 , 1 ), "SAME" ),
714- ((1 , 32 , 128 , 128 ), (3 , 3 , 32 , 32 ), "NCHW" , (1 , 1 , 1 , 1 ), "VALID" ),
711+ ((1 , 128 , 128 , 32 ), (3 , 3 , 32 , 32 ), "NHWC" , (1 , 1 , 1 , 1 ), "SAME" ),
712+ ((1 , 128 , 128 , 32 ), (3 , 3 , 32 , 32 ), "NHWC" , (1 , 1 , 1 , 1 ), "VALID" ),
713+ # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "SAME"),
714+ # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "VALID"),
715715 ],
716716)
717717def test_conv2d (data , kernel , data_format , strides , padding ):
@@ -742,10 +742,10 @@ def func(self, data, kernel):
742742 "data, kernel, data_format, strides, padding" ,
743743 [
744744 # Tf on CI (CPU) support only NHWC
745- # ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"),
746- # ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"),
747- ((1 , 32 , 128 , 128 ), (3 , 3 ), "NCHW" , (1 , 1 , 1 , 1 ), "SAME" ),
748- ((1 , 32 , 128 , 128 ), (3 , 3 ), "NCHW" , (1 , 1 , 1 , 1 ), "VALID" ),
745+ ((1 , 128 , 128 , 32 ), (2 , 2 ), "NHWC" , (1 , 1 , 1 , 1 ), "SAME" ),
746+ ((1 , 128 , 128 , 32 ), (2 , 2 ), "NHWC" , (1 , 1 , 1 , 1 ), "VALID" ),
747+ # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "SAME"),
748+ # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "VALID"),
749749 ],
750750)
751751def test_pool_2d (pool , data , kernel , data_format , strides , padding ):
0 commit comments