@@ -24,25 +24,55 @@ c10::optional<IValue> getIValue(
2424 return toIValue (getValue (name, match_vmap, vmap));
2525}
2626
27+ // FuseShuffle is matching the channelshuffle pattern, where:
28+ // (1) the first view is [n, c, h, w] => [n, groups, c // groups, h, w]
29+ // (2) the tranpose is for groups => [n, c // groups, grpups, h, w]
30+ // (3) the output view shape should be the same as the input tensor shape
2731void FuseShuffle (std::shared_ptr<Graph>& graph) {
28- std::string shuffle = R"(
32+ // below is channelshuffle for staic view shape pattern
33+ std::string channelshuffle_with_static_shape = R"(
2934 graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
3035 %r1 = aten::view(%input, %view_shape)
3136 %r2 = aten::transpose(%r1, %trans_dim0, %trans_dim1)
3237 %r3 = aten::contiguous(%r2, %mem_format)
3338 %r4 = aten::view(%r3, %flattern_shape)
3439 return (%r4) )" ;
3540
36- std::string shuffle_2d_fusion = R"(
41+ std::string shuffle_2d_fusion_with_static_shape = R"(
3742 graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
3843 %r = ipex::shuffle_2d(%input, %view_shape, %trans_dim0, %trans_dim1)
3944 return (%r) )" ;
4045
41- // this filter passes only for the following conditions:
42- // (1) the first view is [n, c, h, w] => [n, groups, c // groups, h, w]
43- // (2) the tranpose is for groups => [n, c // groups, grpups, h, w]
44- // (3) the output view shape should be the same as the input tensor shape
45- auto filter_shuffle_2d_fusion =
46+ // below is channelshuffle for dynamic view shape pattern
47+ std::string dynamic_shape_input = R"(
48+ graph(%input, %idx_0:int, %idx_1:int, %idx_2:int, %idx_3:int, %div_g, %g:int, %type, %flattern_c):
49+ %n_ = aten::size(%input, %idx_0)
50+ %c_ = aten::size(%input, %idx_1)
51+ %tensor_c_ = prim::NumToTensor(%c_)
52+ %h_ = aten::size(%input, %idx_2)
53+ %w_ = aten::size(%input, %idx_3)
54+ %c_div_g_ = aten::div(%tensor_c_, %div_g, %type)
55+ %int_c_div_g_ = aten::Int(%c_div_g_)
56+ %view_shape:int[] = prim::ListConstruct(%n_, %g, %int_c_div_g_, %h_, %w_) )" ;
57+
58+ std::string channelshuffle_for_dynamic_shape = R"(
59+ %r1 = aten::view(%input, %view_shape)
60+ %r2 = aten::transpose(%r1, %idx_1, %idx_2)
61+ %r3 = aten::contiguous(%r2, %idx_0)
62+ %flattern_shape:int[] = prim::ListConstruct(%n_, %flattern_c, %h_, %w_)
63+ %r4 = aten::view(%r3, %flattern_shape)
64+ return (%r4) )" ;
65+
66+ std::string shuffle_2d_fusion_for_dynamic_shape = R"(
67+ %r = ipex::shuffle_2d(%input, %view_shape, %idx_1, %idx_2)
68+ return (%r) )" ;
69+
70+ std::string channelshuffle_with_dynamic_shape =
71+ dynamic_shape_input + channelshuffle_for_dynamic_shape;
72+ std::string shuffle_2d_fusion_with_dynamic_shape =
73+ dynamic_shape_input + shuffle_2d_fusion_for_dynamic_shape;
74+
75+ auto filter_shuffle_2d_static_fusion =
4676 [](const Match& match,
4777 const std::unordered_map<std::string, Value*>& vmap) {
4878 const auto & match_vmap = match.values_map ;
@@ -86,11 +116,12 @@ void FuseShuffle(std::shared_ptr<Graph>& graph) {
86116 return false ;
87117 }
88118
89- // if the view shape and flattern shape is not set
119+ // if the view shape or flattern shape is not set
90120 if (!toIValue (view_shape_).has_value () ||
91121 !toIValue (flattern_shape_).has_value ()) {
92122 return false ;
93123 }
124+
94125 auto view_shape_list = toIValue (view_shape_).value ().toIntVector ();
95126 auto flattern_shape_list =
96127 toIValue (flattern_shape_).value ().toIntVector ();
@@ -134,10 +165,43 @@ void FuseShuffle(std::shared_ptr<Graph>& graph) {
134165 return true ;
135166 };
136167
137- SubgraphRewriter rewriter_shuffle_2d;
138- rewriter_shuffle_2d.RegisterRewritePattern (shuffle, shuffle_2d_fusion);
139- rewriter_shuffle_2d.runOnGraph (graph, filter_shuffle_2d_fusion);
140- }
168+ auto filter_shuffle_2d_dynamic_fusion =
169+ [](const Match& match,
170+ const std::unordered_map<std::string, Value*>& vmap) {
171+ const auto & match_vmap = match.values_map ;
172+
173+ auto n_idx = getIValue (" idx_0" , match_vmap, vmap);
174+ auto c_idx = getIValue (" idx_1" , match_vmap, vmap);
175+ auto h_idx = getIValue (" idx_2" , match_vmap, vmap);
176+ auto w_idx = getIValue (" idx_3" , match_vmap, vmap);
177+ if (!n_idx.has_value () || !c_idx.has_value () || !h_idx.has_value () ||
178+ !w_idx.has_value ()) {
179+ return false ;
180+ }
181+
182+ auto n_idx_ = n_idx.value ().toInt ();
183+ auto c_idx_ = c_idx.value ().toInt ();
184+ auto h_idx_ = h_idx.value ().toInt ();
185+ auto w_idx_ = w_idx.value ().toInt ();
186+
187+ if ((n_idx_ != 0 ) || (c_idx_ != 1 ) || (h_idx_ != 2 ) || (w_idx_ != 3 )) {
188+ return false ;
189+ }
190+
191+ return true ;
192+ };
193+
194+ SubgraphRewriter rewriter_shuffle_2d_dynamic;
195+ rewriter_shuffle_2d_dynamic.RegisterRewritePattern (
196+ channelshuffle_with_dynamic_shape, shuffle_2d_fusion_with_dynamic_shape);
197+ rewriter_shuffle_2d_dynamic.runOnGraph (
198+ graph, filter_shuffle_2d_dynamic_fusion);
199+ SubgraphRewriter rewriter_shuffle_2d_static;
200+ rewriter_shuffle_2d_static.RegisterRewritePattern (
201+ channelshuffle_with_static_shape, shuffle_2d_fusion_with_static_shape);
202+ rewriter_shuffle_2d_static.runOnGraph (graph, filter_shuffle_2d_static_fusion);
203+
204+ } // namespace graph_rewrite
141205
142206void FuseAddLayerNorm (std::shared_ptr<Graph>& graph) {
143207 std::string aten_add_layernorm = R"(
0 commit comments