diff --git a/docs/oneflow2onnx/op_list.md b/docs/oneflow2onnx/op_list.md index 57e9aac..b39be89 100644 --- a/docs/oneflow2onnx/op_list.md +++ b/docs/oneflow2onnx/op_list.md @@ -30,4 +30,4 @@ | 90 | ScalarLogicalLess| 91| ScalarLogicalGreater| 92| Gather | 93 | Expand | | 94 | fill_ | 95 | GeLU | 96 | LayerNorm | 97 | AmpIdentity | | 98 | fast_gelu | 99 | quick_gelu | 100 | fused_self_attention |101 |RMSLayerNorm | -| 102 | RMSNorm | 103 | fused_bias_add_scale_mask_softmax_dropout | \ No newline at end of file +| 102 | RMSNorm | 103 | fused_bias_add_scale_mask_softmax_dropout | 104 | stack \ No newline at end of file diff --git a/examples/oneflow2onnx/nodes/CPU/test_stack.py b/examples/oneflow2onnx/nodes/CPU/test_stack.py new file mode 100644 index 0000000..65c0f60 --- /dev/null +++ b/examples/oneflow2onnx/nodes/CPU/test_stack.py @@ -0,0 +1,50 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import tempfile +import oneflow as flow +from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check + + +class Stack(flow.nn.Module): + def __init__(self) -> None: + super(Stack, self).__init__() + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return flow.stack([x, x, x], dim=1) + + +stack = Stack() + + +class StackOpGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = stack + + def build(self, x): + out = self.m(x) + return out + + +def test_stack(): + + stack_graph = StackOpGraph() + stack_graph._compile(flow.randn(1, 3, 224, 224)) + + convert_to_onnx_and_check(stack_graph, onnx_model_path="./temp") + + +test_stack() diff --git a/oneflow_onnx/oneflow2onnx/handlers/array.py b/oneflow_onnx/oneflow2onnx/handlers/array.py index 174c099..93c554c 100644 --- a/oneflow_onnx/oneflow2onnx/handlers/array.py +++ b/oneflow_onnx/oneflow2onnx/handlers/array.py @@ -242,6 +242,55 @@ def Version_11(cls, ctx, node, **kwargs): cls.Version_1(ctx, node, **kwargs) +@flow_op("stack", "ConcatFromSequence") +class Stack: + @classmethod + def Version_11(cls, ctx, node, **kwargs): + print("version_11") + axis_val = node.attrs.get("axis", None) + dtypes = node.output_dtypes + ctx.RemoveNode(node.name) + ctx.MakeNode("ConcatFromSequence", node.input_tensor_names, outputs=[node.output_tensor_names[0]], op_name_scope=node.name, name="stack", dtypes=dtypes, attr={"new_axis": 1, "axis": axis_val}) + + @classmethod + def Version_1(cls, ctx, node, **kwargs): + print(f"version_1: {ctx.opset}") + axis_val = node.attrs.get("axis", None) + dtypes = node.output_dtypes + output_shape = node.output_shapes[0] + node_concat = ctx.MakeNode( + "Concat", + node.input_tensor_names, + op_name_scope=node.name, + name="concat", + dtypes=dtypes, + attr={"axis": axis_val}, + ) + ctx.RemoveNode(node.name) + # since opset 5 + # set_trace() + if ctx.opset > 4: + node_constant = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("shape"), np.array(output_shape).astype(np.int64)) + node_reshape = ctx.MakeNode( + "Reshape", + node_concat.output_tensor_names + node_constant.output_tensor_names, + outputs=node.output_tensor_names, + op_name_scope=node.name, + name="reshape", + dtypes=dtypes, + ) + else: + node_reshape = ctx.MakeNode( + "Reshape", + node_concat.output_tensor_names, + outputs=node.output_tensor_names, + op_name_scope=node.name, + name="reshape", + dtypes=dtypes, + attr={"shape": output_shape}, + ) + + @flow_op("slice", "Slice") class Slice: @classmethod