Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 45 additions & 55 deletions tests/python/codegen/test_gpu_codegen_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,49 @@
import tvm_ffi
import tvm.testing
import numpy as np
from tvm.script import tir as T
from tvm.script import tir as T, ir as I

import pytest


@T.prim_func
def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None:
A = T.match_buffer(a, [1, d1, d2, d3])
B = T.match_buffer(b, [1, d1, d2])

for i, j, k, l in T.grid(1, d1, d2, d3):
with T.sblock("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
with T.init():
B[vi, vj, vk] = 0.0
B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]


@T.prim_func
def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None:
A = T.match_buffer(a, [1, d1, d2, d3])
B = T.match_buffer(b, [1, d1, d2])

for i, j, k, l in T.grid(1, d1, d2, d3):
with T.sblock("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
with T.init():
B[vi, vj, vk] = T.float32(-3.4028234663852886e38)
B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])
def _reduce_sum_module(d1, d2, d3):
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1, d1, d2, d3), "float32"), B: T.Buffer((1, d1, d2), "float32")):
for i in T.thread_binding(1, thread="blockIdx.x"):
for j in T.thread_binding(d1, thread="threadIdx.z"):
for k in T.thread_binding(d2, thread="threadIdx.y"):
for l in T.thread_binding(d3, thread="threadIdx.x"):
with T.sblock("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
T.reads(A[vi, vj, vk, vl])
T.writes(B[vi, vj, vk])
with T.init():
B[vi, vj, vk] = T.float32(0.0)
B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]

return Module


def _reduce_max_module(d1, d2, d3):
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1, d1, d2, d3), "float32"), B: T.Buffer((1, d1, d2), "float32")):
for i in T.thread_binding(1, thread="blockIdx.x"):
for j in T.thread_binding(d1, thread="threadIdx.z"):
for k in T.thread_binding(d2, thread="threadIdx.y"):
for l in T.thread_binding(d3, thread="threadIdx.x"):
with T.sblock("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
T.reads(A[vi, vj, vk, vl])
T.writes(B[vi, vj, vk])
with T.init():
B[vi, vj, vk] = T.float32(-3.4028234663852886e38)
B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])

return Module


def generate_param_sets():
Expand All @@ -63,16 +77,8 @@ def generate_param_sets():
@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_sum(dims, target, dev):
d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.s_tir.Schedule(mod)
blk = sch.get_sblock("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.compile(sch.mod["main"], target=target)
mod = _reduce_sum_module(d1, d2, d3)
f = tvm.compile(mod, target=target)

# prepare input and output array
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
Expand Down Expand Up @@ -117,31 +123,15 @@ def test_allreduce_sum_compile(optional_metal_compile_callback):
target = "metal"

d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.s_tir.Schedule(mod)
blk = sch.get_sblock("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
tvm.compile(sch.mod["main"], target=target)
mod = _reduce_sum_module(d1, d2, d3)
tvm.compile(mod, target=target)


@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_max(dims, target, dev):
d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce_max.params
mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.s_tir.Schedule(mod)
blk = sch.get_sblock("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.compile(sch.mod["main"], target=target)
mod = _reduce_max_module(d1, d2, d3)
f = tvm.compile(mod, target=target)

# prepare input and output array
a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
Expand Down
Loading
Loading