diff --git a/src/cunumeric/scan/scan_local.cu b/src/cunumeric/scan/scan_local.cu index ddc0739514..078beae594 100644 --- a/src/cunumeric/scan/scan_local.cu +++ b/src/cunumeric/scan/scan_local.cu @@ -38,13 +38,267 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) sum_val[0] = out[0]; } +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) partition_sum( + RES* out, Buffer sum_val, const Pitches pitches, uint64_t len, uint64_t stride) +{ + unsigned int tid = threadIdx.x; + uint64_t blid = blockIdx.x * blockDim.x; + + uint64_t index = (blid + tid) * stride; + + if (index < len) { + auto sum_valp = pitches.unflatten(index, Point::ZEROES()); + sum_valp[DIM - 1] = 0; + sum_val[sum_valp] = out[index + stride - 1]; + } +} + +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + cuda_add(RES* B, uint64_t len, uint64_t stride, OP func, RES* block_sum) +{ + unsigned int tid = threadIdx.x; + unsigned int blid = blockIdx.x * blockDim.x; + + uint64_t pad_stride = stride; + bool must_copy = true; + if (stride & (stride - 1)) { + pad_stride = 1 << (32 - __clz(stride)); + must_copy = (tid & (pad_stride - 1)) < stride; + } + uint64_t blocks_per_batch; + bool last_block; + bool first_block; + + blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + first_block = (blockIdx.x) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if (remaining_batch == 0) { remaining_batch = THREADS_PER_BLOCK; } + must_copy = !last_block || (tid < remaining_batch); + + int pad_per_batch = pad_stride - stride; + + uint64_t idx0 = tid + blid; + + uint64_t batch_id = idx0 / pad_stride; + idx0 = idx0 - pad_per_batch * batch_id; + + if (idx0 < len && must_copy && !first_block) { + B[idx0] = func(block_sum[blockIdx.x - 1], B[idx0]); + } +} + +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) batch_scan_cuda( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, RES* block_sum) +{ + __shared__ RES temp[THREADS_PER_BLOCK]; + + unsigned int tid = threadIdx.x; + unsigned int blid = blockIdx.x * blockDim.x; + + uint64_t pad_stride = stride; + bool must_copy = true; + if (stride & (stride - 1)) { + pad_stride = 1 << (32 - __clz(stride)); + must_copy = (tid & (pad_stride - 1)) < stride; + } + bool last_block; + if (pad_stride > THREADS_PER_BLOCK) { + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if (remaining_batch == 0) { remaining_batch = THREADS_PER_BLOCK; } + must_copy = !last_block || (tid < remaining_batch); + } + + int pad_per_batch = pad_stride - stride; + int n_batches_block = THREADS_PER_BLOCK / pad_stride; + + uint64_t idx0 = tid + blid; + + uint64_t batch_id = idx0 / pad_stride; + idx0 = idx0 - pad_per_batch * batch_id; + + if (idx0 < len) { + temp[tid] = (must_copy) ? A[idx0] : identity; + __syncthreads(); + if (!n_batches_block) { + n_batches_block = 1; + pad_stride = THREADS_PER_BLOCK; + } + for (int j = 0; j < n_batches_block; j++) { + int offset = j * pad_stride; + for (int i = 1; i <= pad_stride; i <<= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if (index < (pad_stride)) { + temp[index_block] = func(temp[index_block - i], temp[index_block]); + } + __syncthreads(); + } + for (int i = pad_stride >> 1; i > 0; i >>= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if ((index + i) < (pad_stride)) { + temp[index_block + i] = func(temp[index_block], temp[index_block + i]); + } + __syncthreads(); + } + } + if (must_copy) { B[idx0] = temp[tid]; } + if (block_sum != nullptr && tid == THREADS_PER_BLOCK - 1 && !last_block) { + block_sum[blockIdx.x] = temp[tid]; + } + } +} + +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) batch_scan_cuda_nan( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, RES* block_sum) +{ + __shared__ RES temp[THREADS_PER_BLOCK]; + + unsigned int tid = threadIdx.x; + unsigned int blid = blockIdx.x * blockDim.x; + + uint64_t pad_stride = stride; + bool must_copy = true; + if (stride & (stride - 1)) { + pad_stride = 1 << (32 - __clz(stride)); + must_copy = (tid & (pad_stride - 1)) < stride; + } + bool last_block; + if (pad_stride > THREADS_PER_BLOCK) { + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + last_block = (blockIdx.x + 1) % blocks_per_batch == 0; + int remaining_batch = stride % THREADS_PER_BLOCK; + if (remaining_batch == 0) { remaining_batch = THREADS_PER_BLOCK; } + must_copy = !last_block || (tid < remaining_batch); + } + + int pad_per_batch = pad_stride - stride; + int n_batches_block = THREADS_PER_BLOCK / pad_stride; + + uint64_t idx0 = tid + blid; + + uint64_t batch_id = idx0 / pad_stride; + idx0 = idx0 - pad_per_batch * batch_id; + + if (idx0 < len) { + RES val = (must_copy) ? A[idx0] : identity; + temp[tid] = cunumeric::is_nan(val) ? identity : val; + __syncthreads(); + if (!n_batches_block) { + n_batches_block = 1; + pad_stride = THREADS_PER_BLOCK; + } + for (int j = 0; j < n_batches_block; j++) { + int offset = j * pad_stride; + for (int i = 1; i <= pad_stride; i <<= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if (index < (pad_stride)) { + temp[index_block] = func(temp[index_block - i], temp[index_block]); + } + __syncthreads(); + } + for (int i = pad_stride >> 1; i > 0; i >>= 1) { + int index = ((tid + 1) * 2 * i - 1); + int index_block = offset + index; + if ((index + i) < (pad_stride)) { + temp[index_block + i] = func(temp[index_block], temp[index_block + i]); + } + __syncthreads(); + } + } + if (must_copy) { B[idx0] = temp[tid]; } + if (block_sum != nullptr && tid == THREADS_PER_BLOCK - 1 && !last_block) { + block_sum[blockIdx.x] = temp[tid]; + } + } +} + +template +void cuda_scan( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) +{ + assert(stride != 0); + uint64_t pad_stride = 1 << (32 - __builtin_clz(stride)); + if (pad_stride > THREADS_PER_BLOCK) { + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + } + uint64_t pad_len = (len / stride) * pad_stride; + uint64_t grid_dim = (pad_len - 1) / THREADS_PER_BLOCK + 1; + + RES* blocked_sum = nullptr; + uint64_t blocked_len, blocked_stride; + if (stride > THREADS_PER_BLOCK) { + blocked_len = grid_dim; + blocked_stride = grid_dim / (len / stride); + CHECK_CUDA(cudaMalloc(&blocked_sum, blocked_len * sizeof(RES))); + } + + batch_scan_cuda + <<>>(A, B, len, stride, func, identity, blocked_sum); + CHECK_CUDA_STREAM(stream); + + if (stride > THREADS_PER_BLOCK) { + cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); + cuda_add<<>>(B, len, stride, func, blocked_sum); + CHECK_CUDA_STREAM(stream); + } + + if (stride > THREADS_PER_BLOCK) { CHECK_CUDA(cudaFree(blocked_sum)); } +} + +template +void cuda_scan_nan( + const RES* A, RES* B, uint64_t len, uint64_t stride, OP func, RES identity, cudaStream_t stream) +{ + assert(stride != 0); + uint64_t pad_stride = 1 << (32 - __builtin_clz(stride)); + if (pad_stride > THREADS_PER_BLOCK) { + uint64_t blocks_per_batch = (stride - 1) / THREADS_PER_BLOCK + 1; + pad_stride = blocks_per_batch * THREADS_PER_BLOCK; + } + uint64_t pad_len = (len / stride) * pad_stride; + uint64_t grid_dim = (pad_len - 1) / THREADS_PER_BLOCK + 1; + + RES* blocked_sum = nullptr; + uint64_t blocked_len, blocked_stride; + if (stride > THREADS_PER_BLOCK) { + blocked_len = grid_dim; + blocked_stride = grid_dim / (len / stride); + CHECK_CUDA(cudaMalloc(&blocked_sum, blocked_len * sizeof(RES))); + } + + batch_scan_cuda_nan + <<>>(A, B, len, stride, func, identity, blocked_sum); + CHECK_CUDA_STREAM(stream); + + if (stride > THREADS_PER_BLOCK) { + cuda_scan(blocked_sum, blocked_sum, blocked_len, blocked_stride, func, identity, stream); + cuda_add<<>>(B, len, stride, func, blocked_sum); + CHECK_CUDA_STREAM(stream); + } + + if (stride > THREADS_PER_BLOCK) { CHECK_CUDA(cudaFree(blocked_sum)); } +} + template struct ScanLocalImplBody { using OP = ScanOp; using VAL = legate_type_of; void operator()(OP func, - const AccessorWO& out, + AccessorWO& out, const AccessorRO& in, Array& sum_vals, const Pitches& pitches, @@ -63,17 +317,18 @@ struct ScanLocalImplBody { auto sum_valsptr = sum_vals.create_output_buffer(extents, true); - for (uint64_t index = 0; index < volume; index += stride) { - thrust::inclusive_scan( - thrust::cuda::par.on(stream), inptr + index, inptr + index + stride, outptr + index, func); - // get the corresponding ND index with base zero to use for sum_val - auto sum_valp = pitches.unflatten(index, Point::ZEROES()); - // only one element on scan axis - sum_valp[DIM - 1] = 0; - // write out the partition sum - lazy_kernel<<<1, THREADS_PER_BLOCK, 0, stream>>>(&outptr[index + stride - 1], - &sum_valsptr[sum_valp]); + VAL identity = (VAL)ScanOp::nan_identity; + + if (volume == stride) { + // Thrust is slightly faster for the 1D case + thrust::inclusive_scan(thrust::cuda::par.on(stream), inptr, inptr + stride, outptr, func); + } else { + cuda_scan(inptr, outptr, volume, stride, func, identity, stream); } + + uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; + partition_sum<<>>( + outptr, sum_valsptr, pitches, volume, stride); CHECK_CUDA_STREAM(stream); } }; @@ -91,7 +346,7 @@ struct ScanLocalNanImplBody { }; void operator()(OP func, - const AccessorWO& out, + AccessorWO& out, const AccessorRO& in, Array& sum_vals, const Pitches& pitches, @@ -110,21 +365,22 @@ struct ScanLocalNanImplBody { auto sum_valsptr = sum_vals.create_output_buffer(extents, true); - for (uint64_t index = 0; index < volume; index += stride) { - thrust::inclusive_scan( - thrust::cuda::par.on(stream), - thrust::make_transform_iterator(inptr + index, convert_nan_func()), - thrust::make_transform_iterator(inptr + index + stride, convert_nan_func()), - outptr + index, - func); - // get the corresponding ND index with base zero to use for sum_val - auto sum_valp = pitches.unflatten(index, Point::ZEROES()); - // only one element on scan axis - sum_valp[DIM - 1] = 0; - // write out the partition sum - lazy_kernel<<<1, THREADS_PER_BLOCK, 0, stream>>>(&outptr[index + stride - 1], - &sum_valsptr[sum_valp]); + VAL identity = (VAL)ScanOp::nan_identity; + + if (volume == stride) { + // Thrust is slightly faster for the 1D case + thrust::inclusive_scan(thrust::cuda::par.on(stream), + thrust::make_transform_iterator(inptr, convert_nan_func()), + thrust::make_transform_iterator(inptr + stride, convert_nan_func()), + outptr, + func); + } else { + cuda_scan_nan(inptr, outptr, volume, stride, func, identity, stream); } + + uint64_t grid_dim = ((volume / stride) - 1) / THREADS_PER_BLOCK + 1; + partition_sum<<>>( + outptr, sum_valsptr, pitches, volume, stride); CHECK_CUDA_STREAM(stream); } };