diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 879fb31ca59..53af4ce211f 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -186,8 +186,10 @@ // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting. Don't use this hack unless absolutely necessary! // FIXME: make the behavior of buffer load a configurable (template) parameter for each usage +// FIX: Enable offset trick to prevent invalid loads from crashing on gfx906/MI50 +// Without this, invalid loads still execute and crash despite bounds checking #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 #endif #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp index c486b124237..da36a8f206f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -418,6 +418,12 @@ struct GridwiseGemmDlMultipleD_km_kn_mn auto c_thread_buf = make_static_buffer( c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + // FIX: Separate buffer for element op output with proper type (FloatC) + // This is needed when FloatAcc (e.g., int32) differs from FloatC (e.g., float) + // The element op e = static_cast(c) * d expects to write to FloatC, not FloatAcc + auto e_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + // Initialize C c_thread_buf.Clear(); @@ -621,13 +627,22 @@ struct GridwiseGemmDlMultipleD_km_kn_mn Number{}); // get reference to dst data + // FIX: Use e_thread_buf (FloatC) for output, c_thread_buf + // (FloatAcc) for input. This fixes type mismatch when FloatAcc + // (int32) != FloatC (float) constexpr index_t c_offset = c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset( make_tuple(0, m10, m11, 0, n10, i)); - auto dst_data_refs = generate_tie( - // return type should be lvalue - [&](auto) -> auto& { return c_thread_buf(Number{}); }, - Number<2>{}); + // Element op signature: (E& e, const C& c, const D& d) + // - e (output): goes to e_thread_buf (FloatC type) + // - c (input): comes from c_thread_buf (FloatAcc type) + // Use tie() to create a tuple of references to different buffers + auto dst_data_refs = + tie(e_thread_buf( + Number{}), // E& e (output to FloatC buffer) + c_thread_buf( + Number{}) // C& c (input from FloatAcc buffer) + ); unpack2(cde_element_op, dst_data_refs, src_data_refs); }); @@ -653,8 +668,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn }); }); + // FIX: Transfer from e_thread_buf (FloatC) instead of c_thread_buf (FloatAcc) + // since element op output is now stored in e_thread_buf with proper type ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, + FloatC, // FIX: Source is now FloatC (e_thread_buf) FloatC, decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), @@ -680,7 +697,7 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ck::tensor_operation::element_wise::PassThrough{}} .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, + e_thread_buf, // FIX: Use e_thread_buf instead of c_thread_buf c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_buf); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index bce2d453dce..0e8f9b0fd94 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -177,8 +177,10 @@ struct ThreadwiseTensorSliceTransfer_v5r1 using src_vector_t = typename decltype(src_vector)::type; + // FIX: Use full bounds check including visible index to prevent OOB access + // when K0 coordinate exceeds tensor bounds const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + coordinate_has_valid_offset(src_desc, src_coord_); // copy data from src_buf to src_vector src_vector.template AsType()(I0) =