diff --git a/cunumeric/linalg/cholesky.py b/cunumeric/linalg/cholesky.py index 9bba033619..4cad697442 100644 --- a/cunumeric/linalg/cholesky.py +++ b/cunumeric/linalg/cholesky.py @@ -32,6 +32,23 @@ from ..runtime import Runtime +def get_gpu_lower_triangular(point): + assert len(point) == 2 + from legate.core import get_machine + from legate.core.machine import ProcessorKind + + machine = get_machine() + num_gpus = machine.count(ProcessorKind.GPU) + gpus = machine.only(ProcessorKind.GPU) + + # The linearized block-cyclic lower-triangular-blocked decomposition + # mapping of 2-d points to blocks is: + def mapping(i, j): + return (j + (i * (i + 1)) // 2) % num_gpus + + return gpus[mapping(point[0], point[1])] + + def transpose_copy_single( context: Context, input: Store, output: Store ) -> None: @@ -83,7 +100,8 @@ def potrf(context: Context, p_output: StorePartition, i: int) -> None: task.throws_exception(LinAlgError) task.add_output(p_output) task.add_input(p_output) - task.execute() + with get_gpu_lower_triangular((i, i)): + task.execute() def trsm( @@ -95,14 +113,16 @@ def trsm( rhs = p_output.get_child_store(i, i) lhs = p_output - launch_domain = Rect(lo=(lo, i), hi=(hi, i + 1)) - task = context.create_manual_task( - CuNumericOpCode.TRSM, launch_domain=launch_domain - ) - task.add_output(lhs) - task.add_input(rhs) - task.add_input(lhs) - task.execute() + for point in Rect(lo=(lo, i), hi=(hi, i + 1)): + task = context.create_manual_task( + CuNumericOpCode.TRSM, + launch_domain=Rect(lo=point, hi=point, exclusive=False), + ) + task.add_output(lhs) + task.add_input(rhs) + task.add_input(lhs) + with get_gpu_lower_triangular(point): + task.execute() def syrk(context: Context, p_output: StorePartition, k: int, i: int) -> None: @@ -116,7 +136,8 @@ def syrk(context: Context, p_output: StorePartition, k: int, i: int) -> None: task.add_output(lhs) task.add_input(rhs) task.add_input(lhs) - task.execute() + with get_gpu_lower_triangular((k, k)): + task.execute() def gemm( @@ -134,18 +155,20 @@ def gemm( lhs = p_output rhs1 = p_output - launch_domain = Rect(lo=(lo, k), hi=(hi, k + 1)) - task = context.create_manual_task( - CuNumericOpCode.GEMM, launch_domain=launch_domain - ) - task.add_output(lhs) - task.add_input(rhs1, proj=lambda p: (p[0], i)) - task.add_input(rhs2) - task.add_input(lhs) - task.execute() + for point in Rect(lo=(lo, k), hi=(hi, k + 1)): + task = context.create_manual_task( + CuNumericOpCode.GEMM, + launch_domain=Rect(lo=point, hi=point, exclusive=False), + ) + task.add_output(lhs) + task.add_input(rhs1, proj=lambda p: (p[0], i)) + task.add_input(rhs2) + task.add_input(lhs) + with get_gpu_lower_triangular(point): + task.execute() -MIN_CHOLESKY_TILE_SIZE = 2048 +MIN_CHOLESKY_TILE_SIZE = 4096 MIN_CHOLESKY_MATRIX_SIZE = 8192