From 7fe33c655496be7f5b9b5c7ca110cb8505cf1be2 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 19 Jun 2026 22:58:19 +0800 Subject: [PATCH] [Relax] Legalize dilated conv_transpose Signed-off-by: Guan-Ming (Wesley) Chiu <105915352+guan404ming@users.noreply.github.com> --- python/tvm/relax/transform/legalize_ops/nn.py | 131 ++++++++---------- .../relax/test_transform_legalize_ops_nn.py | 74 ++++++++++ 2 files changed, 135 insertions(+), 70 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 35d81f968b37..2de4bbd0eb49 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -20,6 +20,7 @@ import logging import math +import tvm from tvm import s_tir, te, tirx, topi from ...block_builder import BlockBuilder @@ -153,70 +154,62 @@ def _nn_conv3d(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.nn.conv1d_transpose") def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.out_layout != call.attrs.data_layout: - logging.info( - "TOPI conv1d_transpose does not support different input-output " - "layouts, and thus cannot be legalized by TOPI" + raise tvm.error.OpNotImplemented( + "nn.conv1d_transpose with out_layout != data_layout is not yet lowered by TOPI." ) - return call if call.attrs.data_layout != "NCW" or call.attrs.kernel_layout != "IOW": - logging.info( - "TOPI conv1d_transpose does not support input layout other than NCW, " - "and kernel layout other than IOW, so cannot be legalized by TOPI" + raise tvm.error.OpNotImplemented( + f"nn.conv1d_transpose with data_layout={call.attrs.data_layout!r}, " + f"kernel_layout={call.attrs.kernel_layout!r} is not yet lowered by TOPI (only NCW/IOW)." ) - return call - dilation = call.attrs.dilation - if len(dilation) != 1 or dilation[0] != 1: - logging.info( - "TOPI conv1d_transpose does not support dilations other than 1, " - "and thus cannot be legalized by TOPI" + strides = [int(s) for s in call.attrs.strides] + padding = [int(p) for p in call.attrs.padding] + output_padding = [int(o) for o in call.attrs.output_padding] + groups = int(call.attrs.groups) + out_dtype = call.struct_info.dtype + dilation = [int(d) for d in call.attrs.dilation] + + def te_conv1d_transpose(data, kernel): + # Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel. + if any(d != 1 for d in dilation): + kernel = topi.nn.dilate(kernel, [1, 1, dilation[0]], name="kernel_dilate") + return topi.nn.group_conv1d_transpose_ncw( + data, kernel, strides, padding, out_dtype, output_padding, groups ) - return call return bb.call_te( - topi.nn.group_conv1d_transpose_ncw, - call.args[0], - call.args[1], - stride=call.attrs.strides, - padding=call.attrs.padding, - out_dtype=call.struct_info.dtype, - output_padding=call.attrs.output_padding, - groups=call.attrs.groups, - primfunc_name_hint="conv1d_transpose", + te_conv1d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv1d_transpose" ) @register_legalize("relax.nn.conv2d_transpose") def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.out_layout != call.attrs.data_layout: - logging.info( - "TOPI conv2d_transpose does not support different input-output " - "layouts, and thus cannot be legalized by TOPI" + raise tvm.error.OpNotImplemented( + "nn.conv2d_transpose with out_layout != data_layout is not yet lowered by TOPI." ) - return call if call.attrs.data_layout != "NCHW" or call.attrs.kernel_layout != "IOHW": - logging.info( - "TOPI conv2d_transpose does not support input layout other than NCHW, " - "and kernel layout other than IOHW, so cannot be legalized by TOPI" + raise tvm.error.OpNotImplemented( + f"nn.conv2d_transpose with data_layout={call.attrs.data_layout!r}, " + f"kernel_layout={call.attrs.kernel_layout!r} is not yet lowered by TOPI (only NCHW/IOHW)." ) - return call - dilation = call.attrs.dilation - if len(dilation) != 2 or any(d != 1 for d in dilation): - logging.info( - "TOPI conv2d_transpose does not support dilations other than 1, " - "and thus cannot be legalized by TOPI" + strides = [int(s) for s in call.attrs.strides] + padding = [int(p) for p in call.attrs.padding] + output_padding = [int(o) for o in call.attrs.output_padding] + groups = int(call.attrs.groups) + out_dtype = call.struct_info.dtype + dilation = [int(d) for d in call.attrs.dilation] + + def te_conv2d_transpose(data, kernel): + # Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel. + if any(d != 1 for d in dilation): + kernel = topi.nn.dilate(kernel, [1, 1, dilation[0], dilation[1]], name="kernel_dilate") + return topi.nn.group_conv2d_transpose_nchw( + data, kernel, strides, padding, out_dtype, output_padding, groups ) - return call return bb.call_te( - topi.nn.group_conv2d_transpose_nchw, - call.args[0], - call.args[1], - stride=call.attrs.strides, - padding=call.attrs.padding, - out_dtype=call.struct_info.dtype, - output_padding=call.attrs.output_padding, - groups=call.attrs.groups, - primfunc_name_hint="conv2d_transpose", + te_conv2d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv2d_transpose" ) @@ -225,35 +218,33 @@ def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr: # Keep policy in sync with _nn_conv2d_transpose: only lower when TOPI supports # the layout/dilation. if call.attrs.out_layout != call.attrs.data_layout: - logging.info( - "TOPI conv3d_transpose does not support different input-output " - "layouts, and thus cannot be legalized by TOPI" + raise tvm.error.OpNotImplemented( + "nn.conv3d_transpose with out_layout != data_layout is not yet lowered by TOPI." ) - return call if call.attrs.data_layout != "NCDHW" or call.attrs.kernel_layout != "IODHW": - logging.info( - "TOPI conv3d_transpose does not support input layout other than NCDHW, " - "and kernel layout other than IODHW, so cannot be legalized by TOPI" - ) - return call - dilation = call.attrs.dilation - if len(dilation) != 3 or any(d != 1 for d in dilation): - logging.info( - "TOPI conv3d_transpose does not support dilations other than 1, " - "and thus cannot be legalized by TOPI" + raise tvm.error.OpNotImplemented( + f"nn.conv3d_transpose with data_layout={call.attrs.data_layout!r}, " + f"kernel_layout={call.attrs.kernel_layout!r} is not yet lowered by TOPI (only NCDHW/IODHW)." + ) + strides = [int(s) for s in call.attrs.strides] + padding = [int(p) for p in call.attrs.padding] + output_padding = [int(o) for o in call.attrs.output_padding] + groups = int(call.attrs.groups) + out_dtype = call.struct_info.dtype + dilation = [int(d) for d in call.attrs.dilation] + + def te_conv3d_transpose(data, kernel): + # Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel. + if any(d != 1 for d in dilation): + kernel = topi.nn.dilate( + kernel, [1, 1, dilation[0], dilation[1], dilation[2]], name="kernel_dilate" + ) + return topi.nn.group_conv3d_transpose_ncdhw( + data, kernel, strides, padding, out_dtype, output_padding, groups ) - return call return bb.call_te( - topi.nn.group_conv3d_transpose_ncdhw, - call.args[0], - call.args[1], - strides=call.attrs.strides, - padding=call.attrs.padding, - out_dtype=call.struct_info.dtype, - output_padding=call.attrs.output_padding, - groups=call.attrs.groups, - primfunc_name_hint="conv3d_transpose", + te_conv3d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv3d_transpose" ) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 8136997cf66c..1b66cf4a891b 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -725,6 +725,80 @@ def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, tvm.ir.assert_structural_equal(mod, Expected) +def test_conv2d_transpose_dilation(): + # fmt: off + @tvm.script.ir_module + class Conv2dTranspose: + @R.function + def main(x: R.Tensor((1, 1, 3, 3), "float32"), w: R.Tensor((1, 1, 2, 2), "float32")): + gv = R.nn.conv2d_transpose(x, w, dilation=(2, 2)) + return gv + + @I.ir_module(s_tir=True) + class Expected: + @T.prim_func(private=True, s_tir=True) + def conv2d_transpose(x: T.Buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3)), "float32"), w: T.Buffer((T.int64(1), T.int64(1), T.int64(2), T.int64(2)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32")): + T.func_attr({"tirx.noalias": True}) + data_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3))) + data_pad = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(7), T.int64(7))) + kernel_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3))) + kernel_transform = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3))) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)): + with T.sblock("data_dilate"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + data_dilate[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, v_i3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(7), T.int64(7)): + with T.sblock("data_pad"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(5) and T.int64(2) <= v_i3 and v_i3 < T.int64(5), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)], T.float32(0.0)) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)): + with T.sblock("kernel_dilate"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + kernel_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(2) == T.int64(0), w[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(2)], T.float32(0.0)) + for o, i, h, w_1 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)): + with T.sblock("kernel_transform"): + v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w_1]) + kernel_transform[v_o, v_i, v_h, v_w] = kernel_dilate[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w] + for b, c, h, w_1, dc, dh, dw in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5), T.int64(1), T.int64(3), T.int64(3)): + with T.sblock("compute"): + v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w_1, dc, dh, dw]) + with T.init(): + compute[v_b, v_c, v_h, v_w] = T.float32(0.0) + compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, v_dh, v_dw] + + @R.function + def main(x: R.Tensor((1, 1, 3, 3), dtype="float32"), w: R.Tensor((1, 1, 2, 2), dtype="float32")) -> R.Tensor((1, 1, 5, 5), dtype="float32"): + cls = Expected + gv = R.call_tir(cls.conv2d_transpose, (x, w), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Conv2dTranspose) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv_transpose_unsupported_raises(): + # fmt: off + @tvm.script.ir_module + class Layout2d: + @R.function + def main(x: R.Tensor((1, 4, 4, 1), "float32"), w: R.Tensor((1, 2, 3, 3), "float32")): + gv = R.nn.conv2d_transpose(x, w, data_layout="NHWC", kernel_layout="IOHW") + return gv + + @tvm.script.ir_module + class OutLayout2d: + @R.function + def main(x: R.Tensor((1, 1, 5, 5), "float32"), w: R.Tensor((1, 2, 3, 3), "float32")): + gv = R.nn.conv2d_transpose(x, w, out_layout="NHWC") + return gv + # fmt: on + + for mod in [Layout2d, OutLayout2d]: + with pytest.raises(tvm.error.OpNotImplemented): + LegalizeOps()(mod) + + def test_max_pool2d(): # fmt: off @tvm.script.ir_module