From 14e3b75f3f76117c3862c319a8ea624b17908c77 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Jan 2026 10:53:45 -0600 Subject: [PATCH 1/4] GEMMTestSuite: use rocrand for input data generation --- tests/cpp/operator/CMakeLists.txt | 2 +- tests/cpp/test_common.cu | 50 +++++++++++++++++++++++++++++++ tests/cpp/util/CMakeLists.txt | 2 +- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e3af4a360..fa51bee19 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -74,7 +74,7 @@ if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() - target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX) + target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX rocrand) endif() target_compile_options(test_operator PRIVATE -O2 -fopenmp) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index a608f6ef2..f29d5b673 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -822,21 +822,71 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { #endif } +#ifdef __HIP_PLATFORM_AMD__ +#include + +template +__global__ void affine_transform_and_cast(float* __restrict__ in, T* __restrict__ out, size_t n, float lo, float hi) { + // Clamp values in *in* to [lo, hi] and cast to type *T* for *out*. + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + in[idx] = lo + (hi - lo) * in[idx]; + out[idx] = static_cast(in[idx]); + } +} + +void fillUniformDevice(Tensor* t) { + void* dst = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); + const auto shape = t->rowwise() ? t->rowwise_shape() : t->columnwise_shape(); + const size_t N = product(shape); + + float* tmp = nullptr; + hipMalloc(&tmp, N * sizeof(float)); + + // per-tensor deterministic seed + const unsigned long long seed = static_cast(t->gen()()); + rocrand_generator gen; + rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); + rocrand_set_seed(gen, seed); + + rocrand_generate_uniform(gen, tmp, N); + + // map to [-2, 1] (like generate_data_uniformly) and cast into tensor dtype + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { + dim3 block(256); + dim3 grid((N + block.x - 1) / block.x); + hipLaunchKernelGGL(affine_transform_and_cast, grid, block, 0, 0, + tmp, reinterpret_cast(dst), N, -2.0f, 1.0f); + }); + + rocrand_destroy_generator(gen); + hipFree(tmp); +} +#endif + void fillUniform(Tensor *t) { if (t->rowwise()) { const size_t size = product(t->rowwise_shape()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformDevice(t); +#else T *data = t->rowwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); +#endif } ); } else { const size_t size = product(t->columnwise_shape()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformDevice(t); +#else T *data = t->columnwise_cpu_dptr(); generate_data_uniformly(data, size, &(t->gen())); +#endif } ); } diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index e80ebffbc..5d494a6d4 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -21,7 +21,7 @@ find_package(OpenMP REQUIRED) if(USE_CUDA) target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) else() -target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX) +target_link_libraries(test_util PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX rocrand) endif() target_compile_options(test_util PRIVATE -O2 -fopenmp) From 0d4d62fbd7ed7b9ed1c775060b25371a9ad4e4c7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Jan 2026 11:19:25 -0600 Subject: [PATCH 2/4] adjust comments --- tests/cpp/test_common.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index f29d5b673..636b6e516 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -783,7 +783,6 @@ std::pair getTolerances(const DType type) { template void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { #ifdef __HIP_PLATFORM_AMD__ - // TODO: Introduce a parallel RNG library (Random123, PCG, rocRAND) std::uniform_real_distribution<> dis(-2.0, 1.0); for (int i = 0; i < size; i++) { data[i] = static_cast(dis(*gen)); @@ -851,7 +850,7 @@ void fillUniformDevice(Tensor* t) { rocrand_generate_uniform(gen, tmp, N); - // map to [-2, 1] (like generate_data_uniformly) and cast into tensor dtype + // map to [-2.0, 1.0] (like generate_data_uniformly) and cast into tensor dtype TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { dim3 block(256); dim3 grid((N + block.x - 1) / block.x); From 3f10ed3a87ac605c5969f332f4debfdf6d112808 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 19 Jan 2026 16:57:05 +0000 Subject: [PATCH 3/4] skip copying to device --- tests/cpp/test_common.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 0f78a0419..4f13409b3 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -895,7 +895,10 @@ void fillUniform(Tensor *t) { } ); } +#ifndef __HIP_PLATFORM_AMD__ +// Data is already on device on AMDGPU t->from_cpu(); +#endif std::uniform_real_distribution<> dis(-2.0, 1.0); t->set_scale_inv(dis(t->gen())); } From 0f008e9859e385ad0b40510ffc8dedc962b1526d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 19 Jan 2026 17:04:03 -0600 Subject: [PATCH 4/4] move include, use hipify more, fix CPU copy --- tests/cpp/test_common.cu | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4f13409b3..e0614b191 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -23,6 +23,10 @@ #include #include "util/logging.h" +#ifdef __HIP_PLATFORM_AMD__ +#include +#endif + namespace test { size_t create_seed_from_tensor_name(const std::string& tensor_name) { @@ -828,8 +832,6 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { } #ifdef __HIP_PLATFORM_AMD__ -#include - template __global__ void affine_transform_and_cast(float* __restrict__ in, T* __restrict__ out, size_t n, float lo, float hi) { // Clamp values in *in* to [lo, hi] and cast to type *T* for *out*. @@ -846,7 +848,7 @@ void fillUniformDevice(Tensor* t) { const size_t N = product(shape); float* tmp = nullptr; - hipMalloc(&tmp, N * sizeof(float)); + cudaMalloc(&tmp, N * sizeof(float)); // per-tensor deterministic seed const unsigned long long seed = static_cast(t->gen()()); @@ -860,12 +862,17 @@ void fillUniformDevice(Tensor* t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { dim3 block(256); dim3 grid((N + block.x - 1) / block.x); - hipLaunchKernelGGL(affine_transform_and_cast, grid, block, 0, 0, - tmp, reinterpret_cast(dst), N, -2.0f, 1.0f); + affine_transform_and_cast<<>>( + tmp, reinterpret_cast(dst), N, -2.0f, 1.0f); + + // Copy into the CPU mirror. We could use Tensor::to_cpu() here, + // but that does more than just copying the data. + T* cpu_dst = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); + cudaMemcpy(cpu_dst, dst, N * sizeof(T), hipMemcpyDeviceToHost); }); rocrand_destroy_generator(gen); - hipFree(tmp); + cudaFree(tmp); } #endif @@ -896,7 +903,7 @@ void fillUniform(Tensor *t) { ); } #ifndef __HIP_PLATFORM_AMD__ -// Data is already on device on AMDGPU + // Data is already on device on AMDGPU t->from_cpu(); #endif std::uniform_real_distribution<> dis(-2.0, 1.0);