Skip to content

Commit 1b947a8

Browse files
authored
Fix shufflenet reg with dynamic shape context shuffle2d pattern (#724)
* fix shufflenet reg with dynamic shape context pattern * refine code * refine filter code and add no match ut
1 parent 0e4228b commit 1b947a8

File tree

2 files changed

+114
-15
lines changed

2 files changed

+114
-15
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,55 @@ c10::optional<IValue> getIValue(
2424
return toIValue(getValue(name, match_vmap, vmap));
2525
}
2626

27+
// FuseShuffle is matching the channelshuffle pattern, where:
28+
// (1) the first view is [n, c, h, w] => [n, groups, c // groups, h, w]
29+
// (2) the tranpose is for groups => [n, c // groups, grpups, h, w]
30+
// (3) the output view shape should be the same as the input tensor shape
2731
void FuseShuffle(std::shared_ptr<Graph>& graph) {
28-
std::string shuffle = R"(
32+
// below is channelshuffle for staic view shape pattern
33+
std::string channelshuffle_with_static_shape = R"(
2934
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
3035
%r1 = aten::view(%input, %view_shape)
3136
%r2 = aten::transpose(%r1, %trans_dim0, %trans_dim1)
3237
%r3 = aten::contiguous(%r2, %mem_format)
3338
%r4 = aten::view(%r3, %flattern_shape)
3439
return (%r4) )";
3540

36-
std::string shuffle_2d_fusion = R"(
41+
std::string shuffle_2d_fusion_with_static_shape = R"(
3742
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
3843
%r = ipex::shuffle_2d(%input, %view_shape, %trans_dim0, %trans_dim1)
3944
return (%r) )";
4045

41-
// this filter passes only for the following conditions:
42-
// (1) the first view is [n, c, h, w] => [n, groups, c // groups, h, w]
43-
// (2) the tranpose is for groups => [n, c // groups, grpups, h, w]
44-
// (3) the output view shape should be the same as the input tensor shape
45-
auto filter_shuffle_2d_fusion =
46+
// below is channelshuffle for dynamic view shape pattern
47+
std::string dynamic_shape_input = R"(
48+
graph(%input, %idx_0:int, %idx_1:int, %idx_2:int, %idx_3:int, %div_g, %g:int, %type, %flattern_c):
49+
%n_ = aten::size(%input, %idx_0)
50+
%c_ = aten::size(%input, %idx_1)
51+
%tensor_c_ = prim::NumToTensor(%c_)
52+
%h_ = aten::size(%input, %idx_2)
53+
%w_ = aten::size(%input, %idx_3)
54+
%c_div_g_ = aten::div(%tensor_c_, %div_g, %type)
55+
%int_c_div_g_ = aten::Int(%c_div_g_)
56+
%view_shape:int[] = prim::ListConstruct(%n_, %g, %int_c_div_g_, %h_, %w_) )";
57+
58+
std::string channelshuffle_for_dynamic_shape = R"(
59+
%r1 = aten::view(%input, %view_shape)
60+
%r2 = aten::transpose(%r1, %idx_1, %idx_2)
61+
%r3 = aten::contiguous(%r2, %idx_0)
62+
%flattern_shape:int[] = prim::ListConstruct(%n_, %flattern_c, %h_, %w_)
63+
%r4 = aten::view(%r3, %flattern_shape)
64+
return (%r4) )";
65+
66+
std::string shuffle_2d_fusion_for_dynamic_shape = R"(
67+
%r = ipex::shuffle_2d(%input, %view_shape, %idx_1, %idx_2)
68+
return (%r) )";
69+
70+
std::string channelshuffle_with_dynamic_shape =
71+
dynamic_shape_input + channelshuffle_for_dynamic_shape;
72+
std::string shuffle_2d_fusion_with_dynamic_shape =
73+
dynamic_shape_input + shuffle_2d_fusion_for_dynamic_shape;
74+
75+
auto filter_shuffle_2d_static_fusion =
4676
[](const Match& match,
4777
const std::unordered_map<std::string, Value*>& vmap) {
4878
const auto& match_vmap = match.values_map;
@@ -86,11 +116,12 @@ void FuseShuffle(std::shared_ptr<Graph>& graph) {
86116
return false;
87117
}
88118

89-
// if the view shape and flattern shape is not set
119+
// if the view shape or flattern shape is not set
90120
if (!toIValue(view_shape_).has_value() ||
91121
!toIValue(flattern_shape_).has_value()) {
92122
return false;
93123
}
124+
94125
auto view_shape_list = toIValue(view_shape_).value().toIntVector();
95126
auto flattern_shape_list =
96127
toIValue(flattern_shape_).value().toIntVector();
@@ -134,10 +165,43 @@ void FuseShuffle(std::shared_ptr<Graph>& graph) {
134165
return true;
135166
};
136167

137-
SubgraphRewriter rewriter_shuffle_2d;
138-
rewriter_shuffle_2d.RegisterRewritePattern(shuffle, shuffle_2d_fusion);
139-
rewriter_shuffle_2d.runOnGraph(graph, filter_shuffle_2d_fusion);
140-
}
168+
auto filter_shuffle_2d_dynamic_fusion =
169+
[](const Match& match,
170+
const std::unordered_map<std::string, Value*>& vmap) {
171+
const auto& match_vmap = match.values_map;
172+
173+
auto n_idx = getIValue("idx_0", match_vmap, vmap);
174+
auto c_idx = getIValue("idx_1", match_vmap, vmap);
175+
auto h_idx = getIValue("idx_2", match_vmap, vmap);
176+
auto w_idx = getIValue("idx_3", match_vmap, vmap);
177+
if (!n_idx.has_value() || !c_idx.has_value() || !h_idx.has_value() ||
178+
!w_idx.has_value()) {
179+
return false;
180+
}
181+
182+
auto n_idx_ = n_idx.value().toInt();
183+
auto c_idx_ = c_idx.value().toInt();
184+
auto h_idx_ = h_idx.value().toInt();
185+
auto w_idx_ = w_idx.value().toInt();
186+
187+
if ((n_idx_ != 0) || (c_idx_ != 1) || (h_idx_ != 2) || (w_idx_ != 3)) {
188+
return false;
189+
}
190+
191+
return true;
192+
};
193+
194+
SubgraphRewriter rewriter_shuffle_2d_dynamic;
195+
rewriter_shuffle_2d_dynamic.RegisterRewritePattern(
196+
channelshuffle_with_dynamic_shape, shuffle_2d_fusion_with_dynamic_shape);
197+
rewriter_shuffle_2d_dynamic.runOnGraph(
198+
graph, filter_shuffle_2d_dynamic_fusion);
199+
SubgraphRewriter rewriter_shuffle_2d_static;
200+
rewriter_shuffle_2d_static.RegisterRewritePattern(
201+
channelshuffle_with_static_shape, shuffle_2d_fusion_with_static_shape);
202+
rewriter_shuffle_2d_static.runOnGraph(graph, filter_shuffle_2d_static_fusion);
203+
204+
} // namespace graph_rewrite
141205

