Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 61 additions & 70 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import math

import tvm
from tvm import s_tir, te, tirx, topi

from ...block_builder import BlockBuilder
Expand Down Expand Up @@ -153,70 +154,62 @@ def _nn_conv3d(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.nn.conv1d_transpose")
def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
logging.info(
"TOPI conv1d_transpose does not support different input-output "
"layouts, and thus cannot be legalized by TOPI"
raise tvm.error.OpNotImplemented(
"nn.conv1d_transpose with out_layout != data_layout is not yet lowered by TOPI."
)
return call
if call.attrs.data_layout != "NCW" or call.attrs.kernel_layout != "IOW":
logging.info(
"TOPI conv1d_transpose does not support input layout other than NCW, "
"and kernel layout other than IOW, so cannot be legalized by TOPI"
raise tvm.error.OpNotImplemented(
f"nn.conv1d_transpose with data_layout={call.attrs.data_layout!r}, "
f"kernel_layout={call.attrs.kernel_layout!r} is not yet lowered by TOPI (only NCW/IOW)."
)
return call
dilation = call.attrs.dilation
if len(dilation) != 1 or dilation[0] != 1:
logging.info(
"TOPI conv1d_transpose does not support dilations other than 1, "
"and thus cannot be legalized by TOPI"
strides = [int(s) for s in call.attrs.strides]
padding = [int(p) for p in call.attrs.padding]
output_padding = [int(o) for o in call.attrs.output_padding]
groups = int(call.attrs.groups)
out_dtype = call.struct_info.dtype
dilation = [int(d) for d in call.attrs.dilation]

def te_conv1d_transpose(data, kernel):
# Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel.
if any(d != 1 for d in dilation):
kernel = topi.nn.dilate(kernel, [1, 1, dilation[0]], name="kernel_dilate")
return topi.nn.group_conv1d_transpose_ncw(
data, kernel, strides, padding, out_dtype, output_padding, groups
)
return call

return bb.call_te(
topi.nn.group_conv1d_transpose_ncw,
call.args[0],
call.args[1],
stride=call.attrs.strides,
padding=call.attrs.padding,
out_dtype=call.struct_info.dtype,
output_padding=call.attrs.output_padding,
groups=call.attrs.groups,
primfunc_name_hint="conv1d_transpose",
te_conv1d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv1d_transpose"
)


@register_legalize("relax.nn.conv2d_transpose")
def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
logging.info(
"TOPI conv2d_transpose does not support different input-output "
"layouts, and thus cannot be legalized by TOPI"
raise tvm.error.OpNotImplemented(
"nn.conv2d_transpose with out_layout != data_layout is not yet lowered by TOPI."
)
return call
if call.attrs.data_layout != "NCHW" or call.attrs.kernel_layout != "IOHW":
logging.info(
"TOPI conv2d_transpose does not support input layout other than NCHW, "
"and kernel layout other than IOHW, so cannot be legalized by TOPI"
raise tvm.error.OpNotImplemented(
f"nn.conv2d_transpose with data_layout={call.attrs.data_layout!r}, "
f"kernel_layout={call.attrs.kernel_layout!r} is not yet lowered by TOPI (only NCHW/IOHW)."
)
return call
dilation = call.attrs.dilation
if len(dilation) != 2 or any(d != 1 for d in dilation):
logging.info(
"TOPI conv2d_transpose does not support dilations other than 1, "
"and thus cannot be legalized by TOPI"
strides = [int(s) for s in call.attrs.strides]
padding = [int(p) for p in call.attrs.padding]
output_padding = [int(o) for o in call.attrs.output_padding]
groups = int(call.attrs.groups)
out_dtype = call.struct_info.dtype
dilation = [int(d) for d in call.attrs.dilation]

def te_conv2d_transpose(data, kernel):
# Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel.
if any(d != 1 for d in dilation):
kernel = topi.nn.dilate(kernel, [1, 1, dilation[0], dilation[1]], name="kernel_dilate")
return topi.nn.group_conv2d_transpose_nchw(
data, kernel, strides, padding, out_dtype, output_padding, groups
)
return call

return bb.call_te(
topi.nn.group_conv2d_transpose_nchw,
call.args[0],
call.args[1],
stride=call.attrs.strides,
padding=call.attrs.padding,
out_dtype=call.struct_info.dtype,
output_padding=call.attrs.output_padding,
groups=call.attrs.groups,
primfunc_name_hint="conv2d_transpose",
te_conv2d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv2d_transpose"
)


Expand All @@ -225,35 +218,33 @@ def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr:
# Keep policy in sync with _nn_conv2d_transpose: only lower when TOPI supports
# the layout/dilation.
if call.attrs.out_layout != call.attrs.data_layout:
logging.info(
"TOPI conv3d_transpose does not support different input-output "
"layouts, and thus cannot be legalized by TOPI"
raise tvm.error.OpNotImplemented(
"nn.conv3d_transpose with out_layout != data_layout is not yet lowered by TOPI."
)
return call
if call.attrs.data_layout != "NCDHW" or call.attrs.kernel_layout != "IODHW":
logging.info(
"TOPI conv3d_transpose does not support input layout other than NCDHW, "
"and kernel layout other than IODHW, so cannot be legalized by TOPI"
)
return call
dilation = call.attrs.dilation
if len(dilation) != 3 or any(d != 1 for d in dilation):
logging.info(
"TOPI conv3d_transpose does not support dilations other than 1, "
"and thus cannot be legalized by TOPI"
raise tvm.error.OpNotImplemented(
f"nn.conv3d_transpose with data_layout={call.attrs.data_layout!r}, "
f"kernel_layout={call.attrs.kernel_layout!r} is not yet lowered by TOPI (only NCDHW/IODHW)."
)
strides = [int(s) for s in call.attrs.strides]
padding = [int(p) for p in call.attrs.padding]
output_padding = [int(o) for o in call.attrs.output_padding]
groups = int(call.attrs.groups)
out_dtype = call.struct_info.dtype
dilation = [int(d) for d in call.attrs.dilation]

def te_conv3d_transpose(data, kernel):
# Dilated transposed conv == transposed conv with a spatially dilated (zero-filled) kernel.
if any(d != 1 for d in dilation):
kernel = topi.nn.dilate(
kernel, [1, 1, dilation[0], dilation[1], dilation[2]], name="kernel_dilate"
)
return topi.nn.group_conv3d_transpose_ncdhw(
data, kernel, strides, padding, out_dtype, output_padding, groups
)
return call

return bb.call_te(
topi.nn.group_conv3d_transpose_ncdhw,
call.args[0],
call.args[1],
strides=call.attrs.strides,
padding=call.attrs.padding,
out_dtype=call.struct_info.dtype,
output_padding=call.attrs.output_padding,
groups=call.attrs.groups,
primfunc_name_hint="conv3d_transpose",
te_conv3d_transpose, call.args[0], call.args[1], primfunc_name_hint="conv3d_transpose"
)


Expand Down
74 changes: 74 additions & 0 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,80 @@ def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle,
tvm.ir.assert_structural_equal(mod, Expected)


def test_conv2d_transpose_dilation():
# fmt: off
@tvm.script.ir_module
class Conv2dTranspose:
@R.function
def main(x: R.Tensor((1, 1, 3, 3), "float32"), w: R.Tensor((1, 1, 2, 2), "float32")):
gv = R.nn.conv2d_transpose(x, w, dilation=(2, 2))
return gv

@I.ir_module(s_tir=True)
class Expected:
@T.prim_func(private=True, s_tir=True)
def conv2d_transpose(x: T.Buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3)), "float32"), w: T.Buffer((T.int64(1), T.int64(1), T.int64(2), T.int64(2)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32")):
T.func_attr({"tirx.noalias": True})
data_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3)))
data_pad = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(7), T.int64(7)))
kernel_dilate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3)))
kernel_transform = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(3), T.int64(3)))
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)):
with T.sblock("data_dilate"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
data_dilate[v_i0, v_i1, v_i2, v_i3] = x[v_i0, v_i1, v_i2, v_i3]
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(7), T.int64(7)):
with T.sblock("data_pad"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(5) and T.int64(2) <= v_i3 and v_i3 < T.int64(5), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2)], T.float32(0.0))
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)):
with T.sblock("kernel_dilate"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
kernel_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 % T.int64(2) == T.int64(0) and v_i3 % T.int64(2) == T.int64(0), w[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(2)], T.float32(0.0))
for o, i, h, w_1 in T.grid(T.int64(1), T.int64(1), T.int64(3), T.int64(3)):
with T.sblock("kernel_transform"):
v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w_1])
kernel_transform[v_o, v_i, v_h, v_w] = kernel_dilate[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w]
for b, c, h, w_1, dc, dh, dw in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5), T.int64(1), T.int64(3), T.int64(3)):
with T.sblock("compute"):
v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = T.axis.remap("SSSSRRR", [b, c, h, w_1, dc, dh, dw])
with T.init():
compute[v_b, v_c, v_h, v_w] = T.float32(0.0)
compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] + data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, v_dh, v_dw]

