|
14 | 14 | from .storage.grouped_tensor_storage import GroupedTensorStorage |
15 | 15 |
|
16 | 16 |
|
| 17 | +# For now, conservatively ban all shape manipulatimg ops. |
| 18 | +BANNED_SHAPE_OPS = { |
| 19 | + torch.ops.aten.view.default, |
| 20 | + torch.ops.aten._unsafe_view.default, |
| 21 | + torch.ops.aten.reshape.default, |
| 22 | + torch.ops.aten._reshape_alias.default, |
| 23 | + torch.ops.aten.flatten.using_ints, |
| 24 | + torch.ops.aten.unflatten.int, |
| 25 | + torch.ops.aten.squeeze.dim, |
| 26 | + torch.ops.aten.squeeze.dims, |
| 27 | + torch.ops.aten.unsqueeze.default, |
| 28 | + torch.ops.aten.transpose.int, |
| 29 | + torch.ops.aten.permute.default, |
| 30 | + torch.ops.aten.movedim.int, |
| 31 | + torch.ops.aten.t.default, |
| 32 | + torch.ops.aten.slice.Tensor, |
| 33 | + torch.ops.aten.narrow.default, |
| 34 | + torch.ops.aten.select.int, |
| 35 | + torch.ops.aten.split.Tensor, |
| 36 | + torch.ops.aten.chunk.default, |
| 37 | + torch.ops.aten.expand.default, |
| 38 | + torch.ops.aten.expand_as.default, |
| 39 | + torch.ops.aten.cat.default, |
| 40 | + torch.ops.aten.stack.default, |
| 41 | +} |
| 42 | + |
| 43 | + |
17 | 44 | class GroupedTensor(GroupedTensorStorage, torch.Tensor): |
18 | 45 | """Tensor wrapper class for grouped tensor storage.""" |
19 | 46 |
|
@@ -96,6 +123,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): |
96 | 123 | if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): |
97 | 124 | return args[0] |
98 | 125 |
|
| 126 | + # Don't allow reshape/view etc. |
| 127 | + if func in BANNED_SHAPE_OPS: |
| 128 | + raise RuntimeError(f"{cls.__name__} forbids shape-manipulation op: {func} ") |
| 129 | + |
99 | 130 | def grouped_to_stacked_tensor(grouped: GroupedTensor) -> torch.Tensor: |
100 | 131 | if not grouped.all_same_shape(): |
101 | 132 | raise NotImplementedError( |
|
0 commit comments