diff --git a/ltx_video/models/transformers/symmetric_patchifier.py b/ltx_video/models/transformers/symmetric_patchifier.py index ba1bd6c..e6e6dde 100644 --- a/ltx_video/models/transformers/symmetric_patchifier.py +++ b/ltx_video/models/transformers/symmetric_patchifier.py @@ -44,7 +44,7 @@ def get_grid( grid_h = torch.arange(h, dtype=torch.float32, device=device) grid_w = torch.arange(w, dtype=torch.float32, device=device) grid_f = torch.arange(f, dtype=torch.float32, device=device) - grid = torch.meshgrid(grid_f, grid_h, grid_w) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing=None) grid = torch.stack(grid, dim=0) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)