@R.function
def main(x: R.Tensor((1, 1, 3, 3), dtype="float32"), w: R.Tensor((1, 1, 2, 2), dtype="float32")) -> R.Tensor((1, 1, 5, 5), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.conv2d_transpose, (x, w), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32"))
return gv
# fmt: on

mod = LegalizeOps()(Conv2dTranspose)
tvm.ir.assert_structural_equal(mod, Expected)


def test_conv_transpose_unsupported_raises():
# fmt: off
@tvm.script.ir_module
class Layout2d:
@R.function
def main(x: R.Tensor((1, 4, 4, 1), "float32"), w: R.Tensor((1, 2, 3, 3), "float32")):
gv = R.nn.conv2d_transpose(x, w, data_layout="NHWC", kernel_layout="IOHW")
return gv

@tvm.script.ir_module
class OutLayout2d:
@R.function
def main(x: R.Tensor((1, 1, 5, 5), "float32"), w: R.Tensor((1, 2, 3, 3), "float32")):
gv = R.nn.conv2d_transpose(x, w, out_layout="NHWC")
return gv
# fmt: on

for mod in [Layout2d, OutLayout2d]:
with pytest.raises(tvm.error.OpNotImplemented):
LegalizeOps()(mod)


def test_max_pool2d():
# fmt: off
@tvm.script.ir_module
Expand Down
Loading