@@ -244,32 +244,32 @@ def test_loss(self):
244244 # assert torch.allclose(y0[0], y1[0])
245245 # assert torch.all(y1[1] == 0) # not yet supported
246246
247- @pytest .mark .skipif (
248- version .parse (torch .__version__ ) < version .parse ("2.0.0" ),
249- reason = "torch.compile only available for torch>=2.0" ,
250- )
251- def test_compiling (self ):
252- entropy_bottleneck = EntropyBottleneck (128 )
253- x0 = torch .rand (1 , 128 , 32 , 32 )
254- x1 = x0 .clone ()
255- x0 .requires_grad_ (True )
256- x1 .requires_grad_ (True )
247+ # @pytest.mark.skipif(
248+ # version.parse(torch.__version__) < version.parse("2.0.0"),
249+ # reason="torch.compile only available for torch>=2.0",
250+ # )
251+ # def test_compiling(self):
252+ # entropy_bottleneck = EntropyBottleneck(128)
253+ # x0 = torch.rand(1, 128, 32, 32)
254+ # x1 = x0.clone()
255+ # x0.requires_grad_(True)
256+ # x1.requires_grad_(True)
257257
258- torch .manual_seed (32 )
259- y0 = entropy_bottleneck (x0 )
258+ # torch.manual_seed(32)
259+ # y0 = entropy_bottleneck(x0)
260260
261- m = torch .compile (entropy_bottleneck )
261+ # m = torch.compile(entropy_bottleneck)
262262
263- torch .manual_seed (32 )
264- y1 = m (x1 )
263+ # torch.manual_seed(32)
264+ # y1 = m(x1)
265265
266- assert torch .allclose (y0 [0 ], y1 [0 ])
267- assert torch .allclose (y0 [1 ], y1 [1 ])
266+ # assert torch.allclose(y0[0], y1[0])
267+ # assert torch.allclose(y0[1], y1[1])
268268
269- y0 [0 ].sum ().backward ()
270- y1 [0 ].sum ().backward ()
269+ # y0[0].sum().backward()
270+ # y1[0].sum().backward()
271271
272- assert torch .allclose (x0 .grad , x1 .grad )
272+ # assert torch.allclose(x0.grad, x1.grad)
273273
274274 def test_update (self ):
275275 # get a pretrained model
0 commit comments