From 6fd80a2a5c5d873f0ad57d696493808f9e459a86 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 6 Feb 2026 08:59:46 -0500 Subject: [PATCH] [REFACTOR][TEST] Migrate all codegen test to tvmscript This PR migrates all the codegen tests to explicitly using tvmscript instead of indirectly via s_tir.Schedule. They makes the test surface more unit, contains less dep and more maintainable. --- .../codegen/test_gpu_codegen_allreduce.py | 100 +- .../codegen/test_target_codegen_aarch64.py | 604 +++++---- .../python/codegen/test_target_codegen_arm.py | 87 +- .../codegen/test_target_codegen_bool.py | 92 +- .../codegen/test_target_codegen_c_host.py | 159 ++- .../codegen/test_target_codegen_cross_llvm.py | 39 +- .../codegen/test_target_codegen_cuda.py | 574 +++++---- .../codegen/test_target_codegen_cuda_fp4.py | 178 +-- .../codegen/test_target_codegen_cuda_fp8.py | 263 ++-- .../codegen/test_target_codegen_device.py | 101 +- .../codegen/test_target_codegen_gpu_common.py | 37 +- .../codegen/test_target_codegen_hexagon.py | 70 +- .../codegen/test_target_codegen_llvm.py | 1136 +++++++++-------- .../codegen/test_target_codegen_metal.py | 59 +- .../codegen/test_target_codegen_opencl.py | 225 ++-- .../codegen/test_target_codegen_rocm.py | 65 +- .../codegen/test_target_codegen_vulkan.py | 382 +++--- .../python/codegen/test_target_codegen_x86.py | 26 +- 18 files changed, 2343 insertions(+), 1854 deletions(-) diff --git a/tests/python/codegen/test_gpu_codegen_allreduce.py b/tests/python/codegen/test_gpu_codegen_allreduce.py index 083ae99ec654..3e9fd2cd6b67 100644 --- a/tests/python/codegen/test_gpu_codegen_allreduce.py +++ b/tests/python/codegen/test_gpu_codegen_allreduce.py @@ -18,35 +18,49 @@ import tvm_ffi import tvm.testing import numpy as np -from tvm.script import tir as T +from tvm.script import tir as T, ir as I import pytest -@T.prim_func -def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None: - A = T.match_buffer(a, [1, d1, d2, d3]) - B = T.match_buffer(b, [1, d1, d2]) - - for i, j, k, l in T.grid(1, d1, d2, d3): - with T.sblock("reduce"): - vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) - with T.init(): - B[vi, vj, vk] = 0.0 - B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl] - - -@T.prim_func -def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None: - A = T.match_buffer(a, [1, d1, d2, d3]) - B = T.match_buffer(b, [1, d1, d2]) - - for i, j, k, l in T.grid(1, d1, d2, d3): - with T.sblock("reduce"): - vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) - with T.init(): - B[vi, vj, vk] = T.float32(-3.4028234663852886e38) - B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl]) +def _reduce_sum_module(d1, d2, d3): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1, d1, d2, d3), "float32"), B: T.Buffer((1, d1, d2), "float32")): + for i in T.thread_binding(1, thread="blockIdx.x"): + for j in T.thread_binding(d1, thread="threadIdx.z"): + for k in T.thread_binding(d2, thread="threadIdx.y"): + for l in T.thread_binding(d3, thread="threadIdx.x"): + with T.sblock("reduce"): + vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) + T.reads(A[vi, vj, vk, vl]) + T.writes(B[vi, vj, vk]) + with T.init(): + B[vi, vj, vk] = T.float32(0.0) + B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl] + + return Module + + +def _reduce_max_module(d1, d2, d3): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1, d1, d2, d3), "float32"), B: T.Buffer((1, d1, d2), "float32")): + for i in T.thread_binding(1, thread="blockIdx.x"): + for j in T.thread_binding(d1, thread="threadIdx.z"): + for k in T.thread_binding(d2, thread="threadIdx.y"): + for l in T.thread_binding(d3, thread="threadIdx.x"): + with T.sblock("reduce"): + vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l]) + T.reads(A[vi, vj, vk, vl]) + T.writes(B[vi, vj, vk]) + with T.init(): + B[vi, vj, vk] = T.float32(-3.4028234663852886e38) + B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl]) + + return Module def generate_param_sets(): @@ -63,16 +77,8 @@ def generate_param_sets(): @tvm.testing.parametrize_targets("cuda", "metal") def test_allreduce_sum(dims, target, dev): d1, d2, d3 = dims - _, _, _d1, _d2, _d3 = reduce.params - mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) - sch = tvm.s_tir.Schedule(mod) - blk = sch.get_sblock("reduce") - i, j, k, l = sch.get_loops(blk) - sch.bind(i, "blockIdx.x") - sch.bind(j, "threadIdx.z") - sch.bind(k, "threadIdx.y") - sch.bind(l, "threadIdx.x") - f = tvm.compile(sch.mod["main"], target=target) + mod = _reduce_sum_module(d1, d2, d3) + f = tvm.compile(mod, target=target) # prepare input and output array a_np = np.random.rand(1, d1, d2, d3).astype("float32") @@ -117,31 +123,15 @@ def test_allreduce_sum_compile(optional_metal_compile_callback): target = "metal" d1, d2, d3 = dims - _, _, _d1, _d2, _d3 = reduce.params - mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) - sch = tvm.s_tir.Schedule(mod) - blk = sch.get_sblock("reduce") - i, j, k, l = sch.get_loops(blk) - sch.bind(i, "blockIdx.x") - sch.bind(j, "threadIdx.z") - sch.bind(k, "threadIdx.y") - sch.bind(l, "threadIdx.x") - tvm.compile(sch.mod["main"], target=target) + mod = _reduce_sum_module(d1, d2, d3) + tvm.compile(mod, target=target) @tvm.testing.parametrize_targets("cuda", "metal") def test_allreduce_max(dims, target, dev): d1, d2, d3 = dims - _, _, _d1, _d2, _d3 = reduce_max.params - mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3}) - sch = tvm.s_tir.Schedule(mod) - blk = sch.get_sblock("reduce") - i, j, k, l = sch.get_loops(blk) - sch.bind(i, "blockIdx.x") - sch.bind(j, "threadIdx.z") - sch.bind(k, "threadIdx.y") - sch.bind(l, "threadIdx.x") - f = tvm.compile(sch.mod["main"], target=target) + mod = _reduce_max_module(d1, d2, d3) + f = tvm.compile(mod, target=target) # prepare input and output array a_np = -np.random.rand(1, d1, d2, d3).astype("float32") diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 90ad1d65c7aa..5517d988a35d 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -23,8 +23,7 @@ import pytest import tvm -from tvm import te -from tvm.script import tir as T +from tvm.script import tir as T, ir as I from tvm.target.codegen import llvm_version_major @@ -38,26 +37,33 @@ def test_mul(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: A[i] * B[i], name="C") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] * B[v_i] - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) - - # Verify we see SVE load instructions and mul instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"mul\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"mul\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -70,26 +76,33 @@ def check_correct_assembly(type): def test_add(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: A[i] + B[i], name="C") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + B[v_i] - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) - - # Verify we see SVE load instructions and add instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"add\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"add\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -102,26 +115,33 @@ def check_correct_assembly(type): def test_sub(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: A[i] - B[i], name="C") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] - B[v_i] - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) - - # Verify we see SVE load instructions and sub instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"sub\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"sub\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -134,27 +154,34 @@ def check_correct_assembly(type): def test_muladd(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.placeholder(m, dtype=type, name="C") - D = te.compute((m), lambda i: A[i] * B[i] + C[i], name="D") - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C, D])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_D: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + D = T.match_buffer(var_D, (m,), dtype=dtype) + for i in range(m): + with T.sblock("D"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i], C[v_i]) + T.writes(D[v_i]) + D[v_i] = A[v_i] * B[v_i] + C[v_i] - # Verify we see SVE load instructions and either mad or mla instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"mad|mla\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"mad|mla\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -167,30 +194,37 @@ def check_correct_assembly(type): def test_max(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: tvm.te.max(A[i], B[i])) - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.max(A[v_i], B[v_i]) - # Verify we see SVE load instructions and cmgt + sel instructions or a max instruction, all using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - compare = re.findall( - r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) - select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly) - max = re.findall( - r"max\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert (len(compare) > 1 and len(select) == len(compare)) or len(max) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + compare = re.findall( + r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly) + max_instr = re.findall( + r"max\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert (len(compare) > 1 and len(select) == len(compare)) or len(max_instr) > 1 @pytest.mark.skipif( @@ -203,30 +237,37 @@ def check_correct_assembly(type): def test_min(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: tvm.te.min(A[i], B[i])) - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.min(A[v_i], B[v_i]) - # Verify we see SVE load instructions and cmgt + sel instructions or a min instruction, all using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - compare = re.findall( - r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) - select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly) - min = re.findall( - r"min\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert (len(compare) > 1 and len(select) == len(compare)) or len(min) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + compare = re.findall( + r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly) + min_instr = re.findall( + r"min\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert (len(compare) > 1 and len(select) == len(compare)) or len(min_instr) > 1 @pytest.mark.skipif( @@ -239,26 +280,33 @@ def check_correct_assembly(type): def test_div(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: tvm.te.div(A[i], B[i])) - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = tvm.tir.div(A[v_i], B[v_i]) - # Verify we see SVE load instructions and div instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"div\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) >= 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"div\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) >= 1 @pytest.mark.skipif( @@ -270,26 +318,33 @@ def check_correct_assembly(type): def test_mod(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: tvm.te.floormod(A[i], B[i]), name="C") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.floormod(A[v_i], B[v_i]) - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) - - # Verify we see SVE load instructions and mls instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"mls\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 0 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"mls\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 0 @pytest.mark.skipif( @@ -302,26 +357,33 @@ def check_correct_assembly(type): def test_eq(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: A[i] == B[i], name="C") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), "bool") + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] == B[v_i] - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) - - # Verify we see SVE load instructions and cmpeq or cmeq instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"cm(p)?eq\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"cm(p)?eq\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -334,26 +396,33 @@ def check_correct_assembly(type): def test_neq(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: A[i] != B[i], name="C") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), "bool") + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] != B[v_i] - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) - - # Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne instructions, all using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"cm(p)?(gt|ne)\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"cm(p)?(gt|ne)\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -365,26 +434,33 @@ def check_correct_assembly(type): def test_or(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: A[i] | B[i], name="C") - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] | B[v_i] - # Verify we see SVE load instructions and orr instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"orr\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"orr\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -396,26 +472,33 @@ def check_correct_assembly(type): def test_and(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype=type, name="B") - C = te.compute((m), lambda i: A[i] & B[i], name="C") - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] & B[v_i] - # Verify we see SVE load instructions and and instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"and\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"and\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -427,25 +510,32 @@ def check_correct_assembly(type): def test_not(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - C = te.compute((m), lambda i: ~A[i], name="C") - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, C])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i]) + T.writes(C[v_i]) + C[v_i] = ~A[v_i] - # Verify we see SVE load instructions and eor instructions using z registers - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) - matches = re.findall( - r"eor\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly - ) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 1 - assert len(matches) > 1 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"eor\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) - check_correct_assembly(type=dtype) + assert len(loads) > 1 + assert len(matches) > 1 @pytest.mark.skipif( @@ -461,22 +551,29 @@ def check_correct_assembly(type): def test_memcpy(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - def check_correct_assembly(type): - m = te.var("m") - A = te.placeholder(m, dtype=type, name="A") - B = te.placeholder(m, dtype="int32", name="B") - C = te.compute((m), lambda i: A[B[i]], name="C") - - with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, B, C])) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,), dtype=dtype) + B = T.match_buffer(var_B, (m,), "int32") + C = T.match_buffer(var_C, (m,), dtype=dtype) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(B[v_i], A[B[v_i]]) + T.writes(C[v_i]) + C[v_i] = A[B[v_i]] - # Verify we see gather instructions in the assembly - assembly = f.inspect_source("asm") - loads = re.findall("ld1[whdb] { z", assembly) + with tvm.target.Target(target): + f = tvm.tir.build(Module) - assert len(loads) > 0 + assembly = f.inspect_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) - check_correct_assembly(type=dtype) + assert len(loads) > 0 @pytest.mark.skipif( @@ -495,12 +592,23 @@ def check_correct_assembly(type): def test_vscale_range_function_attribute(mattr, expect_attr): target = f"llvm -mtriple=aarch64-linux-gnu -mattr={mattr}" - m = te.var("m") - A = te.placeholder(m, dtype="float32", name="A") - C = te.compute((m), lambda i: A[i] + 1, name="C") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int32() + A = T.match_buffer(var_A, (m,)) + C = T.match_buffer(var_C, (m,)) + for i in range(m): + with T.sblock("C"): + v_i = T.axis.spatial(m, i) + T.reads(A[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + T.float32(1) with tvm.target.Target(target): - f = tvm.tir.build(te.create_prim_func([A, C])) + f = tvm.tir.build(Module) # Check if the vscale_range() attribute exists ll = f.inspect_source("ll") diff --git a/tests/python/codegen/test_target_codegen_arm.py b/tests/python/codegen/test_target_codegen_arm.py index b85b96bc0393..d6791c785977 100644 --- a/tests/python/codegen/test_target_codegen_arm.py +++ b/tests/python/codegen/test_target_codegen_arm.py @@ -15,20 +15,27 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te import re +from tvm.script import tir as T, ir as I def test_popcount(): target = "llvm -mtriple=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon" def check_correct_assembly(type, elements, counts): - n = tvm.runtime.convert(elements) - A = te.placeholder(n, dtype=type, name="A") - B = te.compute(A.shape, lambda i: tvm.tir.popcount(A[i]), name="B") - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - sch.vectorize(sch.get_loops("B")[0]) - f = tvm.tir.build(sch.mod, target=target) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((elements,), type), B: T.Buffer((elements,), type)): + T.func_attr({"tir.noalias": True}) + for i in T.vectorized(elements): + with T.sblock("B"): + v_i = T.axis.spatial(elements, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.popcount(A[v_i]) + + f = tvm.tir.build(Module, target=target) # Verify we see the correct number of vpaddl and vcnt instructions in the assembly assembly = f.inspect_source("asm") matches = re.findall("vpaddl", assembly) @@ -47,18 +54,27 @@ def test_vmlal_s16(): target = "llvm -mtriple=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon" def check_correct_assembly(N): - K = te.size_var("K") - A = te.placeholder((K, N), dtype="int8", name="A") - B = te.placeholder((K, N), dtype="int8", name="B") - k = te.reduce_axis((0, K)) - C = te.compute( - (N,), - lambda n: te.sum(A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), - name="C", - ) - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B, C])) - sch.vectorize(sch.get_loops("C")[0]) - f = tvm.tir.build(sch.mod, target=target) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, C: T.Buffer((N,), "int32")): + T.func_attr({"tir.noalias": True}) + K = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (K, N), "int8") + B = T.match_buffer(var_B, (K, N), "int8") + for n in T.vectorized(N): + for rv in range(K): + with T.sblock("C"): + v_n, v_rv = T.axis.remap("SR", [n, rv]) + T.reads(A[v_rv, v_n], B[v_rv, v_n]) + T.writes(C[v_n]) + with T.init(): + C[v_n] = 0 + C[v_n] = C[v_n] + T.Cast("int32", A[v_rv, v_n]) * T.Cast( + "int32", B[v_rv, v_n] + ) + + f = tvm.tir.build(Module, target=target) # Verify we see the correct number of vmlal.s16 instructions assembly = f.inspect_source("asm") @@ -71,18 +87,27 @@ def check_correct_assembly(N): check_correct_assembly(64) def check_broadcast_correct_assembly(N): - K = te.size_var("K") - A = te.placeholder((K, N), dtype="int8", name="A") - B = te.placeholder((K,), dtype="int8", name="B") - k = te.reduce_axis((0, K)) - C = te.compute( - (N,), - lambda n: te.sum(A[k, n].astype("int32") * B[k].astype("int32"), axis=[k]), - name="C", - ) - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B, C])) - sch.vectorize(sch.get_loops("C")[0]) - f = tvm.tir.build(sch.mod, target=target) + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, C: T.Buffer((N,), "int32")): + T.func_attr({"tir.noalias": True}) + K = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (K, N), "int8") + B = T.match_buffer(var_B, (K,), "int8") + for n in T.vectorized(N): + for rv in range(K): + with T.sblock("C"): + v_n, v_rv = T.axis.remap("SR", [n, rv]) + T.reads(A[v_rv, v_n], B[v_rv]) + T.writes(C[v_n]) + with T.init(): + C[v_n] = 0 + C[v_n] = C[v_n] + T.Cast("int32", A[v_rv, v_n]) * T.Cast( + "int32", B[v_rv] + ) + + f = tvm.tir.build(Module, target=target) # Verify we see the correct number of vmlal.s16 instructions assembly = f.inspect_source("asm") diff --git a/tests/python/codegen/test_target_codegen_bool.py b/tests/python/codegen/test_target_codegen_bool.py index afb8afe35564..d95379d80817 100644 --- a/tests/python/codegen/test_target_codegen_bool.py +++ b/tests/python/codegen/test_target_codegen_bool.py @@ -16,49 +16,75 @@ # under the License. """codegen related to bool types""" -import tvm -import tvm.testing -from tvm import te import numpy as np -import tvm.testing - -arr_size = tvm.testing.parameter(32) - -@tvm.testing.fixture -def compute(arr_size): - A = te.placeholder((arr_size,), name="A") - B = te.placeholder((arr_size,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) > B(*i), name="C") - D = te.compute(C.shape, lambda *i: tvm.tir.all(C(*i), A(*i) > 1).astype("float32"), name="D") - return [A, B, C, D] +import tvm +import tvm.testing +from tvm.script import ir as I +from tvm.script import tir as T -@tvm.testing.fixture -def get_module(target, compute): - target = tvm.target.Target(target) - A, B, C, D = compute - if target.kind.name == "llvm": - return tvm.IRModule.from_expr(te.create_prim_func([A, B, D])) +@tvm.testing.uses_gpu +def test_cmp_load_store(target, dev): + @I.ir_module + class GPUModule: + @T.prim_func + def main( + A: T.Buffer((32,), "float32"), + B: T.Buffer((32,), "float32"), + D: T.Buffer((32,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + C = T.alloc_buffer((32,), "bool") + for i0_0 in T.thread_binding(8, thread="blockIdx.x"): + for i0_1 in T.thread_binding(4, thread="blockIdx.x"): + with T.sblock("C"): + v_i0 = T.axis.spatial(32, i0_0 * 4 + i0_1) + T.reads(B[v_i0], A[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = B[v_i0] < A[v_i0] + for i0_0 in T.thread_binding(8, thread="blockIdx.x"): + for i0_1 in T.thread_binding(4, thread="blockIdx.x"): + with T.sblock("D"): + v_i0 = T.axis.spatial(32, i0_0 * 4 + i0_1) + T.reads(C[v_i0], A[v_i0]) + T.writes(D[v_i0]) + D[v_i0] = T.Cast("float32", C[v_i0] and T.float32(1.0) < A[v_i0]) - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B, D])) - for stage in ["C", "D"]: - xo, xi = sch.split(sch.get_loops(stage)[0], factors=[None, 4]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "blockIdx.x") - return sch.mod + @I.ir_module + class CPUModule: + @T.prim_func + def main( + A: T.Buffer((32,), "float32"), + B: T.Buffer((32,), "float32"), + D: T.Buffer((32,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + C = T.alloc_buffer((32,), "bool") + for i0 in range(32): + with T.sblock("C"): + v_i0 = T.axis.spatial(32, i0) + T.reads(B[v_i0], A[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = B[v_i0] < A[v_i0] + for i0 in range(32): + with T.sblock("D"): + v_i0 = T.axis.spatial(32, i0) + T.reads(C[v_i0], A[v_i0]) + T.writes(D[v_i0]) + D[v_i0] = T.Cast("float32", C[v_i0] and T.float32(1.0) < A[v_i0]) + arr_size = 32 + is_gpu = tvm.target.Target(target).kind.name != "llvm" + mod = GPUModule if is_gpu else CPUModule -@tvm.testing.uses_gpu -def test_cmp_load_store(target, dev, arr_size, compute, get_module): - A, B, _, D = compute - f = tvm.compile(get_module, target=target) + f = tvm.compile(mod, target=target) - a_np = np.random.uniform(size=arr_size).astype(A.dtype) - b_np = np.random.uniform(size=arr_size).astype(B.dtype) + a_np = np.random.uniform(size=arr_size).astype("float32") + b_np = np.random.uniform(size=arr_size).astype("float32") a = tvm.runtime.tensor(a_np, dev) b = tvm.runtime.tensor(b_np, dev) - d = tvm.runtime.tensor(np.zeros(arr_size, dtype=D.dtype), dev) + d = tvm.runtime.tensor(np.zeros(arr_size, dtype="float32"), dev) f(a, b, d) np.testing.assert_equal( d.numpy(), diff --git a/tests/python/codegen/test_target_codegen_c_host.py b/tests/python/codegen/test_target_codegen_c_host.py index e95108aeac17..0c46a0c6aee6 100644 --- a/tests/python/codegen/test_target_codegen_c_host.py +++ b/tests/python/codegen/test_target_codegen_c_host.py @@ -18,7 +18,6 @@ import tvm import tvm.testing -from tvm import te from tvm.contrib import utils from tvm.script import tir as T, ir as I @@ -27,29 +26,35 @@ def test_add(): nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") + + @I.ir_module + class Module: + @T.prim_func + def test_fadd( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0 in range(1024): + with T.sblock("C"): + v_i0 = T.axis.spatial(1024, i0) + T.reads(A[v_i0], B[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = A[v_i0] + B[v_i0] def check_c(): - mhost = tvm.compile( - tvm.IRModule.from_expr( - te.create_prim_func([A, B, C]).with_attr("global_symbol", "test_fadd") - ), - target="c", - ) + mhost = tvm.compile(Module, target="c") temp = utils.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) m = tvm.runtime.load_module(path_dso) fadd = m["test_fadd"] dev = tvm.cpu(0) - # launch the kernel. n = nn - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) fadd(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) @@ -58,19 +63,24 @@ def check_c(): def test_reinterpret(): nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A", dtype="int32") - B = te.compute( - A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", 2 + A(*i)), name="B" - ) + + @I.ir_module + class Module: + @T.prim_func + def test_reinterpret( + A: T.Buffer((1024,), "int32"), + B: T.Buffer((1024,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0 in range(1024): + with T.sblock("B"): + v_i0 = T.axis.spatial(1024, i0) + T.reads(A[v_i0]) + T.writes(B[v_i0]) + B[v_i0] = T.reinterpret("float32", A[v_i0] + 2) def check_c(): - mhost = tvm.compile( - tvm.IRModule.from_expr( - te.create_prim_func([A, B]).with_attr("global_symbol", "test_reinterpret") - ), - target="c", - ) + mhost = tvm.compile(Module, target="c") temp = utils.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) @@ -78,8 +88,8 @@ def check_c(): fadd = m["test_reinterpret"] dev = tvm.cpu(0) n = nn - a = tvm.runtime.tensor(np.random.randint(-(2**30), 2**30, size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.randint(-(2**30), 2**30, size=n).astype("int32"), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) fadd(a, b) tvm.testing.assert_allclose(b.numpy(), (2 + a.numpy()).view("float32")) @@ -88,17 +98,24 @@ def check_c(): def test_ceil(): nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A", dtype="float32") - B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.ceil", A(*i)), name="B") + + @I.ir_module + class Module: + @T.prim_func + def test_ceil( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0 in range(1024): + with T.sblock("B"): + v_i0 = T.axis.spatial(1024, i0) + T.reads(A[v_i0]) + T.writes(B[v_i0]) + B[v_i0] = T.ceil(A[v_i0]) def check_c(): - mhost = tvm.compile( - tvm.IRModule.from_expr( - te.create_prim_func([A, B]).with_attr("global_symbol", "test_ceil") - ), - target="c", - ) + mhost = tvm.compile(Module, target="c") temp = utils.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) @@ -106,8 +123,8 @@ def check_c(): fceil = m["test_ceil"] dev = tvm.cpu(0) n = nn - a = tvm.runtime.tensor(np.random.rand(n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.rand(n).astype("float32"), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) fceil(a, b) tvm.testing.assert_allclose(b.numpy(), (np.ceil(a.numpy()).view("float32"))) @@ -116,17 +133,24 @@ def check_c(): def test_floor(): nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A", dtype="float32") - B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.floor", A(*i)), name="B") + + @I.ir_module + class Module: + @T.prim_func + def test_floor( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0 in range(1024): + with T.sblock("B"): + v_i0 = T.axis.spatial(1024, i0) + T.reads(A[v_i0]) + T.writes(B[v_i0]) + B[v_i0] = T.floor(A[v_i0]) def check_c(): - mhost = tvm.compile( - tvm.IRModule.from_expr( - te.create_prim_func([A, B]).with_attr("global_symbol", "test_floor") - ), - target="c", - ) + mhost = tvm.compile(Module, target="c") temp = utils.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) @@ -134,8 +158,8 @@ def check_c(): ffloor = m["test_floor"] dev = tvm.cpu(0) n = nn - a = tvm.runtime.tensor(np.random.rand(n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.rand(n).astype("float32"), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) ffloor(a, b) tvm.testing.assert_allclose(b.numpy(), (np.floor(a.numpy()).view("float32"))) @@ -144,17 +168,24 @@ def check_c(): def test_round(): nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A", dtype="float32") - B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.round", A(*i)), name="B") + + @I.ir_module + class Module: + @T.prim_func + def test_round( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0 in range(1024): + with T.sblock("B"): + v_i0 = T.axis.spatial(1024, i0) + T.reads(A[v_i0]) + T.writes(B[v_i0]) + B[v_i0] = T.round(A[v_i0]) def check_c(): - mhost = tvm.compile( - tvm.IRModule.from_expr( - te.create_prim_func([A, B]).with_attr("global_symbol", "test_round") - ), - target="c", - ) + mhost = tvm.compile(Module, target="c") temp = utils.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) @@ -162,8 +193,8 @@ def check_c(): fround = m["test_round"] dev = tvm.cpu(0) n = nn - a = tvm.runtime.tensor(np.random.rand(n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) + a = tvm.runtime.tensor(np.random.rand(n).astype("float32"), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) fround(a, b) tvm.testing.assert_allclose(b.numpy(), (np.round(a.numpy()).view("float32"))) @@ -172,17 +203,17 @@ def check_c(): def test_subroutine_call(): @I.ir_module - class mod: + class Module: @T.prim_func def main(A: T.Buffer(1, dtype="float32")): - mod.subroutine(A.data) + Module.subroutine(A.data) @T.prim_func(private=True) def subroutine(A_data: T.handle("float32")): A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 42.0 - built = tvm.tir.build(mod, target="c") + built = tvm.tir.build(Module, target="c") source = built.inspect_source() assert ( diff --git a/tests/python/codegen/test_target_codegen_cross_llvm.py b/tests/python/codegen/test_target_codegen_cross_llvm.py index 3942012a67d8..50220993a28c 100644 --- a/tests/python/codegen/test_target_codegen_cross_llvm.py +++ b/tests/python/codegen/test_target_codegen_cross_llvm.py @@ -17,26 +17,35 @@ """Test cross compilation""" import tvm import tvm.testing -from tvm import te import os import struct from tvm import rpc from tvm.contrib import utils, cc +from tvm.script import tir as T, ir as I import numpy as np +@I.ir_module +class AddModule: + @T.prim_func + def main( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0_0 in T.parallel(256): + for i0_1 in T.vectorized(4): + with T.sblock("C"): + v_i0 = T.axis.spatial(1024, i0_0 * 4 + i0_1) + T.reads(A[v_i0], B[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = A[v_i0] + B[v_i0] + + @tvm.testing.requires_llvm def test_llvm_add_pipeline(): nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B, C])) - xo, xi = sch.split(sch.get_loops("C")[0], factors=[None, 4]) - sch.parallel(xo) - sch.vectorize(xi) def verify_elf(path, e_machine): with open(path, "rb") as fi: @@ -49,7 +58,7 @@ def verify_elf(path, e_machine): def build_i386(): temp = utils.tempdir() target = "llvm -mtriple=i386-pc-linux-gnu" - f = tvm.tir.build(sch.mod, target=target) + f = tvm.tir.build(AddModule, target=target) path = temp.relpath("myadd.o") f.write_to_file(path) verify_elf(path, 0x03) @@ -60,7 +69,7 @@ def build_arm(): print("Skip because %s is not enabled.." % target) return temp = utils.tempdir() - f = tvm.tir.build(sch.mod, target=target) + f = tvm.tir.build(AddModule, target=target) path = temp.relpath("myadd.o") f.write_to_file(path) verify_elf(path, 0x28) @@ -81,9 +90,9 @@ def build_arm(): farm = remote.load_module("myadd.o") dev = remote.cpu(0) n = nn - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) farm(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) print("Verification finish on remote..") diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index ad89619c5648..433a7ed0e2e0 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -22,7 +22,6 @@ import tvm import tvm.contrib.nvcc import tvm.testing -from tvm import te, topi from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8 from tvm.script import ir as I from tvm.script import tir as T @@ -65,18 +64,29 @@ def check_cuda(dtype, n, lanes): if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version): print("skip because gpu does not support int8") return - A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) - B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") - - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, num_thread]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "threadIdx.x") - fun = tvm.compile(sch.mod, target="cuda") + vec_dtype = "%sx%d" % (dtype, lanes) + one = tvm.tir.const(1, vec_dtype) + num_blocks = (n + num_thread - 1) // num_thread + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): + for i_1 in T.thread_binding(num_thread, thread="threadIdx.x"): + with T.sblock("B"): + v_i = T.axis.spatial(n, i_0 * num_thread + i_1) + T.where(i_0 * num_thread + i_1 < n) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = A[v_i] + one + + fun = tvm.compile(Module, target="cuda") dev = tvm.cuda(0) - a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) - c = tvm.runtime.empty((n,), B.dtype, dev) + a = tvm.runtime.empty((n,), vec_dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) + c = tvm.runtime.empty((n,), vec_dtype, dev) fun(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) @@ -117,22 +127,32 @@ def np_bf162np_float(arr): return u32.view(" b, name="C") + half_const = tvm.tir.const(0.5, dtype="float16") - sch = tvm.s_tir.Schedule(te.create_prim_func([a, c])) - xo, xi = sch.split(sch.fuse(*sch.get_loops("C")), factors=[None, 64]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "threadIdx.x") - func = tvm.compile(sch.mod, target="cuda") + @I.ir_module + class Module: + @T.prim_func + def main(a: T.Buffer((2, 3, 4), "float16"), C: T.Buffer((2, 3, 4), "bool")): + T.func_attr({"tir.noalias": True}) + for i_j_k_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_j_k_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(2, (i_j_k_fused_0 * 64 + i_j_k_fused_1) // 12) + v_j = T.axis.spatial(3, (i_j_k_fused_0 * 64 + i_j_k_fused_1) % 12 // 4) + v_k = T.axis.spatial(4, (i_j_k_fused_0 * 64 + i_j_k_fused_1) % 4) + T.where(i_j_k_fused_0 * 64 + i_j_k_fused_1 < 24) + T.reads(a[v_i, v_j, v_k]) + T.writes(C[v_i, v_j, v_k]) + C[v_i, v_j, v_k] = half_const < a[v_i, v_j, v_k] + + func = tvm.compile(Module, target="cuda") dev = tvm.cuda(0) - a_np = np.random.uniform(size=shape).astype(a.dtype) - c_np = np.zeros(shape=shape, dtype=c.dtype) + shape = (2, 3, 4) + a_np = np.random.uniform(size=shape).astype("float16") + c_np = np.zeros(shape=shape, dtype="bool") a = tvm.runtime.tensor(a_np, dev) c = tvm.runtime.tensor(c_np, dev) func(a, c) - np.testing.assert_equal(c.numpy(), a_np > b.value) + np.testing.assert_equal(c.numpy(), a_np > 0.5) @tvm.testing.requires_gpu @@ -391,19 +492,25 @@ def test_cuda_floordiv_with_vectorization(): # B[i] = A[floordiv(i, k)] n = 256 k = 37 - A = te.placeholder((n,), name="A") - B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name="B") - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - xo, xi = sch.split(sch.get_loops("B")[0], factors=[1, None]) - xio, xii = sch.split(xi, factors=[None, 4]) - sch.vectorize(xii) - sch.bind(xo, "blockIdx.x") - sch.bind(xio, "threadIdx.x") - func = tvm.compile(sch.mod, target="cuda") + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((256,), "float32"), B: T.Buffer((256,), "float32")): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_1_0 in T.thread_binding(64, thread="threadIdx.x"): + for i_1_1 in T.vectorized(4): + with T.sblock("B"): + v_i = T.axis.spatial(256, i_0 * 256 + i_1_0 * 4 + i_1_1) + T.reads(A[v_i // 37]) + T.writes(B[v_i]) + B[v_i] = A[v_i // 37] + + func = tvm.compile(Module, target="cuda") dev = tvm.cuda(0) - a_np = np.random.uniform(size=(n,)).astype(A.dtype) + a_np = np.random.uniform(size=(n,)).astype("float32") b_np = np.array([a_np[i // k] for i in range(0, n)]) a_nd = tvm.runtime.tensor(a_np, dev) b_nd = tvm.runtime.tensor(np.zeros(b_np.shape, dtype=b_np.dtype), dev) @@ -418,18 +525,25 @@ def test_cuda_floormod_with_vectorization(): # B[i] = A[floormod(i, k)] n = 256 k = 37 - A = te.placeholder((n,), name="A") - B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name="B") - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - xo, xi = sch.split(sch.get_loops("B")[0], factors=[1, None]) - xio, xii = sch.split(xi, factors=[None, 4]) - sch.vectorize(xii) - sch.bind(xo, "blockIdx.x") - sch.bind(xio, "threadIdx.x") - func = tvm.compile(sch.mod, target="cuda") + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((256,), "float32"), B: T.Buffer((256,), "float32")): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_1_0 in T.thread_binding(64, thread="threadIdx.x"): + for i_1_1 in T.vectorized(4): + with T.sblock("B"): + v_i = T.axis.spatial(256, i_0 * 256 + i_1_0 * 4 + i_1_1) + T.reads(A[v_i % 37]) + T.writes(B[v_i]) + B[v_i] = A[v_i % 37] + + func = tvm.compile(Module, target="cuda") dev = tvm.cuda(0) - a_np = np.random.uniform(size=(n,)).astype(A.dtype) + a_np = np.random.uniform(size=(n,)).astype("float32") b_np = np.array([a_np[i % k] for i in range(0, n)]) a_nd = tvm.runtime.tensor(a_np, dev) b_nd = tvm.runtime.tensor(np.zeros(b_np.shape, dtype=b_np.dtype), dev) @@ -445,25 +559,30 @@ def check(t0, t1, factor): print("Skip because gpu does not have fp16 support") return - # compute n = 128 - A = te.placeholder((n,), dtype=t0, name="A") - B = te.placeholder((n,), dtype=t1, name="B") - C = te.compute((n,), lambda i: A[i] + topi.cast(B[i], A.dtype), name="C") - - # schedule - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B, C])) - ob, ib = sch.split(sch.get_loops("C")[0], factors=[None, factor]) - sch.vectorize(ib) - sch.bind(ob, "threadIdx.x") - func = tvm.compile(sch.mod, target="cuda") + num_thread = n // factor + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((n,), t0), B: T.Buffer((n,), t1), C: T.Buffer((n,), t0)): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(num_thread, thread="threadIdx.x"): + for i_1 in T.vectorized(factor): + with T.sblock("C"): + v_i = T.axis.spatial(n, i_0 * factor + i_1) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + T.Cast(t0, B[v_i]) + + func = tvm.compile(Module, target="cuda") # correctness dev = tvm.cuda(0) low, high = (0, 20) if t0.startswith("u") or t1.startswith("u") else (-10, 10) - a_np = np.random.randint(low, high, size=n).astype(A.dtype) - b_np = np.random.randint(low, high, size=n).astype(B.dtype) - c_np = (a_np + b_np).astype(A.dtype) + a_np = np.random.randint(low, high, size=n).astype(t0) + b_np = np.random.randint(low, high, size=n).astype(t1) + c_np = (a_np + b_np).astype(t0) a_nd = tvm.runtime.tensor(a_np, dev) b_nd = tvm.runtime.tensor(b_np, dev) c_nd = tvm.runtime.tensor(np.zeros(c_np.shape, dtype=c_np.dtype), dev) @@ -501,16 +620,30 @@ def skip(t0, t1): check("uint8", "int8", 16) -def sched(A, B): - # schedule - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - io, ii = sch.split(sch.get_loops("B")[0], factors=[1, None]) - iio, iii = sch.split(ii, factors=[32, None]) - _, iiii = sch.split(iii, factors=[None, 4]) - sch.vectorize(iiii) - sch.bind(io, "blockIdx.x") - sch.bind(iio, "threadIdx.x") - return tvm.compile(sch.mod, target="cuda") +def sched(compute_fn, dtype, n=128): + """Create a vectorized CUDA module with the given compute function. + + The schedule structure is: split [1, None] -> split [32, None] -> split [None, 4] + then vectorize innermost, bind blockIdx.x and threadIdx.x. + For n=128 this gives: blockIdx.x=1, threadIdx.x=32, serial=1, vectorized=4. + """ + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((n,), dtype), B: T.Buffer((n,), dtype)): + T.func_attr({"tir.noalias": True}) + for i0_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_1_0 in T.thread_binding(32, thread="threadIdx.x"): + for i0_1_1_0 in range(1): + for i0_1_1_1 in T.vectorized(4): + with T.sblock("B"): + v_i0 = T.axis.spatial(n, i0_1_0 * 4 + i0_1_1_0 * 4 + i0_1_1_1) + T.reads(A[v_i0]) + T.writes(B[v_i0]) + B[v_i0] = compute_fn(A[v_i0]) + + return tvm.compile(Module, target="cuda") @tvm.testing.requires_gpu @@ -557,12 +690,10 @@ def run_test(tvm_intrin, np_func, dtype): return n = 128 - A = te.placeholder((n,), dtype=dtype, name="A") - B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B") - f = sched(A, B) + f = sched(tvm_intrin, dtype, n) dev = tvm.cuda(0) - a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(dtype), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) @@ -582,12 +713,10 @@ def test_vectorized_intrin2(dtype="float32"): def run_test(tvm_intrin, np_func): n = 128 - A = te.placeholder((n,), dtype=dtype, name="A") - B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B") - f = sched(A, B) + f = sched(lambda x: tvm_intrin(x, c2), dtype, n) dev = tvm.cuda(0) - a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(A.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(0, 1, size=n).astype(dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(dtype), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) @@ -607,12 +736,10 @@ def ref_popcount(x): def run_test(dtype): n = 128 - A = te.placeholder((n,), dtype=dtype, name="A") - B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B") - f = sched(A, B) + f = sched(lambda x: tvm.tir.popcount(x), dtype, n) dev = tvm.cuda(0) - a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(B.dtype), dev) + a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(dtype), dev) f(a, b) ref = np.vectorize(ref_popcount)(a.numpy()) tvm.testing.assert_allclose(b.numpy(), ref) @@ -630,27 +757,33 @@ def check_cuda(dtype, n, l, padding, lanes): return dev = tvm.cuda(0) - A = tvm.te.placeholder((n, l), name="A", dtype=dtype) - B = tvm.te.compute( - (n // lanes, l + 2 * padding, lanes), - lambda i, j, k: tvm.te.if_then_else( - tvm.te.any(j < padding, j >= l + padding), - tvm.tir.const(0, dtype), - A[i * lanes + k, j - padding], - ), - name="B", - ) - - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - block, thread, vectorize = sch.get_loops("B") - sch.bind(block, "blockIdx.x") - sch.bind(thread, "threadIdx.x") - sch.vectorize(vectorize) - fun = tvm.compile(sch.mod, target="cuda") - - np_a = np.random.randint(low=-128, high=127, size=(n, l)).astype(A.dtype) - a = tvm.runtime.empty((n, l), A.dtype, dev).copyfrom(np_a) - b = tvm.runtime.empty((n // lanes, l + padding * 2, lanes), B.dtype, dev) + zero = tvm.tir.const(0, dtype) + dim0 = n // lanes + dim1 = l + 2 * padding + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((n, l), dtype), B: T.Buffer((dim0, dim1, lanes), dtype)): + T.func_attr({"tir.noalias": True}) + for i in T.thread_binding(dim0, thread="blockIdx.x"): + for j in T.thread_binding(dim1, thread="threadIdx.x"): + for k in T.vectorized(lanes): + with T.sblock("B"): + v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k]) + T.reads(A[v_i * lanes + v_k, v_j - padding]) + T.writes(B[v_i, v_j, v_k]) + B[v_i, v_j, v_k] = T.if_then_else( + v_j < padding or l + padding <= v_j, + zero, + A[v_i * lanes + v_k, v_j - padding], + ) + + fun = tvm.compile(Module, target="cuda") + + np_a = np.random.randint(low=-128, high=127, size=(n, l)).astype(dtype) + a = tvm.runtime.empty((n, l), dtype, dev).copyfrom(np_a) + b = tvm.runtime.empty((dim0, dim1, lanes), dtype, dev) fun(a, b) np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1) ref = np.pad( @@ -670,48 +803,45 @@ def check_cuda(dtype, n, l, padding, lanes): @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_try_unaligned_vector_load(): - def get_compute(N, C_N, offset): - A = te.placeholder((N,), name="A", dtype="float16") - C = te.compute((C_N,), lambda i: A[i + offset], name="C") - return N, C_N, A, C - - def get_compute_unaligned(): - return get_compute(3, 2, 1) - - def get_compute_aligned(): - return get_compute(4, 2, 2) - - def build(A, C, N, C_N): - sch = tvm.s_tir.Schedule(te.create_prim_func([A, C])) - oi, ii = sch.split(sch.get_loops("C")[0], factors=[None, 2]) - sch.bind(oi, "threadIdx.x") - sch.vectorize(ii) # BUG: misalignment - - f = tvm.tir.build(sch.mod, target="cuda") + def build(N, C_N, offset): + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((N,), "float16"), C: T.Buffer((C_N,), "float16")): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(C_N // 2, thread="threadIdx.x"): + for i_1 in T.vectorized(2): + with T.sblock("C"): + v_i = T.axis.spatial(C_N, i_0 * 2 + i_1) + T.reads(A[v_i + offset]) + T.writes(C[v_i]) + C[v_i] = A[v_i + offset] + + f = tvm.tir.build(Module, target="cuda") kernel_source = f.imports[0].inspect_source() dev = tvm.cuda() - a_data = np.arange(0, N).astype(A.dtype) + a_data = np.arange(0, N).astype("float16") a = tvm.runtime.tensor(a_data, dev) - c = tvm.runtime.tensor(np.zeros(C_N, dtype=C.dtype), dev) + c = tvm.runtime.tensor(np.zeros(C_N, dtype="float16"), dev) f(a, c) return a_data, c.numpy(), kernel_source - N, C_N, A, C = get_compute_unaligned() - a_data, c, kernel_source = build(A, C, N, C_N) + # Unaligned case: N=3, C_N=2, offset=1 + a_data, c, kernel_source = build(3, 2, 1) # (uint1*)(A + (1)) is invalid assert "A + (1)" not in kernel_source - expected = a_data[1 : C_N + 1] + expected = a_data[1 : 2 + 1] assert np.allclose(c, expected), f"expected={expected}\nactual={c}" - N, C_N, A, C = get_compute_aligned() - a_data, c, kernel_source = build(A, C, N, C_N) + # Aligned case: N=4, C_N=2, offset=2 + a_data, c, kernel_source = build(4, 2, 2) # (uint1*)(A + (2)) is a valid vector load assert "A + 2" in kernel_source - expected = a_data[2 : C_N + 2] + expected = a_data[2 : 2 + 2] assert np.allclose(c, expected), f"expected={expected}\nactual={c}" diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index dd0a486a54b3..2730dab20157 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -22,6 +22,7 @@ import tvm import tvm.testing +from tvm.script import ir as I from tvm.script import tir as T try: @@ -38,31 +39,28 @@ def test_e2m1_vector_conversions(promoted_dtype): native_dtype = "float4_e2m1fnx2" vector_length = 64 - @T.prim_func - def add( - A: T.Buffer((vector_length,), native_dtype), - B: T.Buffer((vector_length,), native_dtype), - C: T.Buffer((vector_length,), native_dtype), - ): - T.func_attr({"tir.noalias": True}) - for i in range(vector_length): - with T.sblock("C"): - v_i = T.axis.spatial(vector_length, i) - T.reads(A[v_i], B[v_i]) - T.writes(C[v_i]) - C[v_i] = T.Cast( - native_dtype, T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]) - ) - - sch = tvm.s_tir.Schedule(add) - block = sch.get_sblock("C") - b = sch.get_loops(block) - bx, tx = sch.split(b[0], factors=[None, 32]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((vector_length,), native_dtype), + B: T.Buffer((vector_length,), native_dtype), + C: T.Buffer((vector_length,), native_dtype), + ): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(vector_length // 32, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(vector_length, i_0 * 32 + i_1) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast( + native_dtype, + T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]), + ) target = "cuda" - fadd = tvm.compile(sch.mod, target=target) + fadd = tvm.compile(Module, target=target) dev = tvm.device(target, 0) if "x" in native_dtype: @@ -111,82 +109,98 @@ def add( assert c_result.dtype == promoted_base_dtype -@tvm.testing.requires_cuda_compute_version(10) -def test_e2m1_dequantize(): - n = 128 - - dev = tvm.device("cuda", 0) - target = tvm.target.Target.from_device(dev) - num_elem_per_storage = 32 // 4 - - def get_reinterpret_mod(func_type, vector_length): +def _shuffle_reinterpret_module(n, num_blocks, vector_length, num_elem_per_storage): + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((n // num_elem_per_storage,), "uint32"), + B: T.Buffer((n,), "float16"), + ): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + for i_2 in T.vectorized(vector_length): + with T.sblock("C"): + v_i = T.axis.spatial( + n, i_0 * 32 * vector_length + i_1 * vector_length + i_2 + ) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.Shuffle( + [ + T.reinterpret( + "float4_e2m1fnx2", + T.bitwise_and( + T.shift_right( + A[v_i // num_elem_per_storage], + ((v_i % num_elem_per_storage) // 2 * 4 * 2).astype( + "uint32" + ), + ), + T.uint32((1 << (4 * 2)) - 1), + ).astype("uint8"), + ).astype("float16x2") + ], + indices=[v_i % 2], + ) + + return Module + + +def _scalar_reinterpret_module(n, num_blocks, vector_length, num_elem_per_storage): + @I.ir_module + class Module: @T.prim_func - def shuffle_reinterpret( + def main( A: T.Buffer((n // num_elem_per_storage,), "uint32"), B: T.Buffer((n,), "float16"), ): T.func_attr({"tir.noalias": True}) - for i in range(n): - with T.sblock("C"): - v_i = T.axis.spatial(n, i) - T.reads(A[v_i]) - T.writes(B[v_i]) - B[v_i] = T.Shuffle( - [ - T.reinterpret( - "float4_e2m1fnx2", + for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + for i_2 in T.vectorized(vector_length): + with T.sblock("C"): + v_i = T.axis.spatial( + n, i_0 * 32 * vector_length + i_1 * vector_length + i_2 + ) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.reinterpret( + "float4_e2m1fn", T.bitwise_and( T.shift_right( A[v_i // num_elem_per_storage], - ((v_i % num_elem_per_storage) // 2 * 4 * 2).astype( - "uint32" - ), + (v_i % num_elem_per_storage * 4).astype("uint32"), ), - T.uint32((1 << (4 * 2)) - 1), + T.uint32((1 << 4) - 1), ).astype("uint8"), - ).astype("float16x2") - ], - indices=[v_i % 2], - ) + ).astype("float16") - @T.prim_func - def scalar_reinterpret( - A: T.Buffer((n // num_elem_per_storage,), "uint32"), - B: T.Buffer((n,), "float16"), - ): - T.func_attr({"tir.noalias": True}) - for i in range(n): - with T.sblock("C"): - v_i = T.axis.spatial(n, i) - T.reads(A[v_i]) - T.writes(B[v_i]) - B[v_i] = T.reinterpret( - "float4_e2m1fn", - T.bitwise_and( - T.shift_right( - A[v_i // num_elem_per_storage], - (v_i % num_elem_per_storage * 4).astype("uint32"), - ), - T.uint32((1 << 4) - 1), - ).astype("uint8"), - ).astype("float16") - - func = shuffle_reinterpret if func_type == "shuffle" else scalar_reinterpret - sch = tvm.s_tir.Schedule(func) - block = sch.get_sblock("C") - b = sch.get_loops(block) - bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - return sch.mod + return Module + + +@tvm.testing.requires_cuda_compute_version(10) +def test_e2m1_dequantize(): + n = 128 + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + num_elem_per_storage = 32 // 4 # We only test the whether the code can be compiled. for func_type, vector_length in product(["shuffle", "scalar"], [1, 2, 4]): if func_type == "shuffle" and vector_length == 1: # Vectorize is necessary for shuffle. continue - mod = get_reinterpret_mod(func_type, vector_length) + + num_blocks = n // (32 * vector_length) + + if func_type == "shuffle": + mod = _shuffle_reinterpret_module(n, num_blocks, vector_length, num_elem_per_storage) + else: + mod = _scalar_reinterpret_module(n, num_blocks, vector_length, num_elem_per_storage) + tvm.compile(mod, target=target) diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index b9f31b195f5d..229839add9ed 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -47,29 +47,31 @@ def test_fp8_conversions(input): dtype, nv_dtype = input - @T.prim_func - def add( - A: T.Buffer((64,), dtype), - B: T.Buffer((64,), dtype), - C: T.Buffer((64,), dtype), - ): - T.func_attr({"tir.noalias": True}) - for i in range(64): - with T.sblock("C"): - v_i = T.axis.spatial(64, i) - T.reads(A[v_i], B[v_i]) - T.writes(C[v_i]) - C[v_i] = T.Cast(dtype, T.Cast("float16", A[v_i]) + T.Cast("float16", B[v_i])) - - sch = tvm.s_tir.Schedule(add) - block = sch.get_sblock("C") - b = sch.get_loops(block) - bx, tx = sch.split(b[0], factors=[None, 32]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - + def _create_mod(dtype): + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((64,), dtype), + B: T.Buffer((64,), dtype), + C: T.Buffer((64,), dtype), + ): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(64, i_0 * 32 + i_1) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast( + dtype, T.Cast("float16", A[v_i]) + T.Cast("float16", B[v_i]) + ) + + return Module + + mod = _create_mod(dtype) target = "cuda" - fadd = tvm.tir.build(sch.mod, target=target) + fadd = tvm.tir.build(mod, target=target) cuda_src = fadd.imports[0].inspect_source() assert nv_dtype in cuda_src, f"{nv_dtype} datatype not found in generated CUDA" @@ -96,41 +98,36 @@ def test_fp8_packing(dtype): vector_length = 4 native_dtype, packed_dtype = (f"{dtype}x{vector_length}", "uint32") - @T.prim_func - def add( - A: T.Buffer((length,), native_dtype), - R: T.Buffer((length,), packed_dtype), - B: T.Buffer((length,), native_dtype), - ): - T.func_attr({"tir.noalias": True}) - # with T.sblock("root"): - for i in range(length): - with T.sblock("R"): - v_i = T.axis.spatial(length, i) - T.reads(A[v_i]) - T.writes(R[v_i]) - R[v_i] = T.reinterpret(packed_dtype, A[v_i]) - for i in range(length): - with T.sblock("B"): - v_i = T.axis.spatial(length, i) - T.reads(R[v_i]) - T.writes(B[v_i]) - B[v_i] = T.reinterpret(native_dtype, R[v_i]) - - sch = tvm.s_tir.Schedule(add) - block = sch.get_sblock("R") - b = sch.get_loops(block) - bx, tx = sch.split(b[0], factors=[None, 32]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - block = sch.get_sblock("B") - b = sch.get_loops(block) - bx, tx = sch.split(b[0], factors=[None, 32]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - + def _create_mod(native_dtype, packed_dtype, length): + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((length,), native_dtype), + R: T.Buffer((length,), packed_dtype), + B: T.Buffer((length,), native_dtype), + ): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.sblock("R"): + v_i = T.axis.spatial(length, i_0 * 32 + i_1) + T.reads(A[v_i]) + T.writes(R[v_i]) + R[v_i] = T.reinterpret(packed_dtype, A[v_i]) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.sblock("B"): + v_i = T.axis.spatial(length, i_0 * 32 + i_1) + T.reads(R[v_i]) + T.writes(B[v_i]) + B[v_i] = T.reinterpret(native_dtype, R[v_i]) + + return Module + + mod = _create_mod(native_dtype, packed_dtype, length) target = "cuda" - f = tvm.compile(sch.mod, target=target) + f = tvm.compile(mod, target=target) dev = tvm.device(target, 0) np_shape = (length, vector_length) @@ -164,32 +161,32 @@ def add( def test_fp8_vector_conversions(native_dtype, promoted_dtype, numpytype): vector_length = 64 - @T.prim_func - def add( - A: T.Buffer((vector_length,), native_dtype), - B: T.Buffer((vector_length,), native_dtype), - C: T.Buffer((vector_length,), native_dtype), - ): - T.func_attr({"tir.noalias": True}) - # with T.sblock("root"): - for i in range(vector_length): - with T.sblock("C"): - v_i = T.axis.spatial(vector_length, i) - T.reads(A[v_i], B[v_i]) - T.writes(C[v_i]) - C[v_i] = T.Cast( - native_dtype, T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]) - ) - - sch = tvm.s_tir.Schedule(add) - block = sch.get_sblock("C") - b = sch.get_loops(block) - bx, tx = sch.split(b[0], factors=[None, 32]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - + def _create_mod(native_dtype, promoted_dtype): + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((64,), native_dtype), + B: T.Buffer((64,), native_dtype), + C: T.Buffer((64,), native_dtype), + ): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(64, i_0 * 32 + i_1) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast( + native_dtype, + T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]), + ) + + return Module + + mod = _create_mod(native_dtype, promoted_dtype) target = "cuda" - fadd = tvm.tir.build(sch.mod, target=target) + fadd = tvm.tir.build(mod, target=target) cuda_src = fadd.imports[0].inspect_source() dev = tvm.device(target, 0) @@ -225,21 +222,21 @@ def add( def test_half_broadcast(bcast_length): dtype = "float16" - @T.prim_func - def vector_broadcast(a: T.Buffer((), dtype), vec: T.Buffer((bcast_length,), dtype)): - for t in range(1): - with T.sblock("broadcast"): - vec[0:bcast_length] = T.broadcast(a[()], bcast_length) - - sch = tvm.s_tir.Schedule(vector_broadcast) - block = sch.get_sblock("broadcast") - b = sch.get_loops(block) - bx, tx = sch.split(b[0], factors=[None, 1]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") + def _create_mod(bcast_length, dtype): + @I.ir_module + class Module: + @T.prim_func + def main(a: T.Buffer((), dtype), vec: T.Buffer((bcast_length,), dtype)): + for i_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_1 in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("broadcast"): + vec[0:bcast_length] = T.broadcast(a[()], bcast_length) + + return Module + mod = _create_mod(bcast_length, dtype) target = "cuda" - func = tvm.compile(sch.mod, target=target) + func = tvm.compile(mod, target=target) dev = tvm.device(target, 0) a_np = np.random.uniform(low=0, high=4, size=()).astype(dtype) @@ -298,30 +295,25 @@ def test_half4_vector_add(): vector_length = 4 vec_dtype = dtype + "x" + str(vector_length) - @T.prim_func - def add( - A: T.Buffer((length,), vec_dtype), - B: T.Buffer((length,), vec_dtype), - C: T.Buffer((length,), vec_dtype), - ): - T.func_attr({"tir.noalias": True}) - # with T.sblock("root"): - for i in range(length): - with T.sblock("C"): - v_i = T.axis.spatial(length, i) - T.reads(A[v_i], B[v_i]) - T.writes(C[v_i]) - C[v_i] = A[v_i] + B[v_i] - - sch = tvm.s_tir.Schedule(add) - block = sch.get_sblock("C") - b = sch.get_loops(block) - bx, tx = sch.split(b[0], factors=[None, 32]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((64,), "float16x4"), + B: T.Buffer((64,), "float16x4"), + C: T.Buffer((64,), "float16x4"), + ): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(64, i_0 * 32 + i_1) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + B[v_i] target = "cuda" - fadd = tvm.compile(sch.mod, target=target) + fadd = tvm.compile(Module, target=target) dev = tvm.device(target, 0) a_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype) @@ -976,26 +968,29 @@ def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @tvm.testing.requires_cuda_compute_version(8, 9) def test_fp8_fp16_bf16_vectorize_arith(vec_length, dtype): - @T.prim_func - def func_vectorize( - A: T.Buffer((128,), "float8_e4m3fn"), - B: T.Buffer((128,), dtype), - C: T.Buffer((128,), dtype), - ) -> None: - for i in T.serial(128): - with T.sblock("compute"): - vi = T.axis.remap("S", [i]) - C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0) - - sch = tvm.s_tir.Schedule(func_vectorize) - (l,) = sch.get_loops(sch.get_sblock("compute")) - lo, li = sch.split(l, [None, vec_length]) - sch.bind(lo, "threadIdx.x") - sch.vectorize(li) - + def _create_mod(vec_length, dtype): + num_threads = 128 // vec_length + + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((128,), "float8_e4m3fn"), + B: T.Buffer((128,), dtype), + C: T.Buffer((128,), dtype), + ) -> None: + for i_0 in T.thread_binding(num_threads, thread="threadIdx.x"): + for i_1 in T.vectorized(vec_length): + with T.sblock("compute"): + vi = T.axis.spatial(128, i_0 * vec_length + i_1) + C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0) + + return Module + + mod = _create_mod(vec_length, dtype) device = tvm.cuda() target = tvm.target.Target.from_device(device) - f = tir.build(sch.mod, target=target) + f = tvm.tir.build(mod, target=target) a_np = np.random.rand(128).astype("float8_e4m3fn") b_np = np.random.rand(128).astype(dtype) diff --git a/tests/python/codegen/test_target_codegen_device.py b/tests/python/codegen/test_target_codegen_device.py index fc9d65cc1548..04e1857c9ec3 100644 --- a/tests/python/codegen/test_target_codegen_device.py +++ b/tests/python/codegen/test_target_codegen_device.py @@ -15,42 +15,36 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te -from tvm.contrib import utils import numpy as np import tvm.testing -from tvm import tir +from tvm.script import tir as T, ir as I @tvm.testing.requires_gpu def test_large_uint_imm(): value = (1 << 63) + 123 - other = tvm.tir.const(3, "uint64") - n = 12 - num_thread = 2 - - A = te.compute((n,), lambda *i: tvm.tir.const(value, "uint64") + other, name="A") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A]) - sch = tvm.s_tir.Schedule(mod) - - # Get block and loop - block = sch.get_sblock("A") - loop = sch.get_loops(block)[0] - - # Split and bind - xo, xi = sch.split(loop, factors=[None, num_thread]) - sch.bind(xi, "threadIdx.x") - sch.bind(xo, "blockIdx.x") + value_const = tvm.tir.const(value, "uint64") + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((12,), "uint64")): + T.func_attr({"tir.noalias": True}) + for i0_0 in T.thread_binding(6, thread="blockIdx.x"): + for i0_1 in T.thread_binding(2, thread="threadIdx.x"): + with T.sblock("A"): + v_i0 = T.axis.spatial(12, i0_0 * 2 + i0_1) + T.reads() + T.writes(A[v_i0]) + A[v_i0] = value_const + T.uint64(3) def check_target(device): if not tvm.testing.device_enabled(device): return dev = tvm.device(device, 0) - f = tvm.compile(sch.mod, target=device) + f = tvm.compile(Module, target=device) # launch the kernel. - a = tvm.runtime.empty((n,), dtype=A.dtype, device=dev) + a = tvm.runtime.empty((12,), dtype="uint64", device=dev) f(a) assert a.numpy()[0] == value + 3 @@ -60,47 +54,44 @@ def check_target(device): @tvm.testing.requires_gpu def test_add_pipeline(): - n = te.size_var("n") - A = te.placeholder((n,), name="A") - B = te.placeholder((), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(), name="C") - D = te.compute(A.shape, lambda *i: C(*i) + 1, name="D") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B, D]) - sch = tvm.s_tir.Schedule(mod) - - # Get blocks and loops - c_block = sch.get_sblock("C") - d_block = sch.get_sblock("D") - c_loop = sch.get_loops(c_block)[0] - d_loop = sch.get_loops(d_block)[0] - - # GPU schedule have to split by gridIdx and threadIdx - num_thread = 256 - - # Schedule C - c_xo, c_xi = sch.split(c_loop, factors=[None, num_thread]) - sch.bind(c_xi, "threadIdx.x") - sch.bind(c_xo, "blockIdx.x") - - # Schedule D - d_xo, d_xi = sch.split(d_loop, factors=[None, num_thread]) - sch.bind(d_xi, "threadIdx.x") - sch.bind(d_xo, "blockIdx.x") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, B: T.Buffer((), "float32"), var_D: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (n,)) + D = T.match_buffer(var_D, (n,)) + C = T.alloc_buffer((n,)) + for i0_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"): + for i0_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.sblock("C"): + v_i0 = T.axis.spatial(n, i0_0 * 256 + i0_1) + T.where(i0_0 * 256 + i0_1 < n) + T.reads(A[v_i0], B[()]) + T.writes(C[v_i0]) + C[v_i0] = A[v_i0] + B[()] + for i0_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"): + for i0_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.sblock("D"): + v_i0 = T.axis.spatial(n, i0_0 * 256 + i0_1) + T.where(i0_0 * 256 + i0_1 < n) + T.reads(C[v_i0]) + T.writes(D[v_i0]) + D[v_i0] = C[v_i0] + T.float32(1.0) def check_target(device, host): if not tvm.testing.device_enabled(device) or not tvm.testing.device_enabled(host): return dev = tvm.device(device, 0) target = tvm.target.Target(device, host) - mhost = tvm.tir.build(sch.mod, target=target) + mhost = tvm.tir.build(Module, target=target) f = mhost.main # launch the kernel. n = 1027 - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.random.uniform(size=()).astype(B.dtype), dev) - d = tvm.runtime.tensor(np.zeros(n, dtype=D.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + b = tvm.runtime.tensor(np.random.uniform(size=()).astype("float32"), dev) + d = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) f(a, b, d) tvm.testing.assert_allclose(d.numpy(), a.numpy() + b.numpy() + 1) diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py b/tests/python/codegen/test_target_codegen_gpu_common.py index ab80ef5e4995..bf2fc12083b4 100644 --- a/tests/python/codegen/test_target_codegen_gpu_common.py +++ b/tests/python/codegen/test_target_codegen_gpu_common.py @@ -21,7 +21,7 @@ import tvm import tvm.testing -from tvm import te +from tvm.script import tir as T, ir as I @tvm.testing.requires_gpu @@ -29,27 +29,34 @@ @pytest.mark.parametrize("dtype", ["int32", "uint32", "int64", "uint64"]) def test_int_intrin(target, dev, dtype): test_funcs = [ - (tvm.tir.clz, lambda x, dtype: int(dtype[-2:]) - (len(bin(x)) - 2)), + (T.clz, lambda x, dtype: int(dtype[-2:]) - (len(bin(x)) - 2)), ] - def run_test(tvm_intrin, np_func, dtype): + for tvm_intrin, np_func in test_funcs: n = 128 - A = te.placeholder((n,), name="A", dtype=dtype) - B = te.compute(A.shape, lambda *i: tvm_intrin(A(*i)), name="B") - func = te.create_prim_func([A, B]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("B")) - sch.bind(x, "threadIdx.x") - f = tvm.compile(sch.mod, target=target) - a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(B.dtype), dev) + + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((n,), dtype), + B: T.Buffer((n,), dtype), + ): + T.func_attr({"tir.noalias": True}) + for i0 in T.thread_binding(n, thread="threadIdx.x"): + with T.sblock("B"): + v_i0 = T.axis.spatial(n, i0) + T.reads(A[v_i0]) + T.writes(B[v_i0]) + B[v_i0] = tvm_intrin(A[v_i0]) + + f = tvm.compile(Module, target=target) + a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(dtype), dev) + b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(dtype), dev) f(a, b) ref = np.vectorize(partial(np_func, dtype=dtype))(a.numpy()) tvm.testing.assert_allclose(b.numpy(), ref) - for func in test_funcs: - run_test(*func, dtype) - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_hexagon.py b/tests/python/codegen/test_target_codegen_hexagon.py index f14005ad9d0b..a297e89bfbe3 100644 --- a/tests/python/codegen/test_target_codegen_hexagon.py +++ b/tests/python/codegen/test_target_codegen_hexagon.py @@ -15,15 +15,12 @@ # specific language governing permissions and limitations # under the License. -import os import re -import sys -import numpy as np import pytest import tvm import tvm.testing import tvm.contrib.hexagon as hexagon -from tvm import te +from tvm.script import tir as T, ir as I @pytest.fixture(autouse=True) @@ -40,27 +37,45 @@ def register_linker(): def test_basic(): target = tvm.target.hexagon("v66", hvx=128) - def check_add(): - A = tvm.te.placeholder((128,), dtype="uint8", name="A") - B = tvm.te.placeholder((128,), dtype="uint8", name="A") - C = tvm.te.compute((128,), lambda i: A[i] + B[i], name="C") - mod = tvm.IRModule.from_expr(te.create_prim_func([C, A, B])) - hexm = tvm.compile(mod, target=tvm.target.Target(target, target)) - asm = hexm.inspect_source("s") - vadds = re.findall(r"v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)", asm) - assert vadds # Check that it's non-empty + @I.ir_module + class Module: + @T.prim_func + def main( + C: T.Buffer((128,), "uint8"), + A: T.Buffer((128,), "uint8"), + A_1: T.Buffer((128,), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + for i in range(128): + with T.sblock("C"): + v_i = T.axis.spatial(128, i) + T.reads(A[v_i], A_1[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + A_1[v_i] - check_add() + hexm = tvm.compile(Module, target=tvm.target.Target(target, target)) + asm = hexm.inspect_source("s") + vadds = re.findall(r"v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)", asm) + assert vadds # Check that it's non-empty @tvm.testing.requires_hexagon def test_llvm_target_features(): target = tvm.target.hexagon("v66", hvx=128) - # Define some trivial compute - A = tvm.te.placeholder((128,), dtype="uint8", name="A") - C = tvm.te.compute((128,), lambda i: A[i] + 1, name="C") - mod = tvm.IRModule.from_expr(te.create_prim_func([C, A]).with_attr("global_symbol", "add_one")) - m = tvm.compile(mod, target=tvm.target.Target(target, target)) + + @I.ir_module + class Module: + @T.prim_func + def add_one(C: T.Buffer((128,), "int32"), A: T.Buffer((128,), "uint8")): + T.func_attr({"tir.noalias": True}) + for i in range(128): + with T.sblock("C"): + v_i = T.axis.spatial(128, i) + T.reads(A[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast("int32", A[v_i]) + 1 + + m = tvm.compile(Module, target=tvm.target.Target(target, target)) llvm_ir = m.inspect_source("ll") # Make sure we find +hvx-length128b in "attributes". fs = re.findall(r"attributes.*\+hvx-length128b", llvm_ir) @@ -70,11 +85,22 @@ def test_llvm_target_features(): @tvm.testing.requires_hexagon def test_llvm_options(): target = tvm.target.hexagon("v66", llvm_options="-hexagon-noopt") - Zero = tvm.te.compute((10,), lambda _: tvm.tir.const(0, "int32")) - mod = tvm.IRModule.from_expr(te.create_prim_func([Zero])) + + @I.ir_module + class Module: + @T.prim_func + def main(compute: T.Buffer((10,), "int32")): + T.func_attr({"tir.noalias": True}) + for _ in range(10): + with T.sblock("compute"): + v__ = T.axis.spatial(10, _) + T.reads() + T.writes(compute[v__]) + compute[v__] = 0 + # Check that BuildHexagon hasn't crashed because of target attribute # type mismatch. - tvm.compile(mod, target=tvm.target.Target(target, target)) + tvm.compile(Module, target=tvm.target.Target(target, target)) assert re.search("-hexagon-noopt", str(target)) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 657044246d89..da58f5bb459c 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -22,7 +22,6 @@ import tvm import tvm.testing -from tvm import te, tir from tvm.contrib import clang, utils from tvm.script import ir as I from tvm.script import tir as T @@ -31,25 +30,26 @@ @tvm.testing.requires_llvm def test_llvm_intrin(): - @T.prim_func - def prefetch(A: T.handle("float32")): - T.func_attr({"global_symbol": "prefetch"}) - A_buf = T.Buffer((4,), "float32", data=A) - T.evaluate(T.Call("void", "tir.prefetch", [T.address_of(A_buf[0]), 0, 3, 1])) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.handle("float32")): + A_buf = T.Buffer((4,), "float32", data=A) + T.evaluate(T.Call("void", "tir.prefetch", [T.address_of(A_buf[0]), 0, 3, 1])) - mod = tvm.IRModule.from_expr(prefetch) - fcode = tvm.compile(mod) + fcode = tvm.compile(Module) @tvm.testing.requires_llvm def test_llvm_void_intrin(): - @T.prim_func - def main(A: T.handle("uint8")): - # Create an intrinsic that returns void. - T.call_llvm_intrin("", "llvm.assume", T.bool(True)) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.handle("uint8")): + # Create an intrinsic that returns void. + T.call_llvm_intrin("", "llvm.assume", T.bool(True)) - mod = tvm.IRModule.from_expr(main) - fcode = tvm.compile(mod) + fcode = tvm.compile(Module) @tvm.testing.requires_llvm @@ -84,108 +84,94 @@ def main(A: T.Buffer((1, 1), "int32"), C: T.Buffer((1, 1), "int32")): @tvm.testing.requires_llvm def test_llvm_lookup_intrin(): - @T.prim_func - def main(A: T.handle("uint8x8")): - A_buf = T.Buffer((1,), "uint8x8", data=A) - T.evaluate(T.call_llvm_pure_intrin("uint8x8", "llvm.ctpop.v8i8", T.uint32(1), A_buf[0])) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.handle("uint8x8")): + A_buf = T.Buffer((1,), "uint8x8", data=A) + T.evaluate(T.call_llvm_pure_intrin("uint8x8", "llvm.ctpop.v8i8", T.uint32(1), A_buf[0])) - mod = tvm.IRModule.from_expr(main) - fcode = tvm.compile(mod, None) + fcode = tvm.compile(Module, None) @tvm.testing.requires_llvm def test_llvm_large_uintimm(): value = (1 << 63) + 123 - other = tvm.tir.const(3, "uint64") - A = te.compute((), lambda: tvm.tir.const(value, "uint64") + other, name="A") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A]) - sch = tvm.s_tir.Schedule(mod) + large_val = tvm.tir.const(value, "uint64") - def check_llvm(): - f = tvm.compile(sch.mod, target="llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.runtime.empty((), dtype=A.dtype, device=dev) - f(a) - assert a.numpy() == value + 3 + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((), "uint64")): + T.func_attr({"tir.noalias": True}) + with T.sblock("A"): + vi = T.axis.spatial(1, 0) + T.reads() + T.writes(A[()]) + A[()] = large_val + T.uint64(3) - check_llvm() + f = tvm.compile(Module, target="llvm") + dev = tvm.cpu(0) + a = tvm.runtime.empty((), dtype="uint64", device=dev) + f(a) + assert a.numpy() == value + 3 @tvm.testing.requires_llvm def test_llvm_multi_parallel(): - n = 128 - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") - C = te.compute(A.shape, lambda *i: te.sqrt(B(*i)) * 2 + 2, name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(mod) - - # Get blocks and loops - c_block = sch.get_sblock("C") - b_block = sch.get_sblock("B") - c_loop = sch.get_loops(c_block)[0] - - # Split and parallelize - xo, xi = sch.split(c_loop, factors=[None, 8]) - xo1, xo2 = sch.split(xo, factors=[1, None]) - - # Move computation of B - sch.compute_at(b_block, xo1) - - # Get B's loop after compute_at - b_loop = sch.get_loops(b_block)[0] - - # Apply parallel scheduling - sch.parallel(b_loop) - sch.parallel(xi) - - def check_llvm(): - # BUILD and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) - f(a, c) - tvm.testing.assert_allclose(c.numpy(), np.sqrt(a.numpy() + 1) * 2 + 2, rtol=1e-5) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((128,), "float32"), C: T.Buffer((128,), "float32")): + T.func_attr({"tir.noalias": True}) + B = T.alloc_buffer((128,)) + for i0_0_0 in T.parallel(1): + for ax0 in range(128): + with T.sblock("B"): + v_i0 = T.axis.spatial(128, ax0) + T.reads(A[v_i0]) + T.writes(B[v_i0]) + B[v_i0] = A[v_i0] + T.float32(1.0) + for i0_0_1 in range(16): + for i0_1 in T.parallel(8): + with T.sblock("C"): + v_i0 = T.axis.spatial(128, i0_0_0 * 128 + i0_0_1 * 8 + i0_1) + T.reads(B[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = T.sqrt(B[v_i0]) * T.float32(2.0) + T.float32(2.0) - check_llvm() + n = 128 + f = tvm.compile(Module, target="llvm") + dev = tvm.cpu(0) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) + f(a, c) + tvm.testing.assert_allclose(c.numpy(), np.sqrt(a.numpy() + 1) * 2 + 2, rtol=1e-5) @tvm.testing.requires_llvm def test_llvm_flip_pipeline(): def check_llvm(nn, base): - n = tvm.runtime.convert(nn) - A = te.placeholder((n + base), name="A") - C = te.compute((n,), lambda i: A(nn + base - i - 1), name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(mod) - - # Get block and loop - block = sch.get_sblock("C") - loop = sch.get_loops(block)[0] - - # Split and parallelize - xo, xi = sch.split(loop, factors=[None, 4]) - sch.parallel(xo) - sch.vectorize(xi) - - # build and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((nn + base,), "float32"), C: T.Buffer((nn,), "float32")): + T.func_attr({"tir.noalias": True}) + for i_0 in T.parallel((nn + 3) // 4): + for i_1 in T.vectorized(4): + with T.sblock("C"): + v_i = T.axis.spatial(nn, i_0 * 4 + i_1) + T.where(i_0 * 4 + i_1 < nn) + T.reads(A[nn + base - 1 - v_i]) + T.writes(C[v_i]) + C[v_i] = A[nn + base - 1 - v_i] + + f = tvm.compile(Module, target="llvm") dev = tvm.cpu(0) - # launch the kernel. - n = nn - a = tvm.runtime.tensor(np.random.uniform(size=(n + base)).astype(A.dtype), dev) - c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(nn + base)).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros(nn, dtype="float32"), dev) f(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy()[::-1][:n]) + tvm.testing.assert_allclose(c.numpy(), a.numpy()[::-1][:nn]) check_llvm(4, 0) check_llvm(128, 8) @@ -195,29 +181,30 @@ def check_llvm(nn, base): @tvm.testing.requires_llvm def test_llvm_vadd_pipeline(): - n = te.size_var("n") - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute((n,), lambda i: A[i] + B[i], name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B, C]) - sch = tvm.s_tir.Schedule(mod) - - # Get block and loop - block = sch.get_sblock("C") - loop = sch.get_loops(block)[0] - - # Split the loop - _, inner = sch.split(loop, factors=[None, 4]) - sch.vectorize(inner) - # Build and verify - f = tvm.compile(sch.mod, target="llvm") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (n,)) + B = T.match_buffer(var_B, (n,)) + C = T.match_buffer(var_C, (n,)) + for i_0 in range((n + 3) // 4): + for i_1 in T.vectorized(4): + with T.sblock("C"): + v_i = T.axis.spatial(n, i_0 * 4 + i_1) + T.where(i_0 * 4 + i_1 < n) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + B[v_i] + + f = tvm.compile(Module, target="llvm") dev = tvm.cpu(0) n = 128 - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) @@ -225,30 +212,29 @@ def test_llvm_vadd_pipeline(): @tvm.testing.requires_llvm def test_llvm_madd_pipeline(): def check_llvm(nn, base, stride): - n = tvm.runtime.convert(nn) - A = te.placeholder((n + base, stride), name="A") - C = te.compute((n, stride), lambda i, j: A(base + i, j) + 1, name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(mod) - - # Get block and loops - block = sch.get_sblock("C") - i_loop, j_loop = sch.get_loops(block) - - # Split and parallelize - xo, xi = sch.split(i_loop, factors=[None, 4]) - sch.parallel(xo) - sch.vectorize(xi) - - # build and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((nn + base, stride), "float32"), + C: T.Buffer((nn, stride), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i_0 in T.parallel((nn + 3) // 4): + for i_1 in T.vectorized(4): + for j in range(stride): + with T.sblock("C"): + v_i = T.axis.spatial(nn, i_0 * 4 + i_1) + v_j = T.axis.spatial(stride, j) + T.where(i_0 * 4 + i_1 < nn) + T.reads(A[v_i + base, v_j]) + T.writes(C[v_i, v_j]) + C[v_i, v_j] = A[v_i + base, v_j] + T.float32(1.0) + + f = tvm.compile(Module, target="llvm") dev = tvm.cpu(0) - # launch the kernel. - n = nn - a = tvm.runtime.tensor(np.random.uniform(size=(n + base, stride)).astype(A.dtype), dev) - c = tvm.runtime.tensor(np.zeros((n, stride), dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=(nn + base, stride)).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros((nn, stride), dtype="float32"), dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy()[base:] + 1) @@ -261,59 +247,73 @@ def check_llvm(nn, base, stride): @tvm.testing.requires_llvm def test_llvm_temp_space(): - nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda i: A(i) + 1, name="B") - C = te.compute(A.shape, lambda i: B(i) + 1, name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(mod) - - def check_llvm(): - # build and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") - dev = tvm.cpu(0) - # launch the kernel. - n = nn - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) - f(a, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1 + 1) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32")): + T.func_attr({"tir.noalias": True}) + B = T.alloc_buffer((1024,)) + for i in range(1024): + with T.sblock("B"): + v_i = T.axis.spatial(1024, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = A[v_i] + T.float32(1.0) + for i in range(1024): + with T.sblock("C"): + v_i = T.axis.spatial(1024, i) + T.reads(B[v_i]) + T.writes(C[v_i]) + C[v_i] = B[v_i] + T.float32(1.0) - check_llvm() + nn = 1024 + f = tvm.compile(Module, target="llvm") + dev = tvm.cpu(0) + a = tvm.runtime.tensor(np.random.uniform(size=nn).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros(nn, dtype="float32"), dev) + f(a, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1 + 1) @tvm.testing.requires_llvm def test_multiple_func(): - # Define the computation - n = te.size_var("n") - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute((n,), lambda i: A[i] + B[i], name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B, C]) - sch = tvm.s_tir.Schedule(mod) - - # Create two functions with different names - mod = tvm.IRModule( - { - "fadd1": sch.mod["main"].with_attr("global_symbol", "fadd1"), - "fadd2": sch.mod["main"].with_attr("global_symbol", "fadd2"), - } - ) + @I.ir_module + class Module: + @T.prim_func + def fadd1(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (n,)) + B = T.match_buffer(var_B, (n,)) + C = T.match_buffer(var_C, (n,)) + for i in range(n): + with T.sblock("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + B[v_i] + + @T.prim_func + def fadd2(var_A: T.handle, var_B: T.handle, var_C: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (n,)) + B = T.match_buffer(var_B, (n,)) + C = T.match_buffer(var_C, (n,)) + for i in range(n): + with T.sblock("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + B[v_i] - # Build and verify - f = tvm.compile(mod, target="llvm") + f = tvm.compile(Module, target="llvm") dev = tvm.cpu(0) n = 10 - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.runtime.tensor(np.zeros(n, dtype=C.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + b = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), dev) + c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) - # Test both functions f["fadd1"](a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) f["fadd2"](a, b, c) @@ -322,126 +322,145 @@ def test_multiple_func(): @tvm.testing.requires_llvm def test_llvm_condition(): - def check_llvm(n, offset): - A = te.placeholder((n,), name="A") - C = te.compute((n,), lambda i: tvm.tir.if_then_else(i >= offset, A[i], 0.0), name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(mod) - - # build and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.runtime.tensor(np.random.uniform(size=(n,)).astype(A.dtype), dev) - c = tvm.runtime.empty((n,), A.dtype, dev) - f(a, c) - c_np = a.numpy() - c_np[:offset] = 0 - tvm.testing.assert_allclose(c.numpy(), c_np) - - check_llvm(64, 8) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((64,), "float32"), C: T.Buffer((64,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in range(64): + with T.sblock("C"): + v_i = T.axis.spatial(64, i) + T.reads(A[v_i]) + T.writes(C[v_i]) + C[v_i] = T.if_then_else(8 <= v_i, A[v_i], T.float32(0.0)) + + n = 64 + offset = 8 + f = tvm.compile(Module, target="llvm") + dev = tvm.cpu(0) + a = tvm.runtime.tensor(np.random.uniform(size=(n,)).astype("float32"), dev) + c = tvm.runtime.empty((n,), "float32", dev) + f(a, c) + c_np = a.numpy() + c_np[:offset] = 0 + tvm.testing.assert_allclose(c.numpy(), c_np) @tvm.testing.requires_llvm def test_llvm_bool(): - def check_llvm(n): - A = te.placeholder((n,), name="A", dtype="int32") - C = te.compute((n,), lambda i: A[i].equal(1).astype("float"), name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(mod) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((64,), "int32"), C: T.Buffer((64,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in range(64): + with T.sblock("C"): + v_i = T.axis.spatial(64, i) + T.reads(A[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast("float32", A[v_i] == 1) + + n = 64 + f = tvm.compile(Module, target="llvm") + dev = tvm.cpu(0) + a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype("int32"), dev) + c = tvm.runtime.empty((n,), "float32", dev) + f(a, c) + c_np = a.numpy() == 1 + tvm.testing.assert_allclose(c.numpy(), c_np) - # build and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - c = tvm.runtime.empty((n,), C.dtype, dev) - f(a, c) - c_np = a.numpy() == 1 - tvm.testing.assert_allclose(c.numpy(), c_np) - check_llvm(64) +@tvm.testing.requires_llvm +def test_rank_zero(): + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((64,), "float32"), + scale: T.Buffer((), "float32"), + compute: T.Buffer((), "float32"), + ): + T.func_attr({"tir.noalias": True}) + C = T.alloc_buffer(()) + for k in range(64): + with T.sblock("C"): + v_k = T.axis.reduce(64, k) + T.reads(A[v_k], scale[()]) + T.writes(C[()]) + with T.init(): + C[()] = T.float32(0.0) + C[()] = C[()] + A[v_k] * scale[()] + with T.sblock("compute"): + vi = T.axis.spatial(1, 0) + T.reads(C[()]) + T.writes(compute[()]) + compute[()] = C[()] + T.float32(1.0) + + n = 64 + f = tvm.compile(Module, target="llvm") + dev = tvm.cpu(0) + a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype("float32"), dev) + sc = tvm.runtime.tensor(np.random.randint(0, 2, size=()).astype("float32"), dev) + d = tvm.runtime.empty((), "float32", dev) + f(a, sc, d) + d_np = np.sum(a.numpy()) * sc.numpy() + 1 + tvm.testing.assert_allclose(d.numpy(), d_np) @tvm.testing.requires_llvm -def test_rank_zero(): - def check_llvm(n): - A = te.placeholder((n,), name="A") - scale = te.placeholder((), name="scale") - k = te.reduce_axis((0, n), name="k") - C = te.compute((), lambda: te.sum(A[k] * scale(), axis=k), name="C") - D = te.compute((), lambda: C() + 1) - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, scale, D]) - sch = tvm.s_tir.Schedule(mod) - - # build and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") +def test_rank_zero_bound_checkers(): + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((64,), "float32"), + scale: T.Buffer((), "float32"), + compute: T.Buffer((), "float32"), + ): + T.func_attr({"tir.noalias": True}) + C = T.alloc_buffer(()) + for k in range(64): + with T.sblock("C"): + v_k = T.axis.reduce(64, k) + T.reads(A[v_k], scale[()]) + T.writes(C[()]) + with T.init(): + C[()] = T.float32(0.0) + C[()] = C[()] + A[v_k] * scale[()] + with T.sblock("compute"): + vi = T.axis.spatial(1, 0) + T.reads(C[()]) + T.writes(compute[()]) + compute[()] = C[()] + T.float32(1.0) + + n = 64 + with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}): + f = tvm.compile(Module, target="llvm") dev = tvm.cpu(0) - # launch the kernel. - a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - sc = tvm.runtime.tensor(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) - d = tvm.runtime.empty((), D.dtype, dev) + a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype("float32"), dev) + sc = tvm.runtime.tensor(np.random.randint(0, 2, size=()).astype("float32"), dev) + d = tvm.runtime.empty((), "float32", dev) f(a, sc, d) d_np = np.sum(a.numpy()) * sc.numpy() + 1 tvm.testing.assert_allclose(d.numpy(), d_np) - check_llvm(64) - - -@tvm.testing.requires_llvm -def test_rank_zero_bound_checkers(): - def check_llvm(n): - with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}): - A = te.placeholder((n,), name="A") - scale = te.placeholder((), name="scale") - k = te.reduce_axis((0, n), name="k") - C = te.compute((), lambda: te.sum(A[k] * scale(), axis=k), name="C") - D = te.compute((), lambda: C() + 1) - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, scale, D]) - sch = tvm.s_tir.Schedule(mod) - - # build and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") - dev = tvm.cpu(0) - # launch the kernel. - a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype(A.dtype), dev) - sc = tvm.runtime.tensor(np.random.randint(0, 2, size=()).astype(scale.dtype), dev) - d = tvm.runtime.empty((), D.dtype, dev) - f(a, sc, d) - d_np = np.sum(a.numpy()) * sc.numpy() + 1 - tvm.testing.assert_allclose(d.numpy(), d_np) - - check_llvm(64) - @tvm.testing.requires_llvm def test_alignment(): - n = tvm.runtime.convert(1024) - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda i: A[i] * 3, name="B") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B]).with_attr("global_symbol", "test_alignment") - sch = tvm.s_tir.Schedule(mod) - - # Get block and loop - block = sch.get_sblock("B") - loop = sch.get_loops(block)[0] - - # Split and vectorize - _, tx = sch.split(loop, factors=[None, 8]) - sch.vectorize(tx) - - # Build with name - f = tvm.tir.build(sch.mod, target="llvm") + @I.ir_module + class Module: + @T.prim_func + def test_alignment(A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32")): + T.func_attr({"tir.noalias": True}) + for i_0 in range(128): + for i_1 in T.vectorized(8): + with T.sblock("B"): + v_i = T.axis.spatial(1024, i_0 * 8 + i_1) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = A[v_i] * T.float32(3.0) + + f = tvm.tir.build(Module, target="llvm") lines = f.inspect_source().split("\n") @@ -479,57 +498,63 @@ def test_llvm_div(): """Check that the semantics of div and mod is correct""" def check(start, end, dstart, dend, dtype, floor_div=False): - div = tvm.te.floordiv if floor_div else tvm.tir.truncdiv - mod = tvm.te.floormod if floor_div else tvm.tir.truncmod + a_size = end - start + 1 + b_size = dend - dstart + 1 - # A are dividends, B are divisors. Note that we add 1 to make include end in the range. - A = te.placeholder((end - start + 1,), name="A", dtype=dtype) - B = te.placeholder((dend - dstart + 1,), name="B", dtype=dtype) - # We clip values with min and max so that simplifiers know the ranges of values + div_fn = tvm.tir.floordiv if floor_div else tvm.tir.truncdiv + mod_fn = tvm.tir.floormod if floor_div else tvm.tir.truncmod - def clipa(x): - return tvm.te.min(tvm.tir.const(end, dtype), tvm.te.max(tvm.tir.const(start, dtype), x)) + # Build clipping helpers — capture TIR const values from env + _start = tvm.tir.const(start, dtype) + _end = tvm.tir.const(end, dtype) + _dstart = tvm.tir.const(dstart, dtype) + _dend = tvm.tir.const(dend, dtype) - def clipb(x): - return tvm.te.min( - tvm.tir.const(dend, dtype), tvm.te.max(tvm.tir.const(dstart, dtype), x) - ) - - # If the range is just a single point, use the constant itself if start == end: - - def clipa(x): - return tvm.tir.const(start, dtype) + clipa = lambda x: _start + else: + clipa = lambda x: T.min(_end, T.max(_start, x)) if dstart == dend: - - def clipb(x): - return tvm.tir.const(dstart, dtype) - - # D are division results and M are modulo results - [D, M] = te.compute( - (end - start + 1, dend - dstart + 1), - lambda i, j: (div(clipa(A[i]), clipb(B[j])), mod(clipa(A[i]), clipb(B[j]))), - ) - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B, D, M]) - sch = tvm.s_tir.Schedule(mod) - - # Build from scheduled TIR - f = tvm.compile(sch.mod, target="llvm") + clipb = lambda x: _dstart + else: + clipb = lambda x: T.min(_dend, T.max(_dstart, x)) + + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((a_size,), dtype), + B: T.Buffer((b_size,), dtype), + D: T.Buffer((a_size, b_size), dtype), + M: T.Buffer((a_size, b_size), dtype), + ): + T.func_attr({"tir.noalias": True}) + for i, j in T.grid(a_size, b_size): + with T.sblock("D"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i], B[v_j]) + T.writes(D[v_i, v_j]) + D[v_i, v_j] = div_fn(clipa(A[v_i]), clipb(B[v_j])) + with T.sblock("M"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i], B[v_j]) + T.writes(M[v_i, v_j]) + M[v_i, v_j] = mod_fn(clipa(A[v_i]), clipb(B[v_j])) + + f = tvm.compile(Module, target="llvm") # Fill input arrays with values - A_arr = tvm.runtime.empty((end - start + 1,), dtype) - B_arr = tvm.runtime.empty((dend - dstart + 1,), dtype) + A_arr = tvm.runtime.empty((a_size,), dtype) + B_arr = tvm.runtime.empty((b_size,), dtype) A_arr.copyfrom(np.arange(start, end + 1, dtype=dtype)) B_np = np.arange(dstart, dend + 1, dtype=dtype) # If the range of the divisor contains 0, replace it with 1 to avoid division by zero if dend >= 0 and dstart <= 0: B_np[-dstart] = 1 B_arr.copyfrom(B_np) - D_arr = tvm.runtime.empty((end - start + 1, dend - dstart + 1), dtype) - M_arr = tvm.runtime.empty((end - start + 1, dend - dstart + 1), dtype) + D_arr = tvm.runtime.empty((a_size, b_size), dtype) + M_arr = tvm.runtime.empty((a_size, b_size), dtype) # Run the function and convert the results to numpy f(A_arr, B_arr, D_arr, M_arr) @@ -560,7 +585,7 @@ def _show_info(): raise AssertionError( "Incorrect division result: {}({}, {}) is {} " "but should be {}".format( - div.__name__, i, j, D_arr[i - start, j - dstart], dref + div_fn.__name__, i, j, D_arr[i - start, j - dstart], dref ) ) if M_arr[i - start, j - dstart] != mref: @@ -568,7 +593,7 @@ def _show_info(): raise AssertionError( "Incorrect modulo result: {}({}, {}) is {} " "but should be {}".format( - mod.__name__, i, j, M_arr[i - start, j - dstart], mref + mod_fn.__name__, i, j, M_arr[i - start, j - dstart], mref ) ) @@ -614,67 +639,73 @@ def _show_info(): @tvm.testing.requires_llvm def test_llvm_fp_math(): - def check_llvm_reciprocal(n): - A = te.placeholder((n,), name="A") - B = te.compute((n,), lambda i: te.div(1.0, (1e37 * A[i])), name="B") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B]) - sch = tvm.s_tir.Schedule(mod) + @I.ir_module + class RecipModule: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (n,)) + B = T.match_buffer(var_B, (n,)) + for i in range(n): + with T.sblock("B"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.float32(1.0) / ( + T.float32(9999999999999999538762658202121142272.0) * A[v_i] + ) - # Build from scheduled TIR - f = tvm.compile(sch.mod, target="llvm") + f_recip = tvm.compile(RecipModule, target="llvm") + for n in [4, 8, 16]: a = tvm.runtime.tensor(np.full((n,), 100, "float32")) b = tvm.runtime.empty((n,), "float32") - f(a, b) + f_recip(a, b) tvm.testing.assert_allclose(b.numpy(), np.zeros((n,), "float32")) - check_llvm_reciprocal(4) - check_llvm_reciprocal(8) - check_llvm_reciprocal(16) - - def check_llvm_sigmoid(n): - A = te.placeholder((n,), name="A") - B = te.compute((n,), lambda i: te.sigmoid(A[i]), name="B") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B]) - sch = tvm.s_tir.Schedule(mod) - - # Build from scheduled TIR - f = tvm.compile(sch.mod, target="llvm") - + @I.ir_module + class SigmoidModule: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (n,)) + B = T.match_buffer(var_B, (n,)) + for i in range(n): + with T.sblock("B"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.sigmoid(A[v_i]) + + f_sigmoid = tvm.compile(SigmoidModule, target="llvm") + + for n in [4, 8, 16]: a = tvm.runtime.tensor(np.full((n,), -1000, "float32")) b = tvm.runtime.empty((n,), "float32") - f(a, b) + f_sigmoid(a, b) tvm.testing.assert_allclose(b.numpy(), np.zeros((n,), "float32")) - check_llvm_sigmoid(4) - check_llvm_sigmoid(8) - check_llvm_sigmoid(16) - @tvm.testing.requires_llvm def test_dwarf_debug_information(): - nn = 1024 - n = tvm.runtime.convert(nn) - A = te.placeholder((n,), name="A") - B = te.placeholder((n,), name="B") - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B, C]) - sch = tvm.s_tir.Schedule(mod) - - # Get block and loop - block = sch.get_sblock("C") - loop = sch.get_loops(block)[0] - - # Split and parallelize - xo, xi = sch.split(loop, factors=[None, 4]) - sch.parallel(xo) - sch.vectorize(xi) + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0_0 in T.parallel(256): + for i0_1 in T.vectorized(4): + with T.sblock("C"): + v_i0 = T.axis.spatial(1024, i0_0 * 4 + i0_1) + T.reads(A[v_i0], B[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = A[v_i0] + B[v_i0] def check_llvm_object(): if tvm.target.codegen.llvm_version_major() < 5: @@ -684,8 +715,8 @@ def check_llvm_object(): # build two functions mod = tvm.IRModule( { - "fadd1": sch.mod["main"].with_attr("global_symbol", "fadd1"), - "fadd2": sch.mod["main"].with_attr("global_symbol", "fadd2"), + "fadd1": Module["main"].with_attr("global_symbol", "fadd1"), + "fadd2": Module["main"].with_attr("global_symbol", "fadd2"), } ) m = tvm.compile(mod, target="llvm") @@ -722,8 +753,8 @@ def check_llvm_ir(): # build two functions mod = tvm.IRModule( { - "fadd1": sch.mod["main"].with_attr("global_symbol", "fadd1"), - "fadd2": sch.mod["main"].with_attr("global_symbol", "fadd2"), + "fadd1": Module["main"].with_attr("global_symbol", "fadd1"), + "fadd2": Module["main"].with_attr("global_symbol", "fadd2"), } ) m = tvm.tir.build(mod, target="llvm -mtriple=aarch64-linux-gnu") @@ -748,24 +779,26 @@ def check_llvm_ir(): @tvm.testing.requires_llvm def test_llvm_bf16(): def dotest(do_vectorize): - np.random.seed(122) - A = te.placeholder((32,), dtype="bfloat16") - B = te.placeholder((32,), dtype="bfloat16") - D = te.compute((32,), lambda x: A[x] + B[x], name="D") - - # Convert to TIR and create schedule - mod = te.create_prim_func([A, B, D]) - sch = tvm.s_tir.Schedule(mod) - - # Get block and loop - block = sch.get_sblock("D") - loop = sch.get_loops(block)[0] - - # Apply vectorization if requested - if do_vectorize: - sch.vectorize(loop) + loop_kind = T.vectorized if do_vectorize else T.serial + + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((32,), "bfloat16"), + B: T.Buffer((32,), "bfloat16"), + D: T.Buffer((32,), "bfloat16"), + ): + T.func_attr({"tir.noalias": True}) + for x in loop_kind(32): + with T.sblock("D"): + v_x = T.axis.spatial(32, x) + T.reads(A[v_x], B[v_x]) + T.writes(D[v_x]) + D[v_x] = A[v_x] + B[v_x] - module = tvm.compile(sch.mod, target="llvm") + np.random.seed(122) + module = tvm.compile(Module, target="llvm") npa = np.random.rand(32).astype("bfloat16") npb = np.random.rand(32).astype("bfloat16") res = npa + npb @@ -783,12 +816,24 @@ def dotest(do_vectorize): @tvm.testing.requires_llvm def test_llvm_crt_static_lib(): - A = te.placeholder((32,), dtype="bfloat16") - B = te.placeholder((32,), dtype="bfloat16") - d = te.compute((32,), lambda x: A[x] + B[x]) - mod = tvm.IRModule.from_expr(te.create_prim_func([A, B, d])) + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((32,), "bfloat16"), + B: T.Buffer((32,), "bfloat16"), + C: T.Buffer((32,), "bfloat16"), + ): + T.func_attr({"tir.noalias": True}) + for x in range(32): + with T.sblock("compute"): + v_x = T.axis.spatial(32, x) + T.reads(A[v_x], B[v_x]) + T.writes(C[v_x]) + C[v_x] = A[v_x] + B[v_x] + module = tvm.tir.build( - mod.with_attr("system_lib_prefix", ""), + Module.with_attr("system_lib_prefix", ""), target=tvm.target.Target("llvm"), ) module.inspect_source() @@ -806,17 +851,14 @@ def test_llvm_order_functions(): class Module: @T.prim_func def Danny(v: T.float32) -> T.float32: - T.func_attr({"global_symbol": "Danny"}) T.ret(T.call_extern("float32", "Dave", v)) @T.prim_func def Sammy(v: T.float32) -> T.float32: - T.func_attr({"global_symbol": "Sammy"}) T.ret(T.call_extern("float32", "Eve", v)) @T.prim_func def Kirby(v: T.float32) -> T.float32: - T.func_attr({"global_symbol": "Kirby"}) T.ret(T.call_extern("float32", "Fred", v)) ir_text = tvm.tir.build(Module, target="llvm").inspect_source("ll") @@ -835,11 +877,6 @@ def test_llvm_import(): return x + y; } """ - n = 10 - A = te.placeholder((n,), name="A") - B = te.compute( - (n,), lambda *i: tvm.tir.call_pure_extern("float32", "my_add", A(*i), 1.0), name="B" - ) def check_llvm(use_file): if not clang.find_clang(required=False): @@ -848,18 +885,24 @@ def check_llvm(use_file): temp = utils.tempdir() ll_path = temp.relpath("temp.ll") ll_code = clang.create_llvm(cc_code, output=ll_path) - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - - if use_file: - sch.annotate(sch.get_loops("B")[0], "pragma_import_llvm", ll_path) - else: - sch.annotate(sch.get_loops("B")[0], "pragma_import_llvm", ll_code) - # BUILD and invoke the kernel. - f = tvm.compile(sch.mod, target="llvm") + import_val = ll_path if use_file else ll_code + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")): + T.func_attr({"tir.noalias": True}) + for i in T.serial(10, annotations={"pragma_import_llvm": import_val}): + with T.sblock("B"): + v_i = T.axis.spatial(10, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.call_pure_extern("float32", "my_add", A[v_i], T.float32(1.0)) + + f = tvm.compile(Module, target="llvm") dev = tvm.cpu(0) - # launch the kernel. - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.random.uniform(size=n).astype(B.dtype), dev) + a = tvm.runtime.tensor(np.random.uniform(size=10).astype("float32"), dev) + b = tvm.runtime.tensor(np.random.uniform(size=10).astype("float32"), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), a.numpy() + 1.0) @@ -869,33 +912,31 @@ def check_llvm(use_file): @tvm.testing.requires_llvm def test_llvm_scalar_concat(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.decl_buffer((1,), "int32x2") - s = tvm.tir.Shuffle([x, y], [0, 1]) - f = tvm.tir.PrimFunc([x, y, z], z.vstore(0, s)) - - mod = tvm.ir.IRModule.from_expr(f.with_attr("global_symbol", "codegen_scalar_concat")) + @I.ir_module + class Module: + @T.prim_func + def main(x: T.int32, y: T.int32, buffer: T.Buffer((1,), "int32x2")): + buffer[0] = T.Shuffle([x, y], [0, 1]) # This will crash in LLVM codegen if CodeGenLLVM::CreateVecConcat doesn't convert # scalars to single-lane LLVM vectors. with tvm.transform.PassContext(config={"tir.disable_assert": True}): - m = tvm.compile(mod, target="llvm") + m = tvm.compile(Module, target="llvm") @tvm.testing.requires_llvm def test_raise_exception_during_codegen(): - @T.prim_func - def threadpool_nested_parallel_loop( - A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32") - ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - for i in T.parallel(4): - for j in T.parallel(4): - B[i, j] = A[i, j] * 2.0 + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")) -> None: + T.func_attr({"tir.noalias": True}) + for i in T.parallel(4): + for j in T.parallel(4): + B[i, j] = A[i, j] * 2.0 with pytest.raises(tvm.TVMError) as e: - tvm.compile(tvm.IRModule.from_expr(threadpool_nested_parallel_loop), target="llvm") + tvm.compile(Module, target="llvm") msg = str(e) assert msg.find("Nested parallel loop is not supported") != -1 @@ -905,20 +946,33 @@ def test_llvm_target_attributes(): """Check that when LLVM codegen creates new functions, they get the same target attributes as the original function. """ - n = te.var() - A = te.placeholder((n,), name="A", dtype="float32") - B = te.compute((n,), lambda i: A[i], name="B") - C = te.compute((n,), lambda i: B[i] + tvm.tir.const(1, A.dtype), name="C") - sch = tvm.s_tir.Schedule( - te.create_prim_func([A, B, C, n]).with_attr("global_symbol", "test_func") - ) - xo, xi = sch.split(sch.get_loops("C")[0], factors=[2, None]) - sch.parallel(xo) + @I.ir_module + class Module: + @T.prim_func + def test_func(var_A: T.handle, var_B: T.handle, var_C: T.handle, tindex: T.int32): + T.func_attr({"tir.noalias": True}) + A = T.match_buffer(var_A, (tindex,)) + B = T.match_buffer(var_B, (tindex,)) + C = T.match_buffer(var_C, (tindex,)) + for i in range(tindex): + with T.sblock("B"): + v_i = T.axis.spatial(tindex, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = A[v_i] + for i_0 in T.parallel(2): + for i_1 in range((tindex + 1) // 2): + with T.sblock("C"): + v_i = T.axis.spatial(tindex, i_0 * ((tindex + 1) // 2) + i_1) + T.where(i_0 * ((tindex + 1) // 2) + i_1 < tindex) + T.reads(B[v_i]) + T.writes(C[v_i]) + C[v_i] = B[v_i] + T.float32(1.0) target_llvm = "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake -mattr=+avx512f" target = tvm.target.Target(target_llvm, host=target_llvm) - module = tvm.tir.build(sch.mod, target=target) + module = tvm.tir.build(Module, target=target) llvm_ir = module.inspect_source() llvm_ir_lines = llvm_ir.split("\n") @@ -959,20 +1013,19 @@ def test_llvm_assume(): related instructions get removed during optimizations """ - @T.prim_func - def tir_assume_func(A: T.Buffer((4, 4), "int32"), B: T.Buffer((14,), "int32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_1 = T.Buffer((16,), "int32", data=A.data) - for axis0, axis1 in T.grid(4, 4): - T.assume(axis0 < 3 or axis1 < 2 or A_1[axis0 * 4 + axis1] == 0) - for i in range(14): - B_1 = T.Buffer((14,), "int32", data=B.data) - B_1[i] = A_1[i] * 2 + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((4, 4), "int32"), B: T.Buffer((14,), "int32")): + T.func_attr({"tir.noalias": True}) + A_1 = T.Buffer((16,), "int32", data=A.data) + for axis0, axis1 in T.grid(4, 4): + T.assume(axis0 < 3 or axis1 < 2 or A_1[axis0 * 4 + axis1] == 0) + for i in range(14): + B_1 = T.Buffer((14,), "int32", data=B.data) + B_1[i] = A_1[i] * 2 - mod = tvm.IRModule.from_expr(tir_assume_func) - inp = te.placeholder((4, 4), name="A", dtype="int32") - out = te.placeholder((14,), name="B", dtype="int32") - m = tvm.compile(mod, target="llvm") + m = tvm.compile(Module, target="llvm") @tvm.testing.requires_llvm @@ -984,38 +1037,39 @@ def test_debug_symbol_for_float64(): prevents lowering to the PackedFunc API. """ - @T.prim_func - def func(a: T.handle("float64"), b: T.handle("float64"), n: T.int64): - T.func_attr({"calling_conv": 2}) - A = T.Buffer(16, "float64", data=a) - B = T.Buffer(16, "float64", data=b) - for i in range(n): - B[i] = A[i] + @I.ir_module + class Module: + @T.prim_func + def main(a: T.handle("float64"), b: T.handle("float64"), n: T.int64): + T.func_attr({"calling_conv": 2}) + A = T.Buffer(16, "float64", data=a) + B = T.Buffer(16, "float64", data=b) + for i in range(n): + B[i] = A[i] - tvm.compile(func, target="llvm") + tvm.compile(Module, target="llvm") @tvm.testing.requires_llvm def test_subroutine_call(): @I.ir_module - class mod: + class Module: @T.prim_func def main(A: T.Buffer(1, dtype="float32")): - T.func_attr({"global_symbol": "main"}) - mod.subroutine(A.data) + Module.subroutine(A.data) @T.prim_func def subroutine(A_data: T.handle("float32")): # The calling_conv parameter is to prevent MakePackedAPI # from changing the call signature of the subroutine. - T.func_attr({"global_symbol": "subroutine", "calling_conv": -1}) + T.func_attr({"calling_conv": -1}) A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 42.0 target = "llvm" dev = tvm.cpu() - built = tvm.compile(mod) + built = tvm.compile(Module) arr = tvm.runtime.tensor(np.zeros([1], "float32"), device=dev) built["main"](arr) @@ -1038,17 +1092,19 @@ def test_call_packed_returning_void(): for the packed function call. """ - @T.prim_func - def func(): - T.Call( - "void", - tvm.ir.Op.get("tir.tvm_call_packed"), - ["dummy_function_name"], - ) + @I.ir_module + class Module: + @T.prim_func + def main(): + T.Call( + "void", + tvm.ir.Op.get("tir.tvm_call_packed"), + ["dummy_function_name"], + ) # Error occurred during build, as part of # CodeGenCPU::MakeCallPackedLowered. - built = tvm.compile(func, target="llvm") + built = tvm.compile(Module, target="llvm") @tvm.testing.requires_llvm @@ -1061,68 +1117,76 @@ def test_call_packed_without_string_arg(): a segfault during codegen. """ - @T.prim_func - def func(A: T.Buffer(1, "float32")): - T.func_attr({"global_symbol": "func"}) - T.Call("int32", tvm.ir.Op.get("tir.tvm_call_packed"), [A.data]) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.Call("int32", tvm.ir.Op.get("tir.tvm_call_packed"), [A.data]) with pytest.raises(tvm.TVMError): - built = tvm.compile(func, target="llvm") + built = tvm.compile(Module, target="llvm") @tvm.testing.requires_llvm def test_call_extern_returning_void(): """Like test_call_packed_returning_void, but for call_extern""" - @T.prim_func - def func(): - T.func_attr({"global_symbol": "func"}) - T.Call("void", tvm.ir.Op.get("tir.call_extern"), ["dummy_function_name"]) + @I.ir_module + class Module: + @T.prim_func + def main(): + T.Call("void", tvm.ir.Op.get("tir.call_extern"), ["dummy_function_name"]) - built = tvm.compile(func, target="llvm") + built = tvm.compile(Module, target="llvm") def test_invalid_volatile_masked_buffer_load(): - @T.prim_func - def func(b: T.handle): - B = T.match_buffer(b, [4]) - a = T.allocate([4], "float32", scope="global") - T.attr(a, "volatile_scope", 1) - A = T.Buffer([4], data=a) - B[0:4] = A.vload([T.Ramp(0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)) + @I.ir_module + class Module: + @T.prim_func + def main(b: T.handle): + B = T.match_buffer(b, [4]) + a = T.allocate([4], "float32", scope="global") + T.attr(a, "volatile_scope", 1) + A = T.Buffer([4], data=a) + B[0:4] = A.vload([T.Ramp(0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)) err_msg = "The masked load intrinsic does not support declaring load as volatile." with pytest.raises(tvm.TVMError, match=err_msg): with tvm.target.Target("llvm"): - tvm.compile(func) + tvm.compile(Module) def test_invalid_volatile_masked_buffer_store(): - @T.prim_func - def func(): - a = T.allocate([4], "float32", scope="global") - T.attr(a, "volatile_scope", 1) - A = T.Buffer([4], data=a) - A.vstore([T.Ramp(0, 1, 4)], T.Broadcast(0.0, 4), predicate=T.Broadcast(T.bool(True), 4)) + @I.ir_module + class Module: + @T.prim_func + def main(): + a = T.allocate([4], "float32", scope="global") + T.attr(a, "volatile_scope", 1) + A = T.Buffer([4], data=a) + A.vstore([T.Ramp(0, 1, 4)], T.Broadcast(0.0, 4), predicate=T.Broadcast(T.bool(True), 4)) err_msg = "The masked store intrinsic does not support declaring store as volatile." with pytest.raises(tvm.TVMError, match=err_msg): with tvm.target.Target("llvm"): - tvm.compile(func) + tvm.compile(Module) def test_int_parameter(): """Boolean may be passed to functions accepting int""" - @T.prim_func - def func(arg: T.int32) -> T.int32: - T.func_attr({"target": T.target("llvm")}) - if arg > 0: - return 10 - else: - return 20 - - built = tvm.compile(func) + @I.ir_module + class Module: + @T.prim_func + def main(arg: T.int32) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg > 0: + return 10 + else: + return 20 + + built = tvm.compile(Module) output = built(True) assert output == 10 @@ -1133,15 +1197,17 @@ def func(arg: T.int32) -> T.int32: def test_bool_parameter(): """Integers may be passed to functions accepting bool""" - @T.prim_func - def func(arg: T.bool) -> T.int32: - T.func_attr({"target": T.target("llvm")}) - if arg: - return 10 - else: - return 20 - - built = tvm.compile(func) + @I.ir_module + class Module: + @T.prim_func + def main(arg: T.bool) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg: + return 10 + else: + return 20 + + built = tvm.compile(Module) output = built(1) assert output == 10 @@ -1155,12 +1221,14 @@ def func(arg: T.bool) -> T.int32: def test_bool_return_value(): """Booleans may be returned from a PrimFunc""" - @T.prim_func - def func(value: T.int32) -> T.bool: - T.func_attr({"target": T.target("llvm")}) - return value < 10 + @I.ir_module + class Module: + @T.prim_func + def main(value: T.int32) -> T.bool: + T.func_attr({"target": T.target("llvm")}) + return value < 10 - built = tvm.compile(func) + built = tvm.compile(Module) assert isinstance(built(0), bool) assert built(0) @@ -1171,12 +1239,14 @@ def func(value: T.int32) -> T.bool: def test_invalid_arguments(): """Integers may be passed to functions accepting bool""" - @T.prim_func - def func(a0: T.bool, a1: T.Buffer([10], "float32")) -> T.int32: - T.func_attr({"target": T.target("llvm")}) - return 0 + @I.ir_module + class Module: + @T.prim_func + def main(a0: T.bool, a1: T.Buffer([10], "float32")) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + return 0 - built = tvm.compile(func) + built = tvm.compile(Module) with pytest.raises(RuntimeError): built(1, 1) diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index cb168c93422f..f6f3e3de0aab 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -17,10 +17,8 @@ import numpy as np import tvm -import tvm.script import tvm.testing -from tvm import te -from tvm.script import tir as T +from tvm.script import tir as T, ir as I @tvm.testing.requires_gpu @@ -29,16 +27,24 @@ def test_metal_inf_nan(): target = "metal" def check_inf_nan(dev, n, value, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - inf_value = tvm.tir.const(value, dtype=dtype) - C = te.compute((n,), lambda i: inf_value, name="C") - prim_func = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(prim_func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - fun = tvm.compile(sch.mod, target=target) - a = tvm.runtime.empty((n,), A.dtype, dev) - c = tvm.runtime.empty((n,), A.dtype, dev) + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((1,), dtype), + C: T.Buffer((1,), dtype), + ): + T.func_attr({"tir.noalias": True}) + for i in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(1, i) + T.reads() + T.writes(C[v_i]) + C[v_i] = T.Cast(dtype, value) + + fun = tvm.compile(Module, target=target) + a = tvm.runtime.empty((n,), dtype, dev) + c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here fun(a, c) @@ -83,15 +89,24 @@ def test_metal_erf(): target = "metal" def check_erf(dev, n, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") - func = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - fun = tvm.compile(sch.mod, target=target) - a = tvm.runtime.empty((n,), A.dtype, dev) - c = tvm.runtime.empty((n,), A.dtype, dev) + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((1,), dtype), + C: T.Buffer((1,), dtype), + ): + T.func_attr({"tir.noalias": True}) + for i0 in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("C"): + v_i0 = T.axis.spatial(1, i0) + T.reads(A[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = T.erf(A[v_i0]) + + fun = tvm.compile(Module, target=target) + a = tvm.runtime.empty((n,), dtype, dev) + c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here fun(a, c) diff --git a/tests/python/codegen/test_target_codegen_opencl.py b/tests/python/codegen/test_target_codegen_opencl.py index 131215303fbe..6493dbc27f51 100644 --- a/tests/python/codegen/test_target_codegen_opencl.py +++ b/tests/python/codegen/test_target_codegen_opencl.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm import te +from tvm.script import tir as T, ir as I target = "opencl" @@ -27,38 +27,54 @@ @tvm.testing.requires_opencl def test_opencl_ternary_expression(): def check_if_then_else(dev, n, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - true_value = tvm.tir.const(1, dtype=dtype) - false_value = tvm.tir.const(3, dtype=dtype) - max_lhs = tvm.tir.const(2, dtype=dtype) - max_rhs = tvm.tir.if_then_else(A[0] > 0, true_value, false_value) - C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") - - func = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - fun = tvm.tir.build(sch.mod, target=target) - a = tvm.runtime.empty((n,), A.dtype, dev) - c = tvm.runtime.empty((n,), A.dtype, dev) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): + T.func_attr({"tir.noalias": True}) + for i in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(1, i) + T.reads(A[0]) + T.writes(C[v_i]) + C[v_i] = T.max( + T.Cast(dtype, 2), + T.if_then_else( + 0 < T.Cast("int32", A[0]), + T.Cast(dtype, 1), + T.Cast(dtype, 3), + ), + ) + + fun = tvm.tir.build(Module, target=target) + a = tvm.runtime.empty((n,), dtype, dev) + c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here fun(a, c) def check_select(dev, n, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - true_value = tvm.tir.const(1, dtype=dtype) - false_value = tvm.tir.const(3, dtype=dtype) - max_lhs = tvm.tir.const(2, dtype=dtype) - max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value) - C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") - func = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - fun = tvm.tir.build(sch.mod, target=target) - - a = tvm.runtime.empty((n,), A.dtype, dev) - c = tvm.runtime.empty((n,), A.dtype, dev) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): + T.func_attr({"tir.noalias": True}) + for i in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(1, i) + T.reads(A[0]) + T.writes(C[v_i]) + C[v_i] = T.max( + T.Cast(dtype, 2), + T.Select( + 0 < T.Cast("int32", A[0]), + T.Cast(dtype, 1), + T.Cast(dtype, 3), + ), + ) + + fun = tvm.tir.build(Module, target=target) + a = tvm.runtime.empty((n,), dtype, dev) + c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here fun(a, c) @@ -78,16 +94,21 @@ def check_select(dev, n, dtype): @tvm.testing.requires_opencl def test_opencl_inf_nan(): def check_inf_nan(dev, n, value, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - inf_value = tvm.tir.const(value, dtype=dtype) - C = te.compute((n,), lambda i: inf_value, name="C") - func = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - fun = tvm.tir.build(sch.mod, target=target) - a = tvm.runtime.empty((n,), A.dtype, dev) - c = tvm.runtime.empty((n,), A.dtype, dev) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): + T.func_attr({"tir.noalias": True}) + for i in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(1, i) + T.reads() + T.writes(C[v_i]) + C[v_i] = T.Cast(dtype, value) + + fun = tvm.tir.build(Module, target=target) + a = tvm.runtime.empty((n,), dtype, dev) + c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here fun(a, c) @@ -105,18 +126,21 @@ def check_inf_nan(dev, n, value, dtype): @tvm.testing.requires_opencl def test_opencl_max(): def check_max(dev, n, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - max_lhs = A[0] + tvm.tir.const(1, dtype=dtype) - max_rhs = tvm.tir.const(0, dtype=dtype) - C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") - func = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - fun = tvm.tir.build(sch.mod, target=target) - - a = tvm.runtime.empty((n,), A.dtype, dev) - c = tvm.runtime.empty((n,), A.dtype, dev) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): + T.func_attr({"tir.noalias": True}) + for i in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(1, i) + T.reads(A[0]) + T.writes(C[v_i]) + C[v_i] = T.max(A[0] + T.Cast(dtype, 1), T.Cast(dtype, 0)) + + fun = tvm.tir.build(Module, target=target) + a = tvm.runtime.empty((n,), dtype, dev) + c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here fun(a, c) @@ -132,13 +156,19 @@ def check_max(dev, n, dtype): def test_opencl_erf(): def check_erf(dev, n, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") - func = te.create_prim_func([A, C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - fun = tvm.tir.build(sch.mod, target=target) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): + T.func_attr({"tir.noalias": True}) + for i0 in T.thread_binding(1, thread="threadIdx.x"): + with T.sblock("C"): + v_i0 = T.axis.spatial(1, i0) + T.reads(A[v_i0]) + T.writes(C[v_i0]) + C[v_i0] = T.erf(A[v_i0]) + + fun = tvm.tir.build(Module, target=target) source_str = fun.imports[0].inspect_source() matches = re.findall("erf", source_str) @@ -154,31 +184,23 @@ def check_erf(dev, n, dtype): @tvm.testing.requires_gpu @tvm.testing.requires_opencl def test_opencl_type_casting(): + @I.ir_module + class Module: + @T.prim_func + def main(C: T.Buffer((32,), "float32")): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(8, thread="threadIdx.x"): + for i_1 in T.vectorized(4): + with T.sblock("C"): + v_i = T.axis.spatial(32, i_0 * 4 + i_1) + T.reads() + T.writes(C[v_i]) + C[v_i] = T.Select( + v_i // 4 == 3 and v_i % 3 == 1, T.float32(1.0), T.float32(0.0) + ) + def check_type_casting(ctx, n, dtype): - block_size = 4 - C = te.compute( - (n,), - lambda i: tvm.tir.Select( - tvm.tir.all( - *[ - i // block_size == tvm.tir.const(3, "int32"), - i % 3 == tvm.tir.const(1, "int32"), - ] - ), - tvm.tir.const(1, dtype), - tvm.tir.const(0, dtype), - ), - name="C", - ) - # NOTE: test simple convert pattern - func = te.create_prim_func([C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - tx, vx = sch.split(x, factors=[None, block_size]) - sch.bind(tx, "threadIdx.x") - sch.vectorize(vx) - - fun = tvm.tir.build(sch.mod, target=target) + fun = tvm.tir.build(Module, target=target) c = tvm.runtime.empty((n,), dtype, ctx) assembly = fun.imports[0].inspect_source() lcond = "convert_int4(((convert_uint4(((uint4)(((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3)))))" @@ -199,24 +221,27 @@ def check_type_casting(ctx, n, dtype): @tvm.testing.parametrize_targets("opencl", "opencl -device=adreno") def test_opencl_ceil_log2(target): def _check(target, n, dtype): - with tvm.target.Target(target): - C = te.compute( - (n,), - lambda i: tvm.topi.ceil_log2(i), - name="C", - ) - func = te.create_prim_func([C]) - sch = tvm.s_tir.Schedule(func) - (x,) = sch.get_loops(sch.get_sblock("C")) - sch.bind(x, "threadIdx.x") - - fun = tvm.tir.build(sch.mod, target=target) - assembly = fun.imports[0].inspect_source() - if "adreno" in target: - pattern = "convert_float" - else: - pattern = "convert_double" - assert assembly.count(pattern) != 0 + inter_dtype = "float32" if "adreno" in target else "float64" + + @I.ir_module + class Module: + @T.prim_func + def main(C: T.Buffer((n,), "int32")): + T.func_attr({"tir.noalias": True}) + for i in T.thread_binding(n, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(n, i) + T.reads() + T.writes(C[v_i]) + C[v_i] = T.Cast("int32", T.ceil(T.log2(T.Cast(inter_dtype, v_i)))) + + fun = tvm.tir.build(Module, target=target) + assembly = fun.imports[0].inspect_source() + if "adreno" in target: + pattern = "convert_float" + else: + pattern = "convert_double" + assert assembly.count(pattern) != 0 _check(target, 32, "float32") diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index a4ba6e12e50e..91bbeed56b66 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -16,24 +16,30 @@ # under the License. import tvm import tvm.testing -from tvm import te import numpy as np -from tvm.script import tir as T +from tvm.script import tir as T, ir as I @tvm.testing.requires_rocm def test_rocm_inf_nan(): def check_inf_nan(dev, n, value, dtype): - A = te.placeholder((n,), name="A", dtype=dtype) - inf_value = tvm.tir.const(value, dtype=dtype) - C = te.compute((n,), lambda i: inf_value, name="C") - sch = tvm.s_tir.Schedule(te.create_prim_func([A, C])) - xo, xi = sch.split(sch.get_loops("C")[0], factors=[None, 128]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "threadIdx.x") - fun = tvm.compile(sch.mod, "rocm") - a = tvm.runtime.empty((n,), A.dtype, dev) - c = tvm.runtime.empty((n,), A.dtype, dev) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.sblock("C"): + v_i = T.axis.spatial(1, i_0 * 128 + i_1) + T.where(i_0 * 128 + i_1 < 1) + T.reads() + T.writes(C[v_i]) + C[v_i] = T.Cast(dtype, value) + + fun = tvm.compile(Module, "rocm") + a = tvm.runtime.empty((n,), dtype, dev) + c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here fun(a, c) @@ -50,10 +56,9 @@ def check_inf_nan(dev, n, value, dtype): @tvm.testing.requires_rocm def test_rocm_copy(): def check_rocm(dtype, n): - A = te.placeholder((n,), name="A", dtype=dtype) dev = tvm.rocm(0) - a_np = np.random.uniform(size=(n,)).astype(A.dtype) - a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(a_np) + a_np = np.random.uniform(size=(n,)).astype(dtype) + a = tvm.runtime.empty((n,), dtype, dev).copyfrom(a_np) b_np = a.numpy() tvm.testing.assert_allclose(a_np, b_np) tvm.testing.assert_allclose(a_np, a.numpy()) @@ -67,20 +72,28 @@ def check_rocm(dtype, n): @tvm.testing.requires_rocm def test_rocm_vectorize_add(): - num_thread = 8 - def check_rocm(dtype, n, lanes): - A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) - B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, 4]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "threadIdx.x") - fun = tvm.compile(sch.mod, target="rocm") + vec_dtype = "%sx%d" % (dtype, lanes) + num_blocks = n // 4 + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): + T.func_attr({"tir.noalias": True}) + for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): + for i_1 in T.thread_binding(4, thread="threadIdx.x"): + with T.sblock("B"): + v_i = T.axis.spatial(n, i_0 * 4 + i_1) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = A[v_i] + T.Broadcast(T.Cast(dtype, 1), lanes) + + fun = tvm.compile(Module, target="rocm") dev = tvm.rocm(0) - a = tvm.runtime.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) - c = tvm.runtime.empty((n,), B.dtype, dev) + a = tvm.runtime.empty((n,), vec_dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) + c = tvm.runtime.empty((n,), vec_dtype, dev) fun(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index c615d302455c..d010da46da49 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -15,31 +15,17 @@ # specific language governing permissions and limitations # under the License. -import os -from posixpath import split -import random import re -import threading import numpy as np import pytest import tvm import tvm.testing -from tvm import te, tir -from tvm.topi.math import cast from tvm.script import tir as T, ir as I -from tvm.tir import TensorIntrin, IntImm, Cast -from tvm.s_tir.tensor_intrin.cuda import ( - WMMA_LOAD_16x16x16_F16_A_INTRIN, - WMMA_LOAD_16x16x16_F16_B_INTRIN, - WMMA_SYNC_16x16x16_f16f16f32_INTRIN, - WMMA_FILL_16x16x16_F32_INTRIN, - WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, - WMMA_SYNC_16x16x16_f16f16f16_INTRIN, - WMMA_FILL_16x16x16_F16_INTRIN, - WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, -) +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import ir as I_builder +from tvm.script.ir_builder import tir as T_builder dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") @@ -62,27 +48,22 @@ ) def test_vector_comparison(target, dev, dtype): target = tvm.target.Target(target) - n = 1024 - A = te.placeholder((n,), dtype=dtype, name="A") - B = te.compute( - A.shape, - lambda i: tvm.tir.Select( - A[i] >= 0, A[i] + tvm.tir.const(1, dtype), tvm.tir.const(0, dtype) - ), - name="B", - ) + zero = tvm.tir.const(0, dtype) + one = tvm.tir.const(1, dtype) - # Create IRModule - mod = tvm.IRModule.from_expr(te.create_prim_func([A, B])) - sch = tvm.s_tir.Schedule(mod) - (bx, tx) = sch.split(sch.get_loops("B")[0], factors=[None, 128]) - (tx, vx) = sch.split(tx, factors=[None, 4]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vx) + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1024,), dtype), B: T.Buffer((1024,), dtype)): + for i_0 in T.thread_binding(8, thread="blockIdx.x"): + for i_1 in T.thread_binding(32, thread="threadIdx.x"): + for i_2 in T.vectorized(4): + with T.sblock("B"): + v_i = T.axis.spatial(1024, i_0 * 128 + i_1 * 4 + i_2) + B[v_i] = T.Select(A[v_i] >= zero, A[v_i] + one, zero) # Build - f = tvm.tir.build(sch.mod, target=target) + f = tvm.tir.build(Module, target=target) # Verify we generate the boolx4 type declaration and the OpSelect # v4{float,half,int} instruction @@ -114,19 +95,25 @@ def test_array_vectorize_add(target, dev, dtype): if "opencl" in str(target) and dtype == "float16": pytest.xfail("Opencl target does not support float16") - A = te.placeholder((arr_size,), name="A", dtype="%sx%d" % (dtype, lanes)) - B = te.compute(A.shape, lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") + vec_dtype = f"{dtype}x{lanes}" + one = tvm.tir.const(1, vec_dtype) + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((64,), vec_dtype), B: T.Buffer((64,), vec_dtype)): + for i_0 in T.thread_binding(16, thread="blockIdx.x"): + for i_1 in T.thread_binding(4, thread="threadIdx.x"): + with T.sblock("B"): + v_i = T.axis.spatial(64, i_0 * 4 + i_1) + B[v_i] = A[v_i] + one - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, 4]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "threadIdx.x") - f = tvm.compile(sch.mod, target=target) + f = tvm.compile(Module, target=target) - a = tvm.runtime.empty((arr_size,), A.dtype, dev).copyfrom( + a = tvm.runtime.empty((arr_size,), vec_dtype, dev).copyfrom( np.random.uniform(size=(arr_size, lanes)) ) - c = tvm.runtime.empty((arr_size,), B.dtype, dev) + c = tvm.runtime.empty((arr_size,), vec_dtype, dev) f(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) @@ -135,16 +122,18 @@ def test_array_vectorize_add(target, dev, dtype): def test_vulkan_bool_load(target, dev): target = tvm.target.Target(target) arr_size = 1024 - A = te.placeholder((arr_size,), name="A", dtype="bool") - B = te.compute(A.shape, lambda i: A[i].astype("int32"), name="B") - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, 128]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "threadIdx.x") + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((1024,), "bool"), B: T.Buffer((1024,), "int32")): + for i_0 in T.thread_binding(8, thread="blockIdx.x"): + for i_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.sblock("B"): + v_i = T.axis.spatial(1024, i_0 * 128 + i_1) + B[v_i] = T.Cast("int32", A[v_i]) - # Build - f = tvm.compile(sch.mod, target=target) + f = tvm.compile(Module, target=target) a_np = np.random.uniform(size=arr_size) > 0.5 b_np = np.zeros((arr_size,), dtype="int32") @@ -183,25 +172,41 @@ def test_vulkan_constant_passing(target, dev, vulkan_parameter_impl, vulkan_para max_int_params_in_push = max_push_constants_size // 8 - 3 num_int_params = max_int_params_in_push + 1 - n = te.var("n") - scalars = [te.var("scale{}".format(i), dtype=dtype) for i in range(num_int_params)] - scalar_sum = scalars[0] - for s in scalars[1:]: - scalar_sum += s - - A = te.placeholder((n,), name="A", dtype=dtype) - B = te.compute(A.shape, lambda i: scalar_sum + A[i], name="B") - - sch = tvm.s_tir.Schedule(te.create_prim_func(scalars + [A, B])) - xo, xi = sch.split(sch.get_loops("B")[0], factors=[None, 64]) - sch.bind(xo, "blockIdx.x") - sch.bind(xi, "threadIdx.x") - f_add = tvm.compile(sch.mod, target=target) + # Build IRModule programmatically since num_int_params is dynamic + with IRBuilder() as ib: + with I_builder.ir_module(): + with T_builder.prim_func(): + T_builder.func_name("main") + scalar_vars = [] + for i in range(num_int_params): + v = T_builder.arg(f"scale{i}", tvm.tir.Var("", dtype)) + scalar_vars.append(v) + var_A = T_builder.arg("var_A", T_builder.handle()) + var_B = T_builder.arg("var_B", T_builder.handle()) + T_builder.func_attr({"tir.noalias": True}) + n_var = T_builder.int32(is_size_var=True) + A = T_builder.match_buffer(var_A, (n_var,), dtype) + B = T_builder.match_buffer(var_B, (n_var,), dtype) + scalar_sum = scalar_vars[0] + for s in scalar_vars[1:]: + scalar_sum = scalar_sum + s + with T_builder.thread_binding( + tvm.tir.ceildiv(n_var, 64), thread="blockIdx.x" + ) as i_0: + with T_builder.thread_binding(64, thread="threadIdx.x") as i_1: + with T_builder.sblock("B"): + v_i = T_builder.axis.spatial(n_var, i_0 * 64 + i_1) + T_builder.where(i_0 * 64 + i_1 < n_var) + T_builder.reads(A[v_i]) + T_builder.writes(B[v_i]) + T_builder.buffer_store(B, scalar_sum + A[v_i], [v_i]) + mod = ib.get() + f_add = tvm.compile(mod, target=target) n = 1024 - scalars = np.array([1 for _ in scalars]).astype(dtype) - a = tvm.runtime.tensor(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(n, dtype=B.dtype), dev) + scalars = np.array([1 for _ in range(num_int_params)]).astype(dtype) + a = tvm.runtime.tensor(np.random.uniform(size=n).astype(dtype), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype=dtype), dev) f_add(*scalars, a, b) tvm.testing.assert_allclose(a.numpy() + sum(scalars), b.numpy()) @@ -361,18 +366,25 @@ def test_negative_operand_divmod(target, dev): offset = 16 divisor = 5 - @T.prim_func - def func(A: T.Buffer((N, 2), "int32")): - for i in T.serial(N): - with T.sblock("A"): - v_i = T.axis.spatial(N, i) - A[v_i, 0] = T.floordiv(v_i - offset, divisor) - A[v_i, 1] = T.floormod(v_i - offset, divisor) - if "gpu" in tvm.target.Target(target).keys: - sch = tvm.s_tir.Schedule(func) - sch.bind(sch.get_loops("A")[0], "threadIdx.x") - func = sch.mod["main"] + + @T.prim_func + def func(A: T.Buffer((N, 2), "int32")): + for i in T.thread_binding(N, thread="threadIdx.x"): + with T.sblock("A"): + v_i = T.axis.spatial(N, i) + A[v_i, 0] = T.floordiv(v_i - offset, divisor) + A[v_i, 1] = T.floormod(v_i - offset, divisor) + + else: + + @T.prim_func + def func(A: T.Buffer((N, 2), "int32")): + for i in T.serial(N): + with T.sblock("A"): + v_i = T.axis.spatial(N, i) + A[v_i, 0] = T.floordiv(v_i - offset, divisor) + A[v_i, 1] = T.floormod(v_i - offset, divisor) built = tvm.compile(func, target=target) @@ -386,105 +398,91 @@ def func(A: T.Buffer((N, 2), "int32")): @pytest.mark.parametrize("out_dtype", ["float32", "float16"]) def test_cooperative_matrix(out_dtype): - def get_matmul(m, n, k, out_dtype="float32"): - X = te.placeholder((m, k), name="X", dtype="float16") - W = te.placeholder((k, n), name="W", dtype="float16") - ak = te.reduce_axis((0, k), name="k") - - if out_dtype == "float32": - matmul = te.compute( - (m, n), - lambda i, j: te.sum( - X[i, ak].astype("float32") * W[ak, j].astype("float32"), - axis=ak, - ), - name="compute", - ) - else: - matmul = te.compute( - (m, n), - lambda i, j: te.sum(X[i, ak] * W[ak, j], axis=ak), - name="compute", - ) - - return te.create_prim_func([X, W, matmul]) - M, N, K = 16, 16, 32 - func = get_matmul(M, N, K, out_dtype) - sch = tvm.s_tir.Schedule(func) - block = sch.get_sblock("compute") - - i, j, k = sch.get_loops(block) - i_outer, i_inner = sch.split(i, factors=[None, 16]) - j_outer, j_inner = sch.split(j, factors=[None, 16]) - k_outer, k_inner = sch.split(k, factors=[None, 16]) - sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner) - fused_outer = sch.fuse(i_outer, j_outer) - sch.bind(fused_outer, "blockIdx.x") - - def fetch_to_shared(block, idx): - block_read = sch.cache_read(block, idx, "shared") - sch.compute_at(block_read, k_outer) - warp_size = 32 - - fused = sch.fuse(*sch.get_loops(block_read)[-2:]) - - vector_size = 4 - _, f_2, f_3 = sch.split(fused, factors=[None, warp_size, vector_size]) - sch.bind(f_2, "threadIdx.x") - sch.vectorize(f_3) - - def tensorize_load(block, dim): - loops = sch.get_loops(block) - i, j = loops[-dim : (len(loops) - dim + 2)] - - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - return i1 - - fetch_to_shared(block, 0) - fetch_to_shared(block, 1) - - c_warp_scope = "wmma.accumulator" - a_warp_scope = "wmma.matrix_a" - b_warp_scope = "wmma.matrix_b" - - A_mat = sch.cache_read(block, 0, a_warp_scope) - B_mat = sch.cache_read(block, 1, b_warp_scope) - - loop_a = tensorize_load(A_mat, 2) - sch.tensorize(loop_a, WMMA_LOAD_16x16x16_F16_A_INTRIN) - - loop_b = tensorize_load(B_mat, 2) - sch.tensorize(loop_b, WMMA_LOAD_16x16x16_F16_B_INTRIN) - - store = sch.cache_write(block, 0, c_warp_scope) - sch.reverse_compute_at(store, fused_outer) - init = sch.decompose_reduction(block, sch.get_loops(block)[1]) - - intrin = WMMA_FILL_16x16x16_F32_INTRIN - if out_dtype == "float16": - intrin = WMMA_FILL_16x16x16_F16_INTRIN - sch.tensorize(sch.get_loops(init)[1], intrin) - - intrin = WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN - if out_dtype == "float16": - intrin = WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN - sch.tensorize(sch.get_loops(store)[1], intrin) - - intrin = WMMA_SYNC_16x16x16_f16f16f32_INTRIN - if out_dtype == "float16": - intrin = WMMA_SYNC_16x16x16_f16f16f16_INTRIN - sch.tensorize(sch.get_loops(block)[2], intrin) + + # fmt: off + @I.ir_module + class Module: + @T.prim_func + def main(X: T.Buffer((16, 32), "float16"), W: T.Buffer((32, 16), "float16"), compute: T.Buffer((16, 16), out_dtype)): + T.func_attr({"tir.noalias": True}) + X_shared = T.alloc_buffer((16, 32), "float16", scope="shared") + W_shared = T.alloc_buffer((32, 16), "float16", scope="shared") + X_shared_wmma_matrix_a = T.alloc_buffer((16, 32), "float16", scope="wmma.matrix_a") + W_shared_wmma_matrix_b = T.alloc_buffer((32, 16), "float16", scope="wmma.matrix_b") + compute_wmma_accumulator = T.alloc_buffer((16, 16), out_dtype, scope="wmma.accumulator") + for i_0_j_0_fused in T.thread_binding(1, thread="blockIdx.x"): + with T.sblock("compute_init_o"): + v_i_o = T.axis.spatial(1, 0) + v_j_o = T.axis.spatial(1, 0) + T.reads() + T.writes(compute_wmma_accumulator[0:16, 0:16]) + C = T.match_buffer(compute_wmma_accumulator[0:16, 0:16], (16, 16), out_dtype, strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) + T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0.0)) + for k_0 in range(2): + for ax0_ax1_fused_0 in range(2): + for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(4): + with T.sblock("X_shared"): + v0 = T.axis.spatial(16, (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 16) + v1 = T.axis.spatial(32, k_0 * 16 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 16) + T.reads(X[v0, v1]) + T.writes(X_shared[v0, v1]) + X_shared[v0, v1] = X[v0, v1] + for ax0_ax1_fused_0 in range(2): + for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(4): + with T.sblock("W_shared"): + v0 = T.axis.spatial(32, k_0 * 16 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 16) + v1 = T.axis.spatial(16, (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 16) + T.reads(W[v0, v1]) + T.writes(W_shared[v0, v1]) + W_shared[v0, v1] = W[v0, v1] + for ax0_0 in T.unroll(1): + for ax1_0 in T.unroll(1): + with T.sblock("X_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(1, ax0_0) + v1_o = T.axis.spatial(2, k_0 + ax1_0) + T.reads(X_shared[0:16, v1_o * 16:v1_o * 16 + 16]) + T.writes(X_shared_wmma_matrix_a[0:16, v1_o * 16:v1_o * 16 + 16]) + A = T.match_buffer(X_shared[0:16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=16) + C = T.match_buffer(X_shared_wmma_matrix_a[0:16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") + for ax0_0 in T.unroll(1): + for ax1_0 in T.unroll(1): + with T.sblock("W_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(2, k_0 + ax0_0) + v1_o = T.axis.spatial(1, ax1_0) + T.reads(W_shared[v0_o * 16:v0_o * 16 + 16, 0:16]) + T.writes(W_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, 0:16]) + A = T.match_buffer(W_shared[v0_o * 16:v0_o * 16 + 16, 0:16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=16) + C = T.match_buffer(W_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, 0:16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") + with T.sblock("compute_update_o"): + v_i_o = T.axis.spatial(1, 0) + v_j_o = T.axis.spatial(1, 0) + v_k_o = T.axis.reduce(2, k_0) + T.reads(compute_wmma_accumulator[0:16, 0:16], X_shared_wmma_matrix_a[0:16, v_k_o * 16:v_k_o * 16 + 16], W_shared_wmma_matrix_b[v_k_o * 16:v_k_o * 16 + 16, 0:16]) + T.writes(compute_wmma_accumulator[0:16, 0:16]) + A = T.match_buffer(X_shared_wmma_matrix_a[0:16, v_k_o * 16:v_k_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) + B = T.match_buffer(W_shared_wmma_matrix_b[v_k_o * 16:v_k_o * 16 + 16, 0:16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) + C = T.match_buffer(compute_wmma_accumulator[0:16, 0:16], (16, 16), out_dtype, strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) + T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) + with T.sblock("compute_wmma.accumulator_o"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(1, 0) + T.reads(compute_wmma_accumulator[0:16, 0:16]) + T.writes(compute[0:16, 0:16]) + A = T.match_buffer(compute_wmma_accumulator[0:16, 0:16], (16, 16), out_dtype, strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) + C = T.match_buffer(compute[0:16, 0:16], (16, 16), out_dtype, strides=("C_s0", "C_s1"), offset_factor=16) + T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation(out_dtype), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") + # fmt: on target = "vulkan -from_device=0" tgt_attrs = tvm.target.Target(target).attrs if tgt_attrs.get("supports_cooperative_matrix"): - f = tvm.compile(sch.mod, target=target) + f = tvm.compile(Module, target=target) dev = tvm.device(target, 0) @@ -506,7 +504,7 @@ def test_codegen_decl_buffer(): """The codegen should accept DeclBuffer nodes in its input""" @I.ir_module - class mod: + class Module: @T.prim_func def kernel(): T.func_attr({"calling_conv": 2, "global_symbol": "kernel", "tir.noalias": True}) @@ -515,7 +513,7 @@ def kernel(): target = tvm.target.Target("vulkan") vulkan_codegen = tvm.get_global_func("target.build.vulkan") - vulkan_codegen(mod, target) + vulkan_codegen(Module, target) @tvm.testing.requires_gpu @@ -537,24 +535,28 @@ def test_unary(): ] def run_test(tvm_intrin, np_func): - m = te.var("m") - A = te.placeholder((m,), name="A", dtype="float32") - B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B") - - mod = te.create_prim_func([A, B]) - sch = tvm.s_tir.Schedule(mod) + n = 16 - block = sch.get_sblock("B") - loop = sch.get_loops(block)[0] - bx, tx = sch.split(loop, factors=[None, 64]) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle): + m = T.int32(is_size_var=True) + A = T.match_buffer(var_A, (m,), "float32") + B = T.match_buffer(var_B, (m,), "float32") + for i_0 in T.thread_binding((m + 63) // 64, thread="blockIdx.x"): + for i_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.sblock("B"): + v_i = T.axis.spatial(m, i_0 * 64 + i_1) + T.where(i_0 * 64 + i_1 < m) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = tvm_intrin(A[v_i]) target = tvm.target.Target("vulkan") dev = tvm.device(target.kind.name, 0) - func = tvm.compile(sch.mod, target=target) + func = tvm.compile(Module, target=target) - n = 16 if tvm_intrin in [tvm.tir.asin, tvm.tir.acos]: data = np.random.uniform(-1.0, 1.0, size=n) elif tvm_intrin == tvm.tir.atanh: @@ -564,8 +566,8 @@ def run_test(tvm_intrin, np_func): else: data = np.random.uniform(0.1, 0.9, size=n) - a = tvm.runtime.tensor(data.astype(A.dtype), dev) - b = tvm.runtime.tensor(np.zeros(n, dtype=A.dtype), dev) + a = tvm.runtime.tensor(data.astype("float32"), dev) + b = tvm.runtime.tensor(np.zeros(n, dtype="float32"), dev) func(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) diff --git a/tests/python/codegen/test_target_codegen_x86.py b/tests/python/codegen/test_target_codegen_x86.py index abed886ce6f1..cb3856e61ce2 100644 --- a/tests/python/codegen/test_target_codegen_x86.py +++ b/tests/python/codegen/test_target_codegen_x86.py @@ -20,7 +20,7 @@ import pytest import tvm -from tvm import te +from tvm.script import tir as T, ir as I llvm_version = tvm.target.codegen.llvm_version_major() machine = platform.machine() @@ -34,12 +34,24 @@ def test_fp16_to_fp32(): def fp16_to_fp32(target, width, match=None, not_match=None): elements = 64 - n = tvm.runtime.convert(elements) - A = te.placeholder((n, width), dtype="float16", name="A") - B = te.compute(A.shape, lambda *i: A(*i).astype("float32"), name="B") - sch = tvm.s_tir.Schedule(te.create_prim_func([A, B])) - sch.vectorize(sch.get_loops("B")[1]) - f = tvm.tir.build(sch.mod, target=target) + + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((elements, width), "float16"), + B: T.Buffer((elements, width), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0 in range(elements): + for i1 in T.vectorized(width): + with T.sblock("B"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(B[v_i0, v_i1]) + B[v_i0, v_i1] = T.Cast("float32", A[v_i0, v_i1]) + + f = tvm.tir.build(Module, target=target) assembly = f.inspect_source("asm").splitlines() if match: