Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions cunumeric/linalg/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@
from ..runtime import Runtime


def get_gpu_lower_triangular(point):
assert len(point) == 2
from legate.core import get_machine
from legate.core.machine import ProcessorKind

machine = get_machine()
num_gpus = machine.count(ProcessorKind.GPU)
gpus = machine.only(ProcessorKind.GPU)

# The linearized block-cyclic lower-triangular-blocked decomposition
# mapping of 2-d points to blocks is:
def mapping(i, j):
return (j + (i * (i + 1)) // 2) % num_gpus

return gpus[mapping(point[0], point[1])]


def transpose_copy_single(
context: Context, input: Store, output: Store
) -> None:
Expand Down Expand Up @@ -83,7 +100,8 @@ def potrf(context: Context, p_output: StorePartition, i: int) -> None:
task.throws_exception(LinAlgError)
task.add_output(p_output)
task.add_input(p_output)
task.execute()
with get_gpu_lower_triangular((i, i)):
task.execute()


def trsm(
Expand All @@ -95,14 +113,16 @@ def trsm(
rhs = p_output.get_child_store(i, i)
lhs = p_output

launch_domain = Rect(lo=(lo, i), hi=(hi, i + 1))
task = context.create_manual_task(
CuNumericOpCode.TRSM, launch_domain=launch_domain
)
task.add_output(lhs)
task.add_input(rhs)
task.add_input(lhs)
task.execute()
for point in Rect(lo=(lo, i), hi=(hi, i + 1)):
task = context.create_manual_task(
CuNumericOpCode.TRSM,
launch_domain=Rect(lo=point, hi=point, exclusive=False),
)
task.add_output(lhs)
task.add_input(rhs)
task.add_input(lhs)
with get_gpu_lower_triangular(point):
task.execute()


def syrk(context: Context, p_output: StorePartition, k: int, i: int) -> None:
Expand All @@ -116,7 +136,8 @@ def syrk(context: Context, p_output: StorePartition, k: int, i: int) -> None:
task.add_output(lhs)
task.add_input(rhs)
task.add_input(lhs)
task.execute()
with get_gpu_lower_triangular((k, k)):
task.execute()


def gemm(
Expand All @@ -134,18 +155,20 @@ def gemm(
lhs = p_output
rhs1 = p_output

launch_domain = Rect(lo=(lo, k), hi=(hi, k + 1))
task = context.create_manual_task(
CuNumericOpCode.GEMM, launch_domain=launch_domain
)
task.add_output(lhs)
task.add_input(rhs1, proj=lambda p: (p[0], i))
task.add_input(rhs2)
task.add_input(lhs)
task.execute()
for point in Rect(lo=(lo, k), hi=(hi, k + 1)):
task = context.create_manual_task(
CuNumericOpCode.GEMM,
launch_domain=Rect(lo=point, hi=point, exclusive=False),
)
task.add_output(lhs)
task.add_input(rhs1, proj=lambda p: (p[0], i))
task.add_input(rhs2)
task.add_input(lhs)
with get_gpu_lower_triangular(point):
task.execute()


MIN_CHOLESKY_TILE_SIZE = 2048
MIN_CHOLESKY_TILE_SIZE = 4096
MIN_CHOLESKY_MATRIX_SIZE = 8192


Expand Down