From 4027de49e53f64365e5c2d45d09aa5d37a5657fc Mon Sep 17 00:00:00 2001 From: Roozbeh Karimi Date: Wed, 31 Aug 2022 16:29:56 -0700 Subject: [PATCH 1/5] Replacing the Thrust based CUDA scan with a custom kernel (work efficient parallel scan) to optimize performance especially in scan along axis for ND-arrays. Implementation not finished, likely buggy and also need to add NAN handling variant. --- src/cunumeric/scan/scan_local.cu | 170 +++++++++++++++++++++++++++++-- 1 file changed, 159 insertions(+), 11 deletions(-) diff --git a/src/cunumeric/scan/scan_local.cu b/src/cunumeric/scan/scan_local.cu index ddc0739514..949090f2ee 100644 --- a/src/cunumeric/scan/scan_local.cu +++ b/src/cunumeric/scan/scan_local.cu @@ -38,6 +38,159 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) sum_val[0] = out[0]; } +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + partition_sum(RES* out, RES* sum_val, const Pitches pitches, uint64_t len, uint64_t stride) +{ + unsigned int tid = threadIdx.x; + uint64_t blid = blockIdx.x * blockDim.x; + + uint64_t index = (blid + tid) * stride; + + if(index < len){ + auto sum_valp = pitches.unflatten(index, Point::ZEROES()); + sum_valp[DIM - 1] = 0; + sum_val[sum_valp] = out[index + stride - 1] + } +} + +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) +__global__ void cuda_add(, RES*B, uint64_t len, uint64_t stride, OP func, RES *block_sum) +{ + unsigned int tid = threadIdx.x; + unsigned int blid = blockIdx.x * blockDim.x; + + uint64_t pad_stride = stride; + bool must_copy = true; + if(stride & (stride-1)){ + pad_stride = 1 << (32 - __clz(stride)); + must_copy = (tid & (pad_stride-1)) < stride; + } + uint64_t blocks_per_batch; + bool last_block; + bool first_block; + + blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + first_block = (blockIdx.x) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if(remaining_batch == 0){ + remaining_batch = THREADS_PER_BLOCK; + } + must_copy = !last_block || (tid < remaining_batch); + + int pad_per_batch = pad_stride-stride; + + uint64_t idx0 = tid + blid; + + uint64_t batch_id = idx0 / pad_stride; + idx0 = idx0 - pad_per_batch * batch_id; + + if(idx0 < len && must_copy && !first_block){ + B[idx0] = func(block_sum[blockIdx.x - 1], B[idx0]); + } +} + +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + batch_scan_cuda(RES *A, RES*B, uint64_t len, uint64_t stride, OP func, RES identity, RES *block_sum) +{ + __shared__ RES temp[THREADS_PER_BLOCK]; + + unsigned int tid = threadIdx.x; + unsigned int blid = blockIdx.x * blockDim.x; + + uint64_t pad_stride = stride; + bool must_copy = true; + if(stride & (stride-1)){ + pad_stride = 1 << (32 - __clz(stride)); + must_copy = (tid & (pad_stride-1)) < stride; + } + bool last_block; + if(pad_stride > THREADS_PER_BLOCK){ + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if(remaining_batch == 0){ + remaining_batch = THREADS_PER_BLOCK; + } + must_copy = !last_block || (tid < remaining_batch); + } + + int pad_per_batch = pad_stride-stride; + int n_batches_block = THREADS_PER_BLOCK / pad_stride; + + uint64_t idx0 = tid + blid; + + uint64_t batch_id = idx0 / pad_stride; + idx0 = idx0 - pad_per_batch * batch_id; + + if(idx0 < len){ + temp[tid] = (must_copy) ? A[idx0] : identity; + __syncthreads(); + if (!n_batches_block) { + n_batches_block = 1; + pad_stride = THREADS_PER_BLOCK; + } + for (int j = 0; j < n_batches_block; j++) { + int offset = j * pad_stride; + for (int i = 1; i <= pad_stride; i <<= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if (index < (pad_stride)){ + temp[index_block] = func(temp[index_block - i], temp[index_block]); + } + __syncthreads(); + } + for(int i = pad_stride >> 1; i > 0; i >>= 1){ + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if((index + i) < (pad_stride)){ + temp[index_block + i] = func(temp[index_block], temp[index_block + i]); + } + __syncthreads(); + } + } + if(must_copy){ + B[idx0] = temp[tid]; + } + if(block_sum != nullptr && tid == THREADS_PER_BLOCK-1 && !last_block){ + block_sum[blockIdx.x] = temp[tid]; + } + } +} + +template +void cuda_scan(RES *A, RES*B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) +{ + assert(stride != 0); + uint64_t pad_stride = 1 << (32 - __builtin_clz(stride)); + if(pad_stride > THREADS_PER_BLOCK){ + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + } + uint64_t pad_len = (len / stride) * pad_stride; + uint64_t grid_dim = (pad_len-1) / THREADS_PER_BLOCK + 1; + + RES *blocked_sum = nullptr; + uint64_t blocked_len, blocked_stride; + if(stride > THREADS_PER_BLOCK){ + blocked_len = grid_dim; + blocked_stride = grid_dim / (len / stride); + CE( cudaMalloc(&blocked_sum, blocked_len * sizeof(RES)) ); + } + + batch_scan_cuda<<>>(A, B, len, stride, func, identity, blocked_sum); + + if(stride > THREADS_PER_BLOCK){ + cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); + cuda_add<<>>(B, len, stride, func, blocked_sum); + } +} + template struct ScanLocalImplBody { using OP = ScanOp; @@ -63,17 +216,12 @@ struct ScanLocalImplBody { auto sum_valsptr = sum_vals.create_output_buffer(extents, true); - for (uint64_t index = 0; index < volume; index += stride) { - thrust::inclusive_scan( - thrust::cuda::par.on(stream), inptr + index, inptr + index + stride, outptr + index, func); - // get the corresponding ND index with base zero to use for sum_val - auto sum_valp = pitches.unflatten(index, Point::ZEROES()); - // only one element on scan axis - sum_valp[DIM - 1] = 0; - // write out the partition sum - lazy_kernel<<<1, THREADS_PER_BLOCK, 0, stream>>>(&outptr[index + stride - 1], - &sum_valsptr[sum_valp]); - } + VAL identity = (VAL)ScanOp::nan_identity; + + cuda_scan(inptr, outptr, volume, stride, func, identity, stream); + + uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; + partition_sum<<>>(outptr, sum_valsptr, Pitches, volume, stride); CHECK_CUDA_STREAM(stream); } }; From 4ba5cfd6d1a21bd8781a99d0585352b103835b10 Mon Sep 17 00:00:00 2001 From: Roozbeh Karimi Date: Thu, 1 Sep 2022 14:45:05 -0700 Subject: [PATCH 2/5] bugfixes. The optimized scan builds and runs without errors. nancumsum and nancumprod still are still thrust based and need to be updated. --- src/cunumeric/scan/scan_local.cu | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/cunumeric/scan/scan_local.cu b/src/cunumeric/scan/scan_local.cu index 949090f2ee..3ea908b12b 100644 --- a/src/cunumeric/scan/scan_local.cu +++ b/src/cunumeric/scan/scan_local.cu @@ -40,7 +40,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) template static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - partition_sum(RES* out, RES* sum_val, const Pitches pitches, uint64_t len, uint64_t stride) + partition_sum(RES* out, Buffer sum_val, const Pitches pitches, uint64_t len, uint64_t stride) { unsigned int tid = threadIdx.x; uint64_t blid = blockIdx.x * blockDim.x; @@ -50,13 +50,13 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) if(index < len){ auto sum_valp = pitches.unflatten(index, Point::ZEROES()); sum_valp[DIM - 1] = 0; - sum_val[sum_valp] = out[index + stride - 1] + sum_val[sum_valp] = out[index + stride - 1]; } } template static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) -__global__ void cuda_add(, RES*B, uint64_t len, uint64_t stride, OP func, RES *block_sum) + cuda_add(RES *B, uint64_t len, uint64_t stride, OP func, RES *block_sum) { unsigned int tid = threadIdx.x; unsigned int blid = blockIdx.x * blockDim.x; @@ -95,7 +95,7 @@ __global__ void cuda_add(, RES*B, uint64_t len, uint64_t stride, OP func, RES *b template static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - batch_scan_cuda(RES *A, RES*B, uint64_t len, uint64_t stride, OP func, RES identity, RES *block_sum) + batch_scan_cuda(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, RES *block_sum) { __shared__ RES temp[THREADS_PER_BLOCK]; @@ -164,7 +164,7 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } template -void cuda_scan(RES *A, RES*B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) +void cuda_scan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) { assert(stride != 0); uint64_t pad_stride = 1 << (32 - __builtin_clz(stride)); @@ -180,10 +180,10 @@ void cuda_scan(RES *A, RES*B, uint64_t len, uint64_t stride, OP func, RES identi if(stride > THREADS_PER_BLOCK){ blocked_len = grid_dim; blocked_stride = grid_dim / (len / stride); - CE( cudaMalloc(&blocked_sum, blocked_len * sizeof(RES)) ); + CHECK_CUDA( cudaMalloc(&blocked_sum, blocked_len * sizeof(RES)) ); } - batch_scan_cuda<<>>(A, B, len, stride, func, identity, blocked_sum); + batch_scan_cuda<<>>(A, B, len, stride, func, identity, blocked_sum); if(stride > THREADS_PER_BLOCK){ cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); @@ -197,7 +197,7 @@ struct ScanLocalImplBody { using VAL = legate_type_of; void operator()(OP func, - const AccessorWO& out, + AccessorWO& out, const AccessorRO& in, Array& sum_vals, const Pitches& pitches, @@ -218,10 +218,10 @@ struct ScanLocalImplBody { VAL identity = (VAL)ScanOp::nan_identity; - cuda_scan(inptr, outptr, volume, stride, func, identity, stream); + cuda_scan(inptr, outptr, volume, stride, func, identity, stream); uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; - partition_sum<<>>(outptr, sum_valsptr, Pitches, volume, stride); + partition_sum<<>>(outptr, sum_valsptr, pitches, volume, stride); CHECK_CUDA_STREAM(stream); } }; From c28637c4174cb745227797dcdd576e9a4b318878 Mon Sep 17 00:00:00 2001 From: Roozbeh Karimi Date: Thu, 1 Sep 2022 16:00:32 -0700 Subject: [PATCH 3/5] Bringing back thrust call for the 1D case (depending on thrust version can be slightly faster on 1D) Minor bugfixes and adding error checks. --- src/cunumeric/scan/scan_local.cu | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/cunumeric/scan/scan_local.cu b/src/cunumeric/scan/scan_local.cu index 3ea908b12b..3fe14d185a 100644 --- a/src/cunumeric/scan/scan_local.cu +++ b/src/cunumeric/scan/scan_local.cu @@ -184,10 +184,16 @@ void cuda_scan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES } batch_scan_cuda<<>>(A, B, len, stride, func, identity, blocked_sum); + CHECK_CUDA_STREAM(stream); if(stride > THREADS_PER_BLOCK){ cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); cuda_add<<>>(B, len, stride, func, blocked_sum); + CHECK_CUDA_STREAM(stream); + } + + if(stride > THREADS_PER_BLOCK){ + CHECK_CUDA( cudaFree(blocked_sum) ); } } @@ -218,8 +224,13 @@ struct ScanLocalImplBody { VAL identity = (VAL)ScanOp::nan_identity; - cuda_scan(inptr, outptr, volume, stride, func, identity, stream); - + if(volume == stride){ + // Thrust is slightly faster for the 1D case + thrust::inclusive_scan(thrust::cuda::par.on(stream), inptr, inptr + stride, outptr, func); + } else { + cuda_scan(inptr, outptr, volume, stride, func, identity, stream); + } + uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; partition_sum<<>>(outptr, sum_valsptr, pitches, volume, stride); CHECK_CUDA_STREAM(stream); From 30cbd9867b7e9acf0d9411cadeaafc272bf7bbb2 Mon Sep 17 00:00:00 2001 From: Roozbeh Karimi Date: Thu, 1 Sep 2022 16:35:16 -0700 Subject: [PATCH 4/5] Changed cumsum/cumrprod implementation to use the optimized cuda ND-scan version. --- src/cunumeric/scan/scan_local.cu | 130 ++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 12 deletions(-) diff --git a/src/cunumeric/scan/scan_local.cu b/src/cunumeric/scan/scan_local.cu index 3fe14d185a..928508903f 100644 --- a/src/cunumeric/scan/scan_local.cu +++ b/src/cunumeric/scan/scan_local.cu @@ -163,6 +163,77 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + batch_scan_cuda_nan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, RES *block_sum) +{ + __shared__ RES temp[THREADS_PER_BLOCK]; + + unsigned int tid = threadIdx.x; + unsigned int blid = blockIdx.x * blockDim.x; + + uint64_t pad_stride = stride; + bool must_copy = true; + if(stride & (stride-1)){ + pad_stride = 1 << (32 - __clz(stride)); + must_copy = (tid & (pad_stride-1)) < stride; + } + bool last_block; + if(pad_stride > THREADS_PER_BLOCK){ + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if(remaining_batch == 0){ + remaining_batch = THREADS_PER_BLOCK; + } + must_copy = !last_block || (tid < remaining_batch); + } + + int pad_per_batch = pad_stride-stride; + int n_batches_block = THREADS_PER_BLOCK / pad_stride; + + uint64_t idx0 = tid + blid; + + uint64_t batch_id = idx0 / pad_stride; + idx0 = idx0 - pad_per_batch * batch_id; + + if(idx0 < len){ + RES val = (must_copy) ? A[idx0] : identity; + temp[tid] = cunumeric::is_nan(val) ? identity : val; + __syncthreads(); + if (!n_batches_block) { + n_batches_block = 1; + pad_stride = THREADS_PER_BLOCK; + } + for (int j = 0; j < n_batches_block; j++) { + int offset = j * pad_stride; + for (int i = 1; i <= pad_stride; i <<= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if (index < (pad_stride)){ + temp[index_block] = func(temp[index_block - i], temp[index_block]); + } + __syncthreads(); + } + for(int i = pad_stride >> 1; i > 0; i >>= 1){ + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if((index + i) < (pad_stride)){ + temp[index_block + i] = func(temp[index_block], temp[index_block + i]); + } + __syncthreads(); + } + } + if(must_copy){ + B[idx0] = temp[tid]; + } + if(block_sum != nullptr && tid == THREADS_PER_BLOCK-1 && !last_block){ + block_sum[blockIdx.x] = temp[tid]; + } + } +} + template void cuda_scan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) { @@ -197,6 +268,40 @@ void cuda_scan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES } } +template +void cuda_scan_nan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) +{ + assert(stride != 0); + uint64_t pad_stride = 1 << (32 - __builtin_clz(stride)); + if(pad_stride > THREADS_PER_BLOCK){ + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + } + uint64_t pad_len = (len / stride) * pad_stride; + uint64_t grid_dim = (pad_len-1) / THREADS_PER_BLOCK + 1; + + RES *blocked_sum = nullptr; + uint64_t blocked_len, blocked_stride; + if(stride > THREADS_PER_BLOCK){ + blocked_len = grid_dim; + blocked_stride = grid_dim / (len / stride); + CHECK_CUDA( cudaMalloc(&blocked_sum, blocked_len * sizeof(RES)) ); + } + + batch_scan_cuda_nan<<>>(A, B, len, stride, func, identity, blocked_sum); + CHECK_CUDA_STREAM(stream); + + if(stride > THREADS_PER_BLOCK){ + cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); + cuda_add<<>>(B, len, stride, func, blocked_sum); + CHECK_CUDA_STREAM(stream); + } + + if(stride > THREADS_PER_BLOCK){ + CHECK_CUDA( cudaFree(blocked_sum) ); + } +} + template struct ScanLocalImplBody { using OP = ScanOp; @@ -250,7 +355,7 @@ struct ScanLocalNanImplBody { }; void operator()(OP func, - const AccessorWO& out, + AccessorWO& out, const AccessorRO& in, Array& sum_vals, const Pitches& pitches, @@ -269,21 +374,22 @@ struct ScanLocalNanImplBody { auto sum_valsptr = sum_vals.create_output_buffer(extents, true); - for (uint64_t index = 0; index < volume; index += stride) { + VAL identity = (VAL)ScanOp::nan_identity; + + if(volume == stride){ + // Thrust is slightly faster for the 1D case thrust::inclusive_scan( thrust::cuda::par.on(stream), - thrust::make_transform_iterator(inptr + index, convert_nan_func()), - thrust::make_transform_iterator(inptr + index + stride, convert_nan_func()), - outptr + index, + thrust::make_transform_iterator(inptr, convert_nan_func()), + thrust::make_transform_iterator(inptr + stride, convert_nan_func()), + outptr, func); - // get the corresponding ND index with base zero to use for sum_val - auto sum_valp = pitches.unflatten(index, Point::ZEROES()); - // only one element on scan axis - sum_valp[DIM - 1] = 0; - // write out the partition sum - lazy_kernel<<<1, THREADS_PER_BLOCK, 0, stream>>>(&outptr[index + stride - 1], - &sum_valsptr[sum_valp]); + } else { + cuda_scan_nan(inptr, outptr, volume, stride, func, identity, stream); } + + uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; + partition_sum<<>>(outptr, sum_valsptr, pitches, volume, stride); CHECK_CUDA_STREAM(stream); } }; From f3893b52a90db9f0c8ae1bd1ec8421a8461a2011 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Sep 2022 23:40:08 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/cunumeric/scan/scan_local.cu | 259 +++++++++++++++---------------- 1 file changed, 125 insertions(+), 134 deletions(-) diff --git a/src/cunumeric/scan/scan_local.cu b/src/cunumeric/scan/scan_local.cu index 928508903f..078beae594 100644 --- a/src/cunumeric/scan/scan_local.cu +++ b/src/cunumeric/scan/scan_local.cu @@ -39,16 +39,16 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } template -static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - partition_sum(RES* out, Buffer sum_val, const Pitches pitches, uint64_t len, uint64_t stride) +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) partition_sum( + RES* out, Buffer sum_val, const Pitches pitches, uint64_t len, uint64_t stride) { unsigned int tid = threadIdx.x; - uint64_t blid = blockIdx.x * blockDim.x; + uint64_t blid = blockIdx.x * blockDim.x; uint64_t index = (blid + tid) * stride; - if(index < len){ - auto sum_valp = pitches.unflatten(index, Point::ZEROES()); + if (index < len) { + auto sum_valp = pitches.unflatten(index, Point::ZEROES()); sum_valp[DIM - 1] = 0; sum_val[sum_valp] = out[index + stride - 1]; } @@ -56,252 +56,242 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) template static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - cuda_add(RES *B, uint64_t len, uint64_t stride, OP func, RES *block_sum) + cuda_add(RES* B, uint64_t len, uint64_t stride, OP func, RES* block_sum) { - unsigned int tid = threadIdx.x; + unsigned int tid = threadIdx.x; unsigned int blid = blockIdx.x * blockDim.x; uint64_t pad_stride = stride; - bool must_copy = true; - if(stride & (stride-1)){ + bool must_copy = true; + if (stride & (stride - 1)) { pad_stride = 1 << (32 - __clz(stride)); - must_copy = (tid & (pad_stride-1)) < stride; + must_copy = (tid & (pad_stride - 1)) < stride; } uint64_t blocks_per_batch; bool last_block; bool first_block; - blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; - pad_stride = blocks_per_batch * THREADS_PER_BLOCK; - last_block = (blockIdx.x + 1) % blocks_per_batch == 0; - first_block = (blockIdx.x) % blocks_per_batch == 0; + blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + first_block = (blockIdx.x) % blocks_per_batch == 0; int remaining_batch = stride % THREADS_PER_BLOCK; - if(remaining_batch == 0){ - remaining_batch = THREADS_PER_BLOCK; - } + if (remaining_batch == 0) { remaining_batch = THREADS_PER_BLOCK; } must_copy = !last_block || (tid < remaining_batch); - int pad_per_batch = pad_stride-stride; + int pad_per_batch = pad_stride - stride; uint64_t idx0 = tid + blid; uint64_t batch_id = idx0 / pad_stride; - idx0 = idx0 - pad_per_batch * batch_id; + idx0 = idx0 - pad_per_batch * batch_id; - if(idx0 < len && must_copy && !first_block){ + if (idx0 < len && must_copy && !first_block) { B[idx0] = func(block_sum[blockIdx.x - 1], B[idx0]); } } template -static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - batch_scan_cuda(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, RES *block_sum) +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) batch_scan_cuda( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, RES* block_sum) { __shared__ RES temp[THREADS_PER_BLOCK]; - unsigned int tid = threadIdx.x; + unsigned int tid = threadIdx.x; unsigned int blid = blockIdx.x * blockDim.x; uint64_t pad_stride = stride; - bool must_copy = true; - if(stride & (stride-1)){ + bool must_copy = true; + if (stride & (stride - 1)) { pad_stride = 1 << (32 - __clz(stride)); - must_copy = (tid & (pad_stride-1)) < stride; + must_copy = (tid & (pad_stride - 1)) < stride; } bool last_block; - if(pad_stride > THREADS_PER_BLOCK){ + if (pad_stride > THREADS_PER_BLOCK) { uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; - pad_stride = blocks_per_batch * THREADS_PER_BLOCK; - last_block = (blockIdx.x + 1) % blocks_per_batch == 0; - int remaining_batch = stride % THREADS_PER_BLOCK; - if(remaining_batch == 0){ - remaining_batch = THREADS_PER_BLOCK; - } + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if (remaining_batch == 0) { remaining_batch = THREADS_PER_BLOCK; } must_copy = !last_block || (tid < remaining_batch); } - int pad_per_batch = pad_stride-stride; + int pad_per_batch = pad_stride - stride; int n_batches_block = THREADS_PER_BLOCK / pad_stride; uint64_t idx0 = tid + blid; uint64_t batch_id = idx0 / pad_stride; - idx0 = idx0 - pad_per_batch * batch_id; + idx0 = idx0 - pad_per_batch * batch_id; - if(idx0 < len){ + if (idx0 < len) { temp[tid] = (must_copy) ? A[idx0] : identity; __syncthreads(); if (!n_batches_block) { n_batches_block = 1; - pad_stride = THREADS_PER_BLOCK; + pad_stride = THREADS_PER_BLOCK; } for (int j = 0; j < n_batches_block; j++) { int offset = j * pad_stride; for (int i = 1; i <= pad_stride; i <<= 1) { - int index = ((tid + 1) * 2 * i - 1); - int index_block = offset + index; - if (index < (pad_stride)){ - temp[index_block] = func(temp[index_block - i], temp[index_block]); - } - __syncthreads(); + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if (index < (pad_stride)) { + temp[index_block] = func(temp[index_block - i], temp[index_block]); + } + __syncthreads(); } - for(int i = pad_stride >> 1; i > 0; i >>= 1){ - int index = ((tid + 1) * 2 * i - 1); - int index_block = offset + index; - if((index + i) < (pad_stride)){ - temp[index_block + i] = func(temp[index_block], temp[index_block + i]); - } - __syncthreads(); + for (int i = pad_stride >> 1; i > 0; i >>= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if ((index + i) < (pad_stride)) { + temp[index_block + i] = func(temp[index_block], temp[index_block + i]); + } + __syncthreads(); } } - if(must_copy){ - B[idx0] = temp[tid]; - } - if(block_sum != nullptr && tid == THREADS_PER_BLOCK-1 && !last_block){ + if (must_copy) { B[idx0] = temp[tid]; } + if (block_sum != nullptr && tid == THREADS_PER_BLOCK - 1 && !last_block) { block_sum[blockIdx.x] = temp[tid]; } } } template -static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - batch_scan_cuda_nan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, RES *block_sum) +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) batch_scan_cuda_nan( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, RES* block_sum) { __shared__ RES temp[THREADS_PER_BLOCK]; - unsigned int tid = threadIdx.x; + unsigned int tid = threadIdx.x; unsigned int blid = blockIdx.x * blockDim.x; uint64_t pad_stride = stride; - bool must_copy = true; - if(stride & (stride-1)){ + bool must_copy = true; + if (stride & (stride - 1)) { pad_stride = 1 << (32 - __clz(stride)); - must_copy = (tid & (pad_stride-1)) < stride; + must_copy = (tid & (pad_stride - 1)) < stride; } bool last_block; - if(pad_stride > THREADS_PER_BLOCK){ + if (pad_stride > THREADS_PER_BLOCK) { uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; - pad_stride = blocks_per_batch * THREADS_PER_BLOCK; - last_block = (blockIdx.x + 1) % blocks_per_batch == 0; - int remaining_batch = stride % THREADS_PER_BLOCK; - if(remaining_batch == 0){ - remaining_batch = THREADS_PER_BLOCK; - } + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if (remaining_batch == 0) { remaining_batch = THREADS_PER_BLOCK; } must_copy = !last_block || (tid < remaining_batch); } - int pad_per_batch = pad_stride-stride; + int pad_per_batch = pad_stride - stride; int n_batches_block = THREADS_PER_BLOCK / pad_stride; uint64_t idx0 = tid + blid; uint64_t batch_id = idx0 / pad_stride; - idx0 = idx0 - pad_per_batch * batch_id; + idx0 = idx0 - pad_per_batch * batch_id; - if(idx0 < len){ - RES val = (must_copy) ? A[idx0] : identity; + if (idx0 < len) { + RES val = (must_copy) ? A[idx0] : identity; temp[tid] = cunumeric::is_nan(val) ? identity : val; __syncthreads(); if (!n_batches_block) { n_batches_block = 1; - pad_stride = THREADS_PER_BLOCK; + pad_stride = THREADS_PER_BLOCK; } for (int j = 0; j < n_batches_block; j++) { int offset = j * pad_stride; for (int i = 1; i <= pad_stride; i <<= 1) { - int index = ((tid + 1) * 2 * i - 1); - int index_block = offset + index; - if (index < (pad_stride)){ - temp[index_block] = func(temp[index_block - i], temp[index_block]); - } - __syncthreads(); + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if (index < (pad_stride)) { + temp[index_block] = func(temp[index_block - i], temp[index_block]); + } + __syncthreads(); } - for(int i = pad_stride >> 1; i > 0; i >>= 1){ - int index = ((tid + 1) * 2 * i - 1); - int index_block = offset + index; - if((index + i) < (pad_stride)){ - temp[index_block + i] = func(temp[index_block], temp[index_block + i]); - } - __syncthreads(); + for (int i = pad_stride >> 1; i > 0; i >>= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if ((index + i) < (pad_stride)) { + temp[index_block + i] = func(temp[index_block], temp[index_block + i]); + } + __syncthreads(); } } - if(must_copy){ - B[idx0] = temp[tid]; - } - if(block_sum != nullptr && tid == THREADS_PER_BLOCK-1 && !last_block){ + if (must_copy) { B[idx0] = temp[tid]; } + if (block_sum != nullptr && tid == THREADS_PER_BLOCK - 1 && !last_block) { block_sum[blockIdx.x] = temp[tid]; } } } template -void cuda_scan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) +void cuda_scan( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) { assert(stride != 0); uint64_t pad_stride = 1 << (32 - __builtin_clz(stride)); - if(pad_stride > THREADS_PER_BLOCK){ + if (pad_stride > THREADS_PER_BLOCK) { uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; - pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; } - uint64_t pad_len = (len / stride) * pad_stride; - uint64_t grid_dim = (pad_len-1) / THREADS_PER_BLOCK + 1; + uint64_t pad_len = (len / stride) * pad_stride; + uint64_t grid_dim = (pad_len - 1) / THREADS_PER_BLOCK + 1; - RES *blocked_sum = nullptr; + RES* blocked_sum = nullptr; uint64_t blocked_len, blocked_stride; - if(stride > THREADS_PER_BLOCK){ - blocked_len = grid_dim; + if (stride > THREADS_PER_BLOCK) { + blocked_len = grid_dim; blocked_stride = grid_dim / (len / stride); - CHECK_CUDA( cudaMalloc(&blocked_sum, blocked_len * sizeof(RES)) ); + CHECK_CUDA(cudaMalloc(&blocked_sum, blocked_len * sizeof(RES))); } - batch_scan_cuda<<>>(A, B, len, stride, func, identity, blocked_sum); + batch_scan_cuda + <<>>(A, B, len, stride, func, identity, blocked_sum); CHECK_CUDA_STREAM(stream); - if(stride > THREADS_PER_BLOCK){ + if (stride > THREADS_PER_BLOCK) { cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); cuda_add<<>>(B, len, stride, func, blocked_sum); CHECK_CUDA_STREAM(stream); } - - if(stride > THREADS_PER_BLOCK){ - CHECK_CUDA( cudaFree(blocked_sum) ); - } + + if (stride > THREADS_PER_BLOCK) { CHECK_CUDA(cudaFree(blocked_sum)); } } - + template -void cuda_scan_nan(const RES *A, RES *B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) +void cuda_scan_nan( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) { assert(stride != 0); uint64_t pad_stride = 1 << (32 - __builtin_clz(stride)); - if(pad_stride > THREADS_PER_BLOCK){ + if (pad_stride > THREADS_PER_BLOCK) { uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; - pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; } - uint64_t pad_len = (len / stride) * pad_stride; - uint64_t grid_dim = (pad_len-1) / THREADS_PER_BLOCK + 1; + uint64_t pad_len = (len / stride) * pad_stride; + uint64_t grid_dim = (pad_len - 1) / THREADS_PER_BLOCK + 1; - RES *blocked_sum = nullptr; + RES* blocked_sum = nullptr; uint64_t blocked_len, blocked_stride; - if(stride > THREADS_PER_BLOCK){ - blocked_len = grid_dim; + if (stride > THREADS_PER_BLOCK) { + blocked_len = grid_dim; blocked_stride = grid_dim / (len / stride); - CHECK_CUDA( cudaMalloc(&blocked_sum, blocked_len * sizeof(RES)) ); + CHECK_CUDA(cudaMalloc(&blocked_sum, blocked_len * sizeof(RES))); } - batch_scan_cuda_nan<<>>(A, B, len, stride, func, identity, blocked_sum); + batch_scan_cuda_nan + <<>>(A, B, len, stride, func, identity, blocked_sum); CHECK_CUDA_STREAM(stream); - if(stride > THREADS_PER_BLOCK){ + if (stride > THREADS_PER_BLOCK) { cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); cuda_add<<>>(B, len, stride, func, blocked_sum); CHECK_CUDA_STREAM(stream); } - - if(stride > THREADS_PER_BLOCK){ - CHECK_CUDA( cudaFree(blocked_sum) ); - } + + if (stride > THREADS_PER_BLOCK) { CHECK_CUDA(cudaFree(blocked_sum)); } } - + template struct ScanLocalImplBody { using OP = ScanOp; @@ -329,15 +319,16 @@ struct ScanLocalImplBody { VAL identity = (VAL)ScanOp::nan_identity; - if(volume == stride){ + if (volume == stride) { // Thrust is slightly faster for the 1D case - thrust::inclusive_scan(thrust::cuda::par.on(stream), inptr, inptr + stride, outptr, func); + thrust::inclusive_scan(thrust::cuda::par.on(stream), inptr, inptr + stride, outptr, func); } else { cuda_scan(inptr, outptr, volume, stride, func, identity, stream); } - + uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; - partition_sum<<>>(outptr, sum_valsptr, pitches, volume, stride); + partition_sum<<>>( + outptr, sum_valsptr, pitches, volume, stride); CHECK_CUDA_STREAM(stream); } }; @@ -376,20 +367,20 @@ struct ScanLocalNanImplBody { VAL identity = (VAL)ScanOp::nan_identity; - if(volume == stride){ + if (volume == stride) { // Thrust is slightly faster for the 1D case - thrust::inclusive_scan( - thrust::cuda::par.on(stream), - thrust::make_transform_iterator(inptr, convert_nan_func()), - thrust::make_transform_iterator(inptr + stride, convert_nan_func()), - outptr, - func); + thrust::inclusive_scan(thrust::cuda::par.on(stream), + thrust::make_transform_iterator(inptr, convert_nan_func()), + thrust::make_transform_iterator(inptr + stride, convert_nan_func()), + outptr, + func); } else { cuda_scan_nan(inptr, outptr, volume, stride, func, identity, stream); } - + uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; - partition_sum<<>>(outptr, sum_valsptr, pitches, volume, stride); + partition_sum<<>>( + outptr, sum_valsptr, pitches, volume, stride); CHECK_CUDA_STREAM(stream); } };