Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/OpMathType.h>
#include <c10/macros/Macros.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <array>
#include <functional>
#include <iosfwd>
Expand Down Expand Up @@ -220,7 +221,9 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);

template<int nt, int output_vec_size, typename R>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void reduce_kernel(R reduction) {
__global__ void reduce_kernel(R reduction, TORCH_DSA_KERNEL_ARGS) {
reduction.assertions_data = assertions_data;
reduction.assertion_caller_id = assertion_caller_id;
reduction.template run<output_vec_size>();
}

Expand Down Expand Up @@ -367,6 +370,12 @@ struct ReduceOp {
bool final_output;
int noutputs;

// DSA context propagated from the launching kernel; using the exact names
// expected by CUDA_KERNEL_ASSERT2 so they resolve via implicit member access
// when the macro is used inside any ReduceOp member function.
c10::cuda::DeviceAssertionsData* assertions_data = nullptr;
uint32_t assertion_caller_id = 0;

ReduceOp(
ops_t ops,
ReduceConfig config,
Expand Down Expand Up @@ -479,7 +488,8 @@ struct ReduceOp {
template <int output_vec_size>
C10_DEVICE std::array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
if (config.vectorize_input) {
CUDA_KERNEL_ASSERT(output_vec_size == 1);
CUDA_KERNEL_ASSERT2_RET(
output_vec_size == 1, (std::array<arg_t, output_vec_size>{}));
// reduce at the header of input_slice where memory is not aligned,
// so that thread_reduce will have an aligned memory to work on.
return {input_vectorized_thread_reduce_impl(data)};
Expand Down Expand Up @@ -722,7 +732,7 @@ struct ReduceOp {
out_scalar_t* out, arg_t value,
typename std::enable_if_t<can_acc>* = nullptr
) const {
CUDA_KERNEL_ASSERT(!final_output);
CUDA_KERNEL_ASSERT2_RET(!final_output, out_scalar_t{});
return (out_scalar_t)value;
}

Expand All @@ -735,7 +745,7 @@ struct ReduceOp {
std::array<arg_t, output_vec_size>,
typename std::enable_if_t<!can_acc>* = nullptr
) const {
CUDA_KERNEL_ASSERT(false);
CUDA_KERNEL_ASSERT2_RET(false, (std::array<arg_t, output_vec_size>{}));
return {arg_t{}};
}

Expand All @@ -747,13 +757,13 @@ struct ReduceOp {
out_scalar_t* out, arg_t value,
typename std::enable_if_t<!can_acc>* = nullptr
) const {
CUDA_KERNEL_ASSERT(false);
CUDA_KERNEL_ASSERT2_RET(false, out_scalar_t{});
return *out;
}

template<class T>
C10_DEVICE void set_results(const T x, const index_t base_offset) const {
CUDA_KERNEL_ASSERT(noutputs == 1);
CUDA_KERNEL_ASSERT2(noutputs == 1);
auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
*res = x;
}
Expand All @@ -775,7 +785,7 @@ struct ReduceOp {

template <int output_vec_size>
C10_DEVICE void set_results_to_output(std::array<arg_t, output_vec_size> value, std::array<index_t, output_vec_size> base_offset) const {
CUDA_KERNEL_ASSERT(final_output);
CUDA_KERNEL_ASSERT2(final_output);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
set_results(ops.project(value[i]), base_offset[i]);
Expand Down Expand Up @@ -919,16 +929,19 @@ static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction)

switch(config.output_vec_size) {
case 4:
reduce_kernel<max_threads / 4, 4, R><<<grid, block, shared_memory, stream>>>(reduction);
C10_CUDA_KERNEL_LAUNCH_CHECK();
TORCH_DSA_KERNEL_LAUNCH_T(
(reduce_kernel<max_threads / 4, 4, R>),
grid, block, shared_memory, stream, reduction);
break;
case 2:
reduce_kernel<max_threads / 2, 2, R><<<grid, block, shared_memory, stream>>>(reduction);
C10_CUDA_KERNEL_LAUNCH_CHECK();
TORCH_DSA_KERNEL_LAUNCH_T(
(reduce_kernel<max_threads / 2, 2, R>),
grid, block, shared_memory, stream, reduction);
break;
default:
reduce_kernel<max_threads / 1, 1, R><<<grid, block, shared_memory, stream>>>(reduction);
C10_CUDA_KERNEL_LAUNCH_CHECK();
TORCH_DSA_KERNEL_LAUNCH_T(
(reduce_kernel<max_threads / 1, 1, R>),
grid, block, shared_memory, stream, reduction);
}
}

Expand Down
33 changes: 33 additions & 0 deletions c10/cuda/CUDADeviceAssertion.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,41 @@ C10_CLANG_DIAGNOSTIC_POP()
return; \
} \
} while (false)

// Variant of CUDA_KERNEL_ASSERT2 for use inside non-void-returning device
// functions. On assertion failure, records the failure into the DSA registry
// and returns `ret_expr` from the enclosing function. As with
// CUDA_KERNEL_ASSERT2, the kernel as a whole continues running; the host
// detects the failure via UVM and any data produced after the failure should
// be treated as garbage.
//
// NOTE: This assumes that `assertions_data` and `assertion_caller_id` are
// accessible at the call site (either as kernel parameters or via
// implicit member access in a member function).
#define CUDA_KERNEL_ASSERT2_RET(condition, ret_expr) \
do { \
if (C10_UNLIKELY(!(condition))) { \
c10::cuda::dsa_add_new_assertion_failure( \
assertions_data, \
C10_STRINGIZE(condition), \
__FILE__, \
__FUNCTION__, \
__LINE__, \
assertion_caller_id, \
blockIdx, \
threadIdx); \
return (ret_expr); \
} \
} while (false)
#else
#define CUDA_KERNEL_ASSERT2(condition) assert(condition)
#define CUDA_KERNEL_ASSERT2_RET(condition, ret_expr) \
do { \
if (C10_UNLIKELY(!(condition))) { \
assert(condition); \
return (ret_expr); \
} \
} while (false)
#endif

} // namespace c10::cuda
31 changes: 31 additions & 0 deletions c10/cuda/CUDAException.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,37 @@ class C10_CUDA_API CUDAError : public c10::Error {
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} while (0)

// Variant of TORCH_DSA_KERNEL_LAUNCH that accepts a *parenthesized* kernel
// expression as the first argument. This lets you launch a kernel whose name
// contains commas at the top level (e.g. a function template-id with more
// than one template argument such as `foo<16, T, U>`), which the standard
// macro cannot accept because the preprocessor would split the template
// argument list on those commas.
//
// Example:
// TORCH_DSA_KERNEL_LAUNCH_T(
// (reduce_kernel<max_threads / 4, 4, R>),
// grid, block, shared_memory, stream, reduction);
#define TORCH_DSA_KERNEL_LAUNCH_T_UNWRAP(...) __VA_ARGS__
#define TORCH_DSA_KERNEL_LAUNCH_T( \
kernel_paren, blocks, threads, shared_mem, stream, ...) \
do { \
auto& launch_registry = \
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \
TORCH_DSA_KERNEL_LAUNCH_T_UNWRAP kernel_paren \
<<<blocks, threads, shared_mem, stream>>>( \
__VA_ARGS__, \
launch_registry \
.get_uvm_assertions_ptr_for_current_device(), \
launch_registry.insert( \
__FILE__, \
__FUNCTION__, \
__LINE__, \
#kernel_paren, \
stream.id())); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} while (0)

namespace c10::cuda {

/// In the event of a CUDA failure, formats a nice error message about that
Expand Down