142206
void FuseAddLayerNorm(std::shared_ptr<Graph>& graph) {
143207
std::string aten_add_layernorm = R"(

tests/cpu/test_jit.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,9 @@ def forward(self, x):
582582
x = self.conv_transpose2d(x)
583583
return x
584584

585-
class ChannelShuffle(nn.Module):
585+
class ChannelShuffle_with_Static_Shape(nn.Module):
586586
def __init__(self, batchsize, num_channels, height, width, groups):
587-
super(ChannelShuffle, self).__init__()
587+
super(ChannelShuffle_with_Static_Shape, self).__init__()
588588
self.batchsize = batchsize
589589
self.num_channels = num_channels
590590
self.height = height
@@ -598,6 +598,32 @@ def forward(self, x):
598598
x = x.view(self.batchsize, -1, self.height, self.width)
599599
return x
600600

601+
class ChannelShuffle_with_Dynamic_Shape(nn.Module):
602+
def __init__(self, groups):
603+
super(ChannelShuffle_with_Dynamic_Shape, self).__init__()
604+
self.groups = groups
605+
606+
def forward(self, x):
607+
batchsize, num_channels, height, width = x.size()
608+
channels_per_group = num_channels // self.groups
609+
x = x.view(batchsize, self.groups, channels_per_group, height, width)
610+
x = torch.transpose(x, 1, 2).contiguous()
611+
x = x.view(batchsize, -1, height, width)
612+
return x
613+
614+
class NotChannelShuffle(nn.Module):
615+
def __init__(self, groups):
616+
super(NotChannelShuffle, self).__init__()
617+
self.groups = groups
618+
619+
def forward(self, x):
620+
batchsize, num_channels, height, width = x.size()
621+
channels_per_group = num_channels // self.groups
622+
x = x.view(batchsize, self.groups, channels_per_group, width, height)
623+
x = torch.transpose(x, 1, 2).contiguous()
624+
x = x.view(batchsize, -1, width, height)
625+
return x
626+
601627
class MatmulDiv(nn.Module):
602628
def __init__(self, div_scalar=False, with_out=False):
603629
super(MatmulDiv, self).__init__()
@@ -2446,9 +2472,18 @@ def test_output_linear_sigmoid(self):
24462472

24472473
def test_channel_shuffle(self):
24482474
self._test_output(
2449-
ChannelShuffle(10, 16, 50, 50, 4),
2475+
ChannelShuffle_with_Static_Shape(10, 16, 50, 50, 4),
2476+
torch.rand(10, 16, 50, 50),
2477+
kind_in_graph="ipex::shuffle_2d")
2478+
self._test_output(
2479+
ChannelShuffle_with_Dynamic_Shape(4),
24502480
torch.rand(10, 16, 50, 50),
24512481
kind_in_graph="ipex::shuffle_2d")
2482+
self._test_output(
2483+
NotChannelShuffle(4),
2484+
torch.rand(10, 16, 50, 60),
2485+
kind_not_in_graph="ipex::shuffle_2d")
2486+
24522487

24532488
def test_jit_function(self):
24542489
# test hool trace and script can works for function

0 commit comments

Comments
 (0)