Skip to content

Commit 6cc371f

Browse files
Freed-Wufracape
authored andcommitted
fix: fix wrong dtype
1 parent e5fe139 commit 6cc371f

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

compressai/models/google.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ def decompress(self, strings, shape):
673673
y_hat = torch.zeros(
674674
(z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding),
675675
device=z_hat.device,
676+
dtype=z_hat.dtype,
676677
)
677678

678679
for i, y_string in enumerate(strings[0]):

0 commit comments

Comments
 (0)