-
Notifications
You must be signed in to change notification settings - Fork 278
Description
Description
When a tensor's element_space_size (= (lengths[i]-1) * strides[i] summed across dimensions) exceeds uint32_max, calculate_element_space_size_impl overflows because lengths and strides are index_t = int32_t. This corrupts buffer_size_ in buffer_view, causing AMD GPU buffer loads (via SRD) to silently return zeros for valid memory accesses.
Discovered in mha_batch_prefill kernel via ROCm/aiter#2517, but the bug is in CK's core tensor infrastructure.
Affected Code
include/ck_tile/core/tensor/tensor_descriptor.hpp, calculate_element_space_size_impl:
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
// ^^^^^^^^^ int32 × ^^^^^^^^^ int32 → overflow!The accumulator acc_old starts as long_number<1> (int64), but (lengths[i] - 1) * strides[i] is computed as int32 × int32 = int32 before promotion. When the product exceeds 2^32, it wraps to a small value.
Impact Chain
element_space_sizewraps to small value (e.g., 6.5M instead of 4.3B)buffer_view::buffer_size_stores the wrong value- SRD
range=buffer_size_ × sizeof(T)= 12 MB instead of 8.6 GB - AMD GPU: any buffer load at byte offset > 12 MB → silent return zero (OOB behavior)
Minimal Trigger
Any 3D tensor with (dim0 - 1) × stride0 > 2^32:
- Example: shape=[8,401,239, 8, 64], stride=[512, 64, 1] →
8,401,238 × 512 = 4.3B > 4.29B - Threshold:
dim0 > 2^32 / stride0(e.g.,> 8,388,608for stride=512)
Suggested Fix
Cast to long_index_t before multiplication:
auto acc_new = acc_old + static_cast<long_index_t>(lengths[i] - number<1>{})
* static_cast<long_index_t>(strides[i]);Also: buffer_view::init_raw() and amd_buffer_addressing.hpp non-raw load paths pass element_space_size * sizeof(T) to uint32_t SRD range — needs capping to 0xFFFFFFFF for byte sizes > 4 GB.
Versions Affected
Verified on CK commit eb033ef20 (aiter main) and tag rocm-7.2.1.