Skip to content

Commit a32a640

Browse files
committed
Ban some ops
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent f413a93 commit a32a640

1 file changed

Lines changed: 31 additions & 0 deletions

File tree

transformer_engine/pytorch/tensor/grouped_tensor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,33 @@
1414
from .storage.grouped_tensor_storage import GroupedTensorStorage
1515

1616

17+
# For now, conservatively ban all shape manipulatimg ops.
18+
BANNED_SHAPE_OPS = {
19+
torch.ops.aten.view.default,
20+
torch.ops.aten._unsafe_view.default,
21+
torch.ops.aten.reshape.default,
22+
torch.ops.aten._reshape_alias.default,
23+
torch.ops.aten.flatten.using_ints,
24+
torch.ops.aten.unflatten.int,
25+
torch.ops.aten.squeeze.dim,
26+
torch.ops.aten.squeeze.dims,
27+
torch.ops.aten.unsqueeze.default,
28+
torch.ops.aten.transpose.int,
29+
torch.ops.aten.permute.default,
30+
torch.ops.aten.movedim.int,
31+
torch.ops.aten.t.default,
32+
torch.ops.aten.slice.Tensor,
33+
torch.ops.aten.narrow.default,
34+
torch.ops.aten.select.int,
35+
torch.ops.aten.split.Tensor,
36+
torch.ops.aten.chunk.default,
37+
torch.ops.aten.expand.default,
38+
torch.ops.aten.expand_as.default,
39+
torch.ops.aten.cat.default,
40+
torch.ops.aten.stack.default,
41+
}
42+
43+
1744
class GroupedTensor(GroupedTensorStorage, torch.Tensor):
1845
"""Tensor wrapper class for grouped tensor storage."""
1946

@@ -96,6 +123,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
96123
if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default):
97124
return args[0]
98125

126+
# Don't allow reshape/view etc.
127+
if func in BANNED_SHAPE_OPS:
128+
raise RuntimeError(f"{cls.__name__} forbids shape-manipulation op: {func} ")
129+
99130
def grouped_to_stacked_tensor(grouped: GroupedTensor) -> torch.Tensor:
100131
if not grouped.all_same_shape():
101132
raise NotImplementedError(

0 commit comments

Comments
 (0)