diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3cfe7c892c46..e6e735f05c37 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3084,6 +3084,25 @@ def _impl_v1(cls, bb, inputs, attr, params): return inputs[0] +class Dropout(OnnxOpConverter): + """Converts an onnx Dropout node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + ratio = float(attr.get("ratio", 0.5)) + return relax.op.nn.dropout(inputs[0], ratio) + + @classmethod + def _impl_v12(cls, bb, inputs, attr, params): + # Since opset 12 ratio is the optional second input rather than an attribute. + ratio = 0.5 + if len(inputs) >= 2 and inputs[1] is not None: + const = get_constant(inputs[1], params) + if isinstance(const, relax.Constant): + ratio = float(const.data.numpy()) + return relax.op.nn.dropout(inputs[0], ratio) + + def _onnx_resize_spatial_roi_vector(roi_full: relax.Expr, rank: int) -> relax.Expr: """Map ONNX ROI [starts..., ends...] to TOPI spatial ROI (drop N/C axes).""" return relax.op.concat( @@ -5314,6 +5333,7 @@ def _get_convert_map(): "ConvTranspose": ConvTranspose, "Flatten": Flatten, "Identity": Identity, + "Dropout": Dropout, "Resize": Resize, "Einsum": Einsum, "Range": Range, diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 35d81f968b37..454e39e6d806 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -697,8 +697,12 @@ def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.nn.dropout") def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: - logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") - return call + # Dropout is a no-op at inference: pass the input through and return an all-ones mask. + return bb.call_te( + lambda x: [topi.identity(x), topi.full_like(x, 1.0)], + call.args[0], + primfunc_name_hint="dropout", + ) def _te_attention( diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 57f780868ccd..31cf91e9e7f1 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -4324,6 +4324,32 @@ def test_maxunpool(kernel_shape, pads, strides): check_correctness(model, inputs={"I": indices}) +def test_dropout(): + verify_unary("Dropout", [1, 3, 32, 32]) + verify_unary("Dropout", [1, 3, 32, 32], opset=11, attrs={"ratio": 0.5}) + + # Opset 12+ passes ratio as an optional input; check it is captured into the relax op. + node = helper.make_node("Dropout", ["x", "ratio"], ["y"]) + graph = helper.make_graph( + [node], + "dropout_ratio_input", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 4, 4])], + initializer=[helper.make_tensor("ratio", TensorProto.FLOAT, [], [0.3])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 4, 4])], + ) + model = helper.make_model(graph, producer_name="dropout_ratio_input") + model.opset_import[0].version = 13 + mod = from_onnx(model, opset=13) + rates = [ + float(b.value.attrs.rate) + for f in mod.functions.values() + for block in getattr(f.body, "blocks", []) + for b in block.bindings + if getattr(getattr(b.value, "op", None), "name", "") == "relax.nn.dropout" + ] + assert rates == pytest.approx([0.3]) + + def test_flatten(): verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0}) verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1}) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 8136997cf66c..c48b1147e9c9 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -4116,5 +4116,43 @@ def main(x: R.Tensor(ndim=4, dtype="float32")) -> R.Tensor(ndim=2, dtype="float3 tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape) +def test_dropout(): + # fmt: off + @tvm.script.ir_module + class Dropout: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")): + gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")) = R.nn.dropout(x, rate=0.5) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True, s_tir=True) + def dropout(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_full_like: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tirx.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.sblock("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(x[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = x[v_i0, v_i1] + for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): + with T.sblock("T_full_like"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads() + T.writes(T_full_like[v_ax0, v_ax1]) + T_full_like[v_ax0, v_ax1] = T.float32(1.0) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")): + cls = Expected + gv = R.call_tir(cls.dropout, (x,), out_sinfo=[R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")]) + return gv + # fmt: on + + mod = LegalizeOps()(Dropout) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()