Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,6 +3084,25 @@ def _impl_v1(cls, bb, inputs, attr, params):
return inputs[0]


class Dropout(OnnxOpConverter):
"""Converts an onnx Dropout node into an equivalent Relax expression."""

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
ratio = float(attr.get("ratio", 0.5))
return relax.op.nn.dropout(inputs[0], ratio)
Comment thread
guan404ming marked this conversation as resolved.

@classmethod
def _impl_v12(cls, bb, inputs, attr, params):
# Since opset 12 ratio is the optional second input rather than an attribute.
ratio = 0.5
if len(inputs) >= 2 and inputs[1] is not None:
const = get_constant(inputs[1], params)
if isinstance(const, relax.Constant):
ratio = float(const.data.numpy())
return relax.op.nn.dropout(inputs[0], ratio)


def _onnx_resize_spatial_roi_vector(roi_full: relax.Expr, rank: int) -> relax.Expr:
"""Map ONNX ROI [starts..., ends...] to TOPI spatial ROI (drop N/C axes)."""
return relax.op.concat(
Expand Down Expand Up @@ -5314,6 +5333,7 @@ def _get_convert_map():
"ConvTranspose": ConvTranspose,
"Flatten": Flatten,
"Identity": Identity,
"Dropout": Dropout,
"Resize": Resize,
"Einsum": Einsum,
"Range": Range,
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,12 @@ def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:

@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
return call
# Dropout is a no-op at inference: pass the input through and return an all-ones mask.
return bb.call_te(
lambda x: [topi.identity(x), topi.full_like(x, 1.0)],
call.args[0],
primfunc_name_hint="dropout",
)


def _te_attention(
Expand Down
26 changes: 26 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4324,6 +4324,32 @@ def test_maxunpool(kernel_shape, pads, strides):
check_correctness(model, inputs={"I": indices})


def test_dropout():
verify_unary("Dropout", [1, 3, 32, 32])
verify_unary("Dropout", [1, 3, 32, 32], opset=11, attrs={"ratio": 0.5})

# Opset 12+ passes ratio as an optional input; check it is captured into the relax op.
node = helper.make_node("Dropout", ["x", "ratio"], ["y"])
graph = helper.make_graph(
[node],
"dropout_ratio_input",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 4, 4])],
initializer=[helper.make_tensor("ratio", TensorProto.FLOAT, [], [0.3])],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 4, 4])],
)
model = helper.make_model(graph, producer_name="dropout_ratio_input")
model.opset_import[0].version = 13
mod = from_onnx(model, opset=13)
rates = [
float(b.value.attrs.rate)
for f in mod.functions.values()
for block in getattr(f.body, "blocks", [])
for b in block.bindings
if getattr(getattr(b.value, "op", None), "name", "") == "relax.nn.dropout"
]
assert rates == pytest.approx([0.3])


def test_flatten():
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": 0})
verify_unary("Flatten", [1, 3, 32, 32], attrs={"axis": -1})
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4116,5 +4116,43 @@ def main(x: R.Tensor(ndim=4, dtype="float32")) -> R.Tensor(ndim=2, dtype="float3
tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape)


def test_dropout():
# fmt: off
@tvm.script.ir_module
class Dropout:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")):
gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")) = R.nn.dropout(x, rate=0.5)
return gv

@tvm.script.ir_module
class Expected:
@T.prim_func(private=True, s_tir=True)
def dropout(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_full_like: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tirx.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.sblock("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(x[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = x[v_i0, v_i1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.sblock("T_full_like"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full_like[v_ax0, v_ax1])
T_full_like[v_ax0, v_ax1] = T.float32(1.0)

@R.function
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")):
cls = Expected
gv = R.call_tir(cls.dropout, (x,), out_sinfo=[R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")])
return gv
# fmt: on

mod = LegalizeOps()(Dropout)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading