From 74a82fc69475ef9ff27fb9c70782903884997b0d Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 10 Mar 2026 19:19:34 -0500 Subject: [PATCH 1/5] MXFP8 Cast Kernel Optimizations --- benchmarks/cpp/CMakeLists.txt | 86 ++ .../cpp/cast/bench_dequantize_mxfp8.cpp | 131 ++ benchmarks/cpp/cast/bench_gated_mxfp8.cpp | 214 +++ .../cpp/cast/bench_quantize_mxfp8_fused.cpp | 181 +++ benchmarks/cpp/run_benchmarks.sh | 97 ++ benchmarks/cpp/utils/benchmark_utils.h | 223 +++ benchmarks/cpp/utils/test_common.cpp | 1241 +++++++++++++++++ benchmarks/cpp/utils/test_common.h | 684 +++++++++ .../common/cast/mxfp8/gated_mxfp8.cuh | 21 +- .../common/cast/mxfp8/quantize_mxfp8.cuh | 99 +- .../cast/mxfp8/rocm_dequantize_mxfp8.cuh | 4 +- .../common/cast/mxfp8/rocm_gated_mxfp8.cuh | 650 +++++---- .../common/cast/mxfp8/rocm_quantize_mxfp8.cuh | 528 ++++--- transformer_engine/common/common.h | 15 + .../common/normalization/common.h | 6 +- transformer_engine/common/util/math.h | 27 +- 16 files changed, 3729 insertions(+), 478 deletions(-) create mode 100644 benchmarks/cpp/CMakeLists.txt create mode 100644 benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp create mode 100644 benchmarks/cpp/cast/bench_gated_mxfp8.cpp create mode 100644 benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp create mode 100755 benchmarks/cpp/run_benchmarks.sh create mode 100644 benchmarks/cpp/utils/benchmark_utils.h create mode 100644 benchmarks/cpp/utils/test_common.cpp create mode 100644 benchmarks/cpp/utils/test_common.h diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt new file mode 100644 index 000000000..e95b5e873 --- /dev/null +++ b/benchmarks/cpp/CMakeLists.txt @@ -0,0 +1,86 @@ +cmake_minimum_required(VERSION 3.18) + +if(NOT DEFINED CMAKE_CXX_COMPILER) + set(CMAKE_CXX_COMPILER hipcc) +endif() + +project(transformer_engine_benchmarks LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(HIP REQUIRED) + +include(FetchContent) +FetchContent_Declare( + benchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.8.3 +) +set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) +set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable gtest in benchmark" FORCE) +FetchContent_MakeAvailable(benchmark) + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common/include + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}/utils +) + +if(DEFINED ENV{NVTE_ROCM_ARCH}) + set(GPU_TARGETS $ENV{NVTE_ROCM_ARCH}) +else() + set(GPU_TARGETS "gfx942;gfx950") +endif() + +set(COMMON_COMPILE_OPTIONS + -Wall + -Wextra + -O3 + -DNDEBUG + -DUSE_ROCM + --offload-arch=${GPU_TARGETS} + -w +) + +find_library(TRANSFORMER_ENGINE_LIB + NAMES transformer_engine + PATHS ${CMAKE_CURRENT_SOURCE_DIR}/../.. + ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake + ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib + /usr/local/lib + $ENV{HOME}/.local/lib + NO_DEFAULT_PATH +) + +if(NOT TRANSFORMER_ENGINE_LIB) + message(WARNING "TransformerEngine library not found in expected paths. Trying system paths...") + find_library(TRANSFORMER_ENGINE_LIB NAMES transformer_engine) +endif() + +if(TRANSFORMER_ENGINE_LIB) + message(STATUS "Found TransformerEngine library: ${TRANSFORMER_ENGINE_LIB}") +else() + message(FATAL_ERROR "TransformerEngine library not found. Please build TransformerEngine first:\n" + " cd ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" + " pip install -e . --no-build-isolation\n" + "Searched paths:\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib") +endif() + +function(add_te_benchmark TARGET_NAME SOURCE_FILE) + add_executable(${TARGET_NAME} ${SOURCE_FILE} utils/test_common.cpp) + target_compile_options(${TARGET_NAME} PRIVATE ${COMMON_COMPILE_OPTIONS}) + target_link_libraries(${TARGET_NAME} PRIVATE + benchmark::benchmark + ${TRANSFORMER_ENGINE_LIB} + hiprand + ) + set_target_properties(${TARGET_NAME} PROPERTIES HIP_ARCHITECTURES "${GPU_TARGETS}") +endfunction() + +add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) +add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) +add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) diff --git a/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp new file mode 100644 index 000000000..7dbfd834e --- /dev/null +++ b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp @@ -0,0 +1,131 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include "amd_detail/hip_float8.h" + +#include +#include + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 3584}) \ + ->Args({1024, 4096}) \ + ->Args({1024, 8192}) \ + ->Args({1024, 14336}) \ + ->Args({2048, 4096}) \ + ->Args({2048, 8192}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({4096, 4096}) \ + ->Args({4096, 8192}) \ + ->Args({4096, 16384}) \ + ->Args({4096, 28672}) \ + ->Args({8192, 8192}) \ + ->Args({8192, 16384}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 8192}) \ + ->Args({16384, 16384})\ + ->Args({32768, 8192}) + +template +static void BM_DequantizeMXFP8(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? cols : 0; + + std::vector shape = {rows, cols}; + DType itype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + DType otype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + + test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + DeviceBuffer temp_fp32(rows * cols); + fill_random_uniform_gpu(temp_fp32.get(), rows * cols, -2.0f, 1.0f, stream); + + void *input_data_ptr = USE_ROWWISE ? input_tensor.rowwise_dptr() : input_tensor.columnwise_dptr(); + size_t threads = 256; + size_t blocks = (rows * cols + threads - 1) / threads; + cast_fp32_kernel<<>>(temp_fp32.get(), static_cast(input_data_ptr), rows * cols); + + HIP_CHECK(hipStreamSynchronize(stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_read_data = rows * cols * sizeof(IType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + const size_t bytes_read_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + const size_t bytes_write = rows * cols * sizeof(OType); + const size_t total_bytes = bytes_read_data + bytes_read_scales + bytes_write; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_DEQUANTIZE_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, __half, "E4M3", "FP16") +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, hip_bfloat16, "E4M3", "BF16") +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, float, "E4M3", "FP32") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/cast/bench_gated_mxfp8.cpp b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp new file mode 100644 index 000000000..7df895817 --- /dev/null +++ b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include "amd_detail/hip_float8.h" + +#include +#include +#include + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +// SwiGLU shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 14336}) \ + ->Args({1024, 18944}) \ + ->Args({1024, 28672}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({2048, 29568}) \ + ->Args({4096, 14336}) \ + ->Args({4096, 28672}) \ + ->Args({4096, 53248}) \ + ->Args({8192, 14336}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 29568}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 28672}) \ + ->Args({16384, 53248}) \ + ->Args({32768, 28672}) \ + ->Args({32768, 53248}) + +template +static void BM_GatedMXFP8_Forward(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t input_cols = cols * 2; + const size_t output_cols = cols; + + const size_t scale_cols_row = USE_ROWWISE ? (output_cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? output_cols : 0; + + std::vector input_shape = {rows, input_cols}; + std::vector output_shape = {rows, output_cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &input_tensor = TensorCache::get_or_create("input", input_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", output_shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_write_data = rows * output_cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + const size_t bytes_read = rows * cols * sizeof(IType) * 2; + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +template +static void BM_GatedMXFP8_Backward(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t input_cols = cols * 2; + const size_t output_cols = cols * 2; + + const size_t scale_cols_row = USE_ROWWISE ? (output_cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? output_cols : 0; + + std::vector grad_shape = {rows, cols}; + std::vector input_shape = {rows, input_cols}; + std::vector output_shape = {rows, output_cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &grad_tensor = TensorCache::get_or_create("grad", grad_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &input_tensor = TensorCache::get_or_create("input", input_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", output_shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_dswiglu(grad_tensor.data(), input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_write_data = rows * output_cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + const size_t bytes_read = rows * cols * sizeof(IType) * 3; + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_GATED_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/both") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/both") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_GATED_ALL_CONFIGS(__half, fp8_e4m3, "FP16", "E4M3") +REGISTER_GATED_ALL_CONFIGS(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") +REGISTER_GATED_ALL_CONFIGS(float, fp8_e4m3, "FP32", "E4M3") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp new file mode 100644 index 000000000..4326a65ff --- /dev/null +++ b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp @@ -0,0 +1,181 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include +#include +#include + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 3584}) \ + ->Args({1024, 4096}) \ + ->Args({1024, 8192}) \ + ->Args({1024, 14336}) \ + ->Args({1024, 18944}) \ + ->Args({2048, 4096}) \ + ->Args({2048, 8192}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({2048, 29568}) \ + ->Args({4096, 4096}) \ + ->Args({4096, 8192}) \ + ->Args({4096, 16384}) \ + ->Args({4096, 14336}) \ + ->Args({4096, 28672}) \ + ->Args({8192, 8192}) \ + ->Args({8192, 16384}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 29568}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 8192}) \ + ->Args({16384, 16384})\ + ->Args({16384, 28672})\ + ->Args({32768, 8192}) \ + ->Args({32768, 16384}) + +template +static void BM_QuantizeMXFP8_Fused(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? cols : 0; + + std::vector shape = {rows, cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspace_tensor_ptr = nullptr; + + if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { + std::vector bias_shape = {cols}; + dbias_tensor_ptr = &TensorCache::get_or_create("dbias", bias_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + workspace_tensor_ptr = &TensorCache::get_or_create("workspace", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + } + + if constexpr (PROC_METHOD == CAST_DBIAS_DACT || PROC_METHOD == CAST_DACT) { + grad_tensor_ptr = &TensorCache::get_or_create("grad", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + } + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + if constexpr (PROC_METHOD == CAST_ONLY) { + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); + } else if constexpr (PROC_METHOD == CAST_DBIAS) { + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor_ptr->data(), workspace_tensor_ptr->data(), stream); + } else if constexpr (PROC_METHOD == CAST_DBIAS_DACT) { + nvte_quantize_dbias_dgelu(grad_tensor_ptr->data(), input_tensor.data(), output_tensor.data(), dbias_tensor_ptr->data(), workspace_tensor_ptr->data(), stream); + } else if constexpr (PROC_METHOD == CAST_DACT) { + nvte_dgelu(grad_tensor_ptr->data(), input_tensor.data(), output_tensor.data(), stream); + } else if constexpr (PROC_METHOD == CAST_ACT) { + nvte_gelu(input_tensor.data(), output_tensor.data(), stream); + } + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t bytes_write_data = rows * cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + size_t bytes_read = rows * cols * sizeof(IType); + if constexpr (PROC_METHOD == CAST_DBIAS_DACT || PROC_METHOD == CAST_DACT) { + bytes_read += rows * cols * sizeof(IType); + } + if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { + bytes_write_data += cols * sizeof(IType); + } + + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, METHOD, METHOD_NAME) \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 1, 32, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/rowwise/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 32, 1, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/colwise/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 32, 32, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/both/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define REGISTER_ALL_METHODS(ITYPE, OTYPE, INAME, ONAME) \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_ONLY, "CastOnly") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DBIAS, "CastDBias") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DBIAS_DACT, "CastDBiasDACT") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DACT, "CastDACT") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_ACT, "CastACT") + +REGISTER_ALL_METHODS(__half, fp8_e4m3, "FP16", "E4M3") +REGISTER_ALL_METHODS(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") +REGISTER_ALL_METHODS(float, fp8_e4m3, "FP32", "E4M3") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/run_benchmarks.sh b/benchmarks/cpp/run_benchmarks.sh new file mode 100755 index 000000000..6b9fb5806 --- /dev/null +++ b/benchmarks/cpp/run_benchmarks.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Builds benchmarks, runs them, and consolidates results into a single CSV + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${SCRIPT_DIR}/build" +RESULTS_DIR="${SCRIPT_DIR}/results" + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +main() { + echo -e "${GREEN}=== MXFP8 Benchmark Suite ===${NC}" + + echo -e "\n${YELLOW}[1/3] Building benchmarks...${NC}" + cd "${SCRIPT_DIR}" + if ! cmake -GNinja -B"${BUILD_DIR}" . || ! cmake --build "${BUILD_DIR}"; then + echo -e "${RED}Build failed. Fix the build errors and try again.${NC}" + return + fi + echo -e "${GREEN}✓ Build complete${NC}" + + mkdir -p "${RESULTS_DIR}" + TIMESTAMP=$(date +%Y%m%d_%H%M%S) + RESULT_PREFIX="${RESULTS_DIR}/bench_${TIMESTAMP}" + + echo -e "\n${YELLOW}[2/3] Running benchmarks...${NC}" + + BENCHMARKS=( + "bench_quantize_mxfp8_fused" + "bench_dequantize_mxfp8" + "bench_gated_mxfp8" + ) + + FAILED_BENCHMARKS=() + for bench in "${BENCHMARKS[@]}"; do + if [ -f "${BUILD_DIR}/${bench}" ]; then + echo -e " Running ${bench}..." + if "${BUILD_DIR}/${bench}" \ + --benchmark_out="${RESULT_PREFIX}_${bench}.csv" \ + --benchmark_out_format=csv \ + --benchmark_min_time=0.2s; then + echo -e " ${GREEN}✓${NC} Saved to ${RESULT_PREFIX}_${bench}.csv" + else + echo -e " ${RED}✗${NC} ${bench} failed (exit code $?), continuing..." + FAILED_BENCHMARKS+=("${bench}") + fi + else + echo -e " ${RED}✗${NC} ${bench} not found, skipping" + fi + done + + echo -e "\n${YELLOW}[3/3] Consolidating results...${NC}" + + CONSOLIDATED_CSV="${RESULT_PREFIX}_all.csv" + FIRST_CSV=$(ls "${RESULT_PREFIX}"_*.csv 2>/dev/null | grep -v "_all.csv" | head -1) + + if [ -z "$FIRST_CSV" ]; then + echo -e "${RED}No CSV files found to consolidate${NC}" + return + fi + + head -1 "$FIRST_CSV" > "$CONSOLIDATED_CSV" + + for csv in "${RESULT_PREFIX}"_bench_*.csv; do + if [ "$csv" != "$CONSOLIDATED_CSV" ]; then + tail -n +2 "$csv" >> "$CONSOLIDATED_CSV" + fi + done + + echo -e "${GREEN}✓ Consolidated CSV: ${CONSOLIDATED_CSV}${NC}" + + echo -e "\n${GREEN}=== Summary ===${NC}" + TOTAL_ROWS=$(tail -n +2 "$CONSOLIDATED_CSV" | wc -l) + echo "Total benchmarks: $TOTAL_ROWS" + echo "Results saved to: ${RESULTS_DIR}/" + echo "" + echo "Files created:" + for bench in "${BENCHMARKS[@]}"; do + if [ -f "${RESULT_PREFIX}_${bench}.csv" ]; then + echo " - $(basename "${RESULT_PREFIX}_${bench}.csv")" + fi + done + echo " - $(basename "$CONSOLIDATED_CSV") (consolidated)" + echo "" + + if [ ${#FAILED_BENCHMARKS[@]} -gt 0 ]; then + echo -e "${RED}Failed benchmarks:${NC}" + for bench in "${FAILED_BENCHMARKS[@]}"; do + echo -e " ${RED}✗${NC} ${bench}" + done + echo "" + fi +} + +main diff --git a/benchmarks/cpp/utils/benchmark_utils.h b/benchmarks/cpp/utils/benchmark_utils.h new file mode 100644 index 000000000..35a857109 --- /dev/null +++ b/benchmarks/cpp/utils/benchmark_utils.h @@ -0,0 +1,223 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "test_common.h" + +namespace te_bench { + +#define HIP_CHECK(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error at %s:%d: %s\n", __FILE__, __LINE__, \ + hipGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +template +class DeviceBuffer { + public: + DeviceBuffer(size_t count) : count_(count) { + HIP_CHECK(hipMalloc(&ptr_, count * sizeof(T))); + } + + ~DeviceBuffer() { + if (ptr_) { + hipError_t err = hipFree(ptr_); + (void)err; + } + } + + DeviceBuffer(const DeviceBuffer &) = delete; + DeviceBuffer &operator=(const DeviceBuffer &) = delete; + + DeviceBuffer(DeviceBuffer &&other) noexcept : ptr_(other.ptr_), count_(other.count_) { + other.ptr_ = nullptr; + other.count_ = 0; + } + + T *get() { return ptr_; } + const T *get() const { return ptr_; } + size_t count() const { return count_; } + size_t bytes() const { return count_ * sizeof(T); } + + void upload(const std::vector &host_data) { + if (host_data.size() != count_) { + throw std::runtime_error("Size mismatch in upload"); + } + HIP_CHECK(hipMemcpy(ptr_, host_data.data(), bytes(), hipMemcpyHostToDevice)); + } + + void download(std::vector &host_data) const { + host_data.resize(count_); + HIP_CHECK(hipMemcpy(host_data.data(), ptr_, bytes(), hipMemcpyDeviceToHost)); + } + + private: + T *ptr_ = nullptr; + size_t count_ = 0; +}; + +template +std::vector generate_random_data(size_t count, T min_val = -1.0, T max_val = 1.0) { + std::vector data(count); + std::mt19937 gen(42); + + if constexpr (std::is_floating_point_v) { + std::uniform_real_distribution dist(min_val, max_val); + for (auto &val : data) { + val = dist(gen); + } + } else { + std::uniform_int_distribution dist(static_cast(min_val), static_cast(max_val)); + for (auto &val : data) { + val = static_cast(dist(gen)); + } + } + + return data; +} + +__global__ void scale_shift_kernel(float *data, size_t count, float scale, float offset) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < count) { + data[idx] = data[idx] * scale + offset; + } +} + +inline void fill_random_uniform_gpu(float *dptr, size_t count, float min_val = -2.0f, float max_val = 1.0f, hipStream_t stream = 0) { + hiprandGenerator_t gen; + hiprandCreateGenerator(&gen, HIPRAND_RNG_PSEUDO_DEFAULT); + hiprandSetPseudoRandomGeneratorSeed(gen, 42); + if (stream != 0) { + hiprandSetStream(gen, stream); + } + hiprandGenerateUniform(gen, dptr, count); + float scale = max_val - min_val; + float offset = min_val; + + size_t threads = 256; + size_t blocks = (count + threads - 1) / threads; + scale_shift_kernel<<>>(dptr, count, scale, offset); + + hiprandDestroyGenerator(gen); +} + +template +__global__ void cast_fp32_kernel(const float *in, T *out, size_t count) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < count) { + out[idx] = static_cast(in[idx]); + } +} + +template +inline void fill_random_uniform_gpu_typed(T *dptr, size_t count, float min_val = -2.0f, float max_val = 1.0f, hipStream_t stream = 0) { + if constexpr (std::is_same_v) { + fill_random_uniform_gpu(dptr, count, min_val, max_val, stream); + } else { + DeviceBuffer temp_fp32(count); + fill_random_uniform_gpu(temp_fp32.get(), count, min_val, max_val, stream); + + size_t threads = 256; + size_t blocks = (count + threads - 1) / threads; + cast_fp32_kernel<<>>(temp_fp32.get(), dptr, count); + } +} + +inline void warmup_gpu(int iterations = 10) { + DeviceBuffer dummy(1024); + for (int i = 0; i < iterations; ++i) { + HIP_CHECK(hipMemset(dummy.get(), 0, dummy.bytes())); + } + HIP_CHECK(hipDeviceSynchronize()); +} + +inline double calculate_bandwidth_gbps(size_t bytes, double time_ns) { + return (bytes / 1e9) / (time_ns / 1e9); +} + +inline void set_items_processed(benchmark::State &state, size_t items_per_iter) { + state.SetItemsProcessed(state.iterations() * items_per_iter); +} + +inline void set_bytes_processed(benchmark::State &state, size_t bytes_per_iter) { + state.SetBytesProcessed(state.iterations() * bytes_per_iter); +} + +class TensorCache { + public: + struct CacheKey { + std::string name; + size_t rows; + size_t cols; + transformer_engine::DType dtype; + bool rowwise; + bool colwise; + NVTEScalingMode scaling_mode; + + bool operator<(const CacheKey &other) const { + return std::tie(name, rows, cols, dtype, rowwise, colwise, scaling_mode) < + std::tie(other.name, other.rows, other.cols, other.dtype, other.rowwise, other.colwise, other.scaling_mode); + } + }; + + static test::Tensor &get_or_create(const std::string &name, + const std::vector &shape, + transformer_engine::DType dtype, + bool rowwise = true, + bool colwise = false, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING, + bool initialize_random = false) { + CacheKey key{name, shape[0], shape[1], dtype, rowwise, colwise, scaling_mode}; + + static auto* cache = new std::map>(); + + auto it = cache->find(key); + if (it == cache->end()) { + auto tensor_ptr = std::make_unique(name, shape, dtype, rowwise, colwise, scaling_mode); + + if (initialize_random && dtype != transformer_engine::DType::kFloat8E4M3 && + dtype != transformer_engine::DType::kFloat8E5M2) { + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + size_t count = shape[0] * shape[1]; + void *data_ptr = tensor_ptr->rowwise_dptr(); + + if (dtype == transformer_engine::DType::kFloat32) { + fill_random_uniform_gpu(static_cast(data_ptr), count, -2.0f, 1.0f, stream); + } else if (dtype == transformer_engine::DType::kFloat16) { + fill_random_uniform_gpu_typed<__half>(static_cast<__half*>(data_ptr), count, -2.0f, 1.0f, stream); + } else if (dtype == transformer_engine::DType::kBFloat16) { + fill_random_uniform_gpu_typed(static_cast(data_ptr), count, -2.0f, 1.0f, stream); + } + + HIP_CHECK(hipStreamSynchronize(stream)); + HIP_CHECK(hipStreamDestroy(stream)); + } + + (*cache)[key] = std::move(tensor_ptr); + it = cache->find(key); + } + + return *(it->second); + } +}; +} // namespace te_bench diff --git a/benchmarks/cpp/utils/test_common.cpp b/benchmarks/cpp/utils/test_common.cpp new file mode 100644 index 000000000..3caf1245c --- /dev/null +++ b/benchmarks/cpp/utils/test_common.cpp @@ -0,0 +1,1241 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + + +#include "test_common.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include "util/logging_hip.h" + +#include + +namespace test { + +size_t create_seed_from_tensor_name(const std::string& tensor_name) { + auto full_name = "benchmark/" + tensor_name; + return std::hash{}(full_name); +} + +std::vector all_fp_types = {DType::kFloat32, + DType::kFloat16, + DType::kBFloat16, + DType::kFloat8E5M2, + DType::kFloat8E4M3}; + +bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) { + if (s1.ndim != s2.ndim) return false; + + for (size_t i = 0; i < s1.ndim; ++i) { + if (s1.data[i] != s2.data[i]) return false; + } + + return true; +} + +size_t typeToNumBits(DType type) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, + { + return TypeInfo::size; + }); +} + +const std::string &typeName(DType type) { + static const std::unordered_map name_map = { + {DType::kByte, "byte"}, + {DType::kInt32, "int32"}, + {DType::kInt64, "int64"}, + {DType::kFloat32, "float32"}, + {DType::kFloat16, "float16"}, + {DType::kBFloat16, "bfloat16"}, + {DType::kFloat8E4M3, "float8e4m3"}, + {DType::kFloat8E5M2, "float8e5m2"}, + {DType::kFloat8E8M0, "float8e8m0"}, + {DType::kFloat4E2M1, "float4e2m1"}}; + return name_map.at(type); +} + +const std::string& caseName(InputsFillCase type) { + static const std::unordered_map name_map = { + {InputsFillCase::uniform, "uniform"}, + {InputsFillCase::zeros, "zeros"}, + {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"}, + {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"}, + {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}}; + return name_map.at(type); +} + +size_t product(const NVTEShape &shape, size_t begin, size_t end) { + size_t ret = 1; + NVTE_CHECK(end <= shape.ndim); + for (size_t i = begin; i < end; ++i) { + ret *= shape.data[i]; + } + return ret; +} + +size_t product(const NVTEShape &shape) { + return product(shape, 0, shape.ndim); +} + +size_t product(const std::vector shape, size_t begin, size_t end) { + size_t ret = 1; + NVTE_CHECK(end <= shape.size()); + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} + +size_t product(const std::vector& shape) { + return product(shape, 0, shape.size()); +} + +size_t DIVUP(const size_t &x, const size_t &y){ + return (((x) + ((y)-1)) / (y)); +} + +size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){ + return DIVUP(x, y) * y; +} + +struct scale_inv_meta { + std::vector shape; + DType type; + size_t type_size_bits; + size_t bytes() const noexcept { + return (product(shape) * type_size_bits) / 8; + } +}; + +size_t bytes(const NVTEShape& shape, const DType type) { + return (product(shape) * typeToNumBits(type)) / 8; +} + +NVTEShape convertShape(const std::vector& s) { + return nvte_make_shape(s.data(), s.size()); +} + +std::pair get_scales(const NVTEShape& shape, + const NVTEScalingMode scaling_mode) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + scale_inv_meta ret; + ret.shape = {1}; + ret.type = DType::kFloat32; + ret.type_size_bits = typeToNumBits(DType::kFloat32); + return {ret, ret}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + + ret_rowwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type = DType::kFloat8E8M0; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + NVTE_CHECK(last_dim % 32 == 0); + NVTE_CHECK(first_dim % 32 == 0); + + scale_inv_meta ret_rowwise, ret_colwise; + + size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y, scale_dim_X}; + + size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t}; + + ret_rowwise.type = DType::kFloat8E4M3; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + ret_colwise.type = DType::kFloat8E4M3; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + + ret_rowwise.type = DType::kFloat8E8M0; + ret_colwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_BLOCK_SCALING_2D) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(first_dim, 128lu); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = DIVUP(last_dim, 128lu); +#else + auto scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4; +#endif + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(last_dim, 128lu); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = DIVUP(first_dim, 128lu); +#else + auto scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4; +#endif + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(last_dim, 128lu); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = first_dim; +#else + auto scale_dim_1 = DIVUP(first_dim, 4) * 4; +#endif + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(first_dim, 128lu); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = last_dim; +#else + auto scale_dim_1 = DIVUP(last_dim, 4) * 4; +#endif + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); + return {ret_rowwise, ret_colwise}; + } + + NVTE_ERROR("Invalid scaling mode!"); +} + +Tensor::Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise, const bool columnwise, + const NVTEScalingMode &scaling_mode) { + name_ = name; + const size_t seed = create_seed_from_tensor_name(name); + gen_.seed(seed); + rowwise_ = rowwise; + columnwise_ = columnwise; + size_t total_size = bytes(shape, type); + void *dptr_rowwise = nullptr; + void *dptr_columnwise = nullptr; + cpu_data_rowwise_ = nullptr; + cpu_data_columnwise_ = nullptr; + amax_cpu_data_ = nullptr; + scale_cpu_data_ = nullptr; + rowwise_scale_inv_cpu_data_ = nullptr; + columnwise_scale_inv_cpu_data_ = nullptr; + float *amax = nullptr, *scale = nullptr; + float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr; + if (columnwise) { + NVTE_CHECK(shape.ndim >= 2); + } + std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), + shape.data[shape.ndim - 1]}; + NVTEShape normalized_shape = convertShape(normalized_shape_v); + NVTEShape columnwise_shape = {}; + + std::vector columnwise_shape_vec; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING + || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { + // Transpose when tensor scaling + columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); + for (size_t i = 0; i < shape.ndim - 1; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); + } + } else { + // Same shape for MX and NVFP4 + for (size_t i = 0; i < shape.ndim; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); + } + } + + if (columnwise) { + columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); + } + + tensor_ = TensorWrapper(scaling_mode); + + if (total_size != 0) { + if (rowwise) { + (void)hipMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) + (void)hipMemset(dptr_rowwise, 0, total_size); + cpu_data_rowwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_rowwise_.get(), total_size, 0); + } + if (columnwise) { + (void)hipMalloc((void**)&dptr_columnwise, total_size); // NOLINT(*) + (void)hipMemset(dptr_columnwise, 0, total_size); + cpu_data_columnwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_columnwise_.get(), total_size, 0); + } + } + + const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); + tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); + + if (isFp8Type(type) || isFp4Type(type)) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + (void)hipMalloc((void**)&amax, sizeof(float)); // NOLINT(*) + (void)hipMemset(amax, 0, sizeof(float)); + (void)hipMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + (void)hipMemset(scale, 0, sizeof(float)); + amax_cpu_data_ = std::make_shared(0); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + (void)hipMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) + if (rowwise) { + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + if (columnwise) { + tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + } else { + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + // Used for NVFP4 second stage scaling + hipMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + hipMemset(scale, 0, sizeof(float)); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); + auto rowwise_scale_size = rowwise_scale_meta.bytes(); + auto columnwise_scale_size = colwise_scale_meta.bytes(); + auto scale_shape = rowwise_scale_meta.shape; + auto columnwise_scale_shape = colwise_scale_meta.shape; + if (rowwise) { + (void)hipMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + (void)hipMemset(rowwise_scale_inv, 0, rowwise_scale_size); + rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); + auto scale_dtype = rowwise_scale_meta.type; + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); + } + if (columnwise) { + (void)hipMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) + (void)hipMemset(columnwise_scale_inv, 0, columnwise_scale_size); + columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); + auto scale_dtype = colwise_scale_meta.type; + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); + } + } + } +} + +void Tensor::to_cpu() const { + const NVTEShape s = tensor_.shape(); + const size_t size = bytes(s, tensor_.dtype()); + if (rowwise_) { + (void)hipMemcpy(cpu_data_rowwise_.get(), + tensor_.get_rowwise_data().data_ptr, + size, + hipMemcpyDeviceToHost); + } + if (columnwise_) { + const DType colwise_type = tensor_.dtype(); + + const size_t colwise_size = bytes(s, colwise_type); + (void)hipMemcpy(cpu_data_columnwise_.get(), + tensor_.get_columnwise_data().data_ptr, + colwise_size, + hipMemcpyDeviceToHost); + } + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { + if (tensor_.amax() != nullptr){ + (void)hipMemcpy(amax_cpu_data_.get(), + tensor_.amax(), + sizeof(float), + hipMemcpyDeviceToHost); + } + (void)hipMemcpy(scale_cpu_data_.get(), + tensor_.scale(), + sizeof(float), + hipMemcpyDeviceToHost); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = rowwise_scale_meta.bytes(); + (void)hipMemcpy(rowwise_scale_inv_cpu_data_.get(), + tensor_.get_rowwise_scale_inv().data_ptr, + scale_size, + hipMemcpyDeviceToHost); + } + if (columnwise_) { + auto scale_size = colwise_scale_meta.bytes(); + (void)hipMemcpy(columnwise_scale_inv_cpu_data_.get(), + tensor_.get_columnwise_scale_inv().data_ptr, + scale_size, + hipMemcpyDeviceToHost); + } + } +} + +void Tensor::from_cpu() const { + const NVTEShape s = tensor_.shape(); + const size_t size = bytes(s, tensor_.dtype()); + if (rowwise_) { + (void)hipMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size, + hipMemcpyHostToDevice); + } + if (columnwise_) { + (void)hipMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, + hipMemcpyHostToDevice); + } + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) { + if (tensor_.amax() != nullptr){ + (void)hipMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), hipMemcpyHostToDevice); + } + (void)hipMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), hipMemcpyHostToDevice); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = rowwise_scale_meta.bytes(); + (void)hipMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, + rowwise_scale_inv_cpu_data_.get(), scale_size, + hipMemcpyHostToDevice); + } + if (columnwise_) { + auto scale_size = colwise_scale_meta.bytes(); + (void)hipMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, + columnwise_scale_inv_cpu_data_.get(), scale_size, + hipMemcpyHostToDevice); + } + } +} + +void Tensor::set_scale(float scale) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + NVTE_CHECK(scale_cpu_data_); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + *scale_cpu_data_ = scale; + from_cpu(); + } + } +} + +void Tensor::set_scale_inv(float scale_inv) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if (rowwise_) { + NVTE_CHECK(rowwise_scale_inv_cpu_data_); + } + if (columnwise_) { + NVTE_CHECK(columnwise_scale_inv_cpu_data_); + } + + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + if (rowwise_) { + auto num_scales = product(rowwise_scale_meta.shape); + if (num_scales == 1) { + rowwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else { + std::uniform_int_distribution dis(0, 127); + auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { + scale_inv_ptr[i] = dis(gen_); + } + } + } + if (columnwise_) { + auto num_scales = product(colwise_scale_meta.shape); + if (num_scales == 1) { + columnwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else { + std::uniform_int_distribution dis(0, 127); + if (rowwise_) { + from_cpu(); //Need it because scale_inv_ptr getting does to_cpu() + } + auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { + scale_inv_ptr[i] = dis(gen_); + } + } + } + from_cpu(); + } +} + +void Tensor::shareFP8Meta(const Tensor &other) { + if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) + || isFp4Type(dtype()) && isFp4Type(other.dtype())) { + auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); + auto my_rowwise_data = tensor_.get_rowwise_data(); + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), + my_rowwise_data.shape); + auto my_columnwise_data = tensor_.get_columnwise_data(); + new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, + static_cast(my_columnwise_data.dtype), + my_columnwise_data.shape); + auto other_amax = other.tensor_.get_amax(); + new_tensor.set_amax(other_amax.data_ptr, static_cast(other_amax.dtype), + other_amax.shape); + auto other_scale = other.tensor_.get_scale(); + new_tensor.set_scale(other_scale.data_ptr, static_cast(other_scale.dtype), + other_scale.shape); + auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); + new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, + static_cast(other_row_scale_inv.dtype), + other_row_scale_inv.shape); + auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv(); + new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr, + static_cast(other_col_scale_inv.dtype), + other_col_scale_inv.shape); + tensor_ = std::move(new_tensor); + to_cpu(); + } +} + +using std::to_string; + +template +std::string to_string(const std::vector &v) { + std::string s = "["; + for (const auto x : v) { + s += to_string(x) + ", "; + } + s.pop_back(); + s.pop_back(); + return s + "]"; +} + +std::vector unravel(const size_t i, const NVTEShape &shape) { + std::vector ret; + size_t current_i = i; + for (size_t current = shape.ndim - 1; current > 0; --current) { + ret.push_back(current_i % shape.data[current]); + current_i /= shape.data[current]; + } + ret.push_back(current_i); + std::reverse(ret.begin(), ret.end()); + return ret; +} + +#ifndef BENCHMARK_STATIC_DEFINE +void compareResults_sequential(const std::string &name, const Tensor &test, + const void *ref, const bool rowwise, + double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); + size_t mismatches_num = 0; + int first_mismatch_idx = -1; + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); + const T *ref_data = reinterpret_cast(ref); + for (size_t i = 0; i < N; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && test.dtype() == DType::kFloat32; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); + } + std::string direction = rowwise ? "rowwise" : "columnwise"; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape)) + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } + } + ); +} + +template +static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, + const size_t N, const double atol, const double rtol, + size_t& mismatches) { + int first_mismatch_idx = N; + + #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread) + { + size_t thread_mismatches = 0; + #pragma omp for schedule(static) + for (size_t i = 0; i < N; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + if (i < first_mismatch_idx) { + first_mismatch_idx = i; + } + thread_mismatches++; + } + } + mismatches += thread_mismatches; + } + return first_mismatch_idx; +} + +void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); + size_t mismatches = 0; + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); + const T *ref_data = reinterpret_cast(ref); + + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches); + if ((i != N) && (mismatches > tolerable_mismatches_limit)) { + const double t = static_cast(test_data[i]); + const double r = static_cast(ref_data[i]); + std::string direction = rowwise ? "rowwise" : "columnwise"; + + GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; + } + ); +} + +void compareResults(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { + constexpr bool sequential = false; + if constexpr (sequential) { + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); + } else { + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); + } +} + +void compareResults(const std::string &name, const float test, const float ref, + double atol, double rtol) { + double t = static_cast(test); + double r = static_cast(ref); + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + ASSERT_FALSE(mismatch) << "Error in " << name << std::endl + << "Mismatch: " << t << " vs " << r; + +} + + +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol) { + size_t max_mismatches = std::ceil(N * mismatch_rate_tol); + size_t n_mismatches = 0; + std::vector mismatch_indices; + for (int i = 0; i < N; i++){ + bool mismatch = test[i] != ref[i]; + if (mismatch){ + n_mismatches++; + mismatch_indices.push_back(i); + } + if (n_mismatches > max_mismatches){ + std::cout << "Error in " << name << std::endl; + for (auto &index : mismatch_indices) + std::cout << "Mismatch at (" << index << "):" << static_cast(test[i]) << " vs " + << static_cast(ref[i]) << std::endl; + GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol."; + } + } +} + +template +struct CastToType; + +template <> +struct CastToType { + using type = int; +}; + +template <> +struct CastToType { + using type = float; +}; + +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) +{ + using UpcastType = typename CastToType::type; + auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); + + + const size_t N = row_blocks * col_blocks; + const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, + std::floor(N * rel_tolerable_mismatches_limit)); + mismatches_num = 0; +#ifndef __HIP_PLATFORM_AMD__ + std::vector mismatch_indices; +#endif //#ifndef __HIP_PLATFORM_AMD__ + + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int idx = i * stride + j; + float t, r; + + bool assertion = false; + + if (std::is_same::value) { + t = static_cast(test[idx]); + r = static_cast(ref[idx]); + assertion = std::abs(t - r) > atol; + } else { + t = static_cast(*reinterpret_cast(&test[idx])); + r = static_cast(*reinterpret_cast(&ref[idx])); + const bool mismatch = (fabs(t - r) > atol_fp8e4m3) + && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3); + if (mismatch) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + } + if (assertion) { + mismatches_num++; + mismatch_indices.push_back(idx); + } + if (mismatches_num > tolerable_mismatches_limit) { + std::cout << "Error in " << name << std::endl; + for (const int index : mismatch_indices) { + std::cout << "Mismatch at (" << index << "):" + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; + } + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "."; + } + } + } +} + +#ifdef __HIP_PLATFORM_AMD__ +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t scale_stride, const size_t rows, + const size_t cols, bool rowwise, void *ref_ptr, DType otype) { + if (mismatch_idx.size() == 0) { + return; + } + const size_t col_blocks_size = rowwise ? 32 : 1; + const size_t row_blocks_size = rowwise ? 1 : 32; + GTEST_LOG_(INFO) << "Adjusting reference data for " << mismatch_idx.size() + << " scale mismatches in tensor " << name << " " + << (rowwise ? "rowwise" : "colwise") << " direction." << std::endl; + for (const auto scale_idx : mismatch_idx) { + const int scale_diff = ref_scale[scale_idx] - test_scale[scale_idx]; + double scale_val; + if (scale_diff == 1) { + scale_val = 2.; + } else if (scale_diff == -1) { + scale_val = .5; + } else { + GTEST_FAIL() << "Error in " << name << ": mismatch " << test_scale[scale_idx] << " vs " + << ref_scale[scale_idx] << " at index " << scale_idx; + } + const int i = scale_idx / scale_stride; + const int j = scale_idx % scale_stride; + size_t ii_min = i * row_blocks_size; + const size_t ii_max = std::min(ii_min + row_blocks_size, rows); + for (; ii_min < ii_max; ii_min++) { + size_t jj_min = j * col_blocks_size; + const size_t jj_max = std::min(jj_min + col_blocks_size, cols); + for (; jj_min < jj_max; jj_min++) { + const size_t data_idx = ii_min * cols + jj_min; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(otype, T, { + T *ref_data = reinterpret_cast(ref_ptr); + ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val); + }); // NOLINT(*) + } + } + } +} +#endif // #ifdef __HIP_PLATFORM_AMD__ + +// Instantiate templates +template +void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + +template +void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + +#endif // #ifndef BENCHMARK_STATIC_DEFINE + + +std::pair getTolerances(const DType type) { + switch(type) { + case DType::kFloat32: + return {1e-6, 5e-6}; + case DType::kFloat16: + return {1e-5, 1e-3}; + case DType::kBFloat16: + return {1e-5, 1e-2}; + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kFloat8E8M0: + return {1e-2, 1e-2}; + default: + NVTE_ERROR("Invalid type!"); + } + return {0, 0}; +} + +#ifndef __HIP_PLATFORM_AMD__ +template +void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { + // Check how many RNG calls are required to generate one uniform random value + int rng_calls_per_val = 0; + { + std::mt19937 gen1 = *gen, gen2 = *gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float _ = dis(gen1); + while (gen2 != gen1) { + auto _ = gen2(); + ++rng_calls_per_val; + } + } + + // Generate uniform random values in parallel + #pragma omp parallel proc_bind(spread) + { + std::mt19937 gen_local = *gen; + const int thread_ID = omp_get_thread_num(); + const int threads_num = omp_get_max_threads(); + const int chunk_size = (size + threads_num - 1) / threads_num; + const int idx_min = chunk_size * thread_ID; + const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast(size)); + gen_local.discard(idx_min * rng_calls_per_val); + std::uniform_real_distribution<> dis(-2.0, 1.0); + + for (int i = idx_min; i < idx_max; ++i) { + data[i] = static_cast(dis(gen_local)); + } + } + gen->discard(size * rng_calls_per_val); +} +#endif + +#ifdef __HIP_PLATFORM_AMD__ +template +__global__ void affine_transform_cast_signs(const float* __restrict__ in, + const float* __restrict__ signs, + T* __restrict__ out, + size_t n, double lo, double hi) { + // Map values in *in* from [0, 1) to [lo, hi) and cast to type *T* for *out*. + // Potentially flip signs if RandomSign==true. + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float val = lo + (hi - lo) * in[idx]; + + if constexpr (RandomSign) { + if (signs[idx] < 0.5f) + val = -val; + } + + out[idx] = static_cast(val); + } +} + +template +static void fillUniformLinearBufferDevice(T* dst_dev, + T* dst_cpu, // nullable + size_t N, + unsigned long long seed, + double lo, double hi, + bool random_sign=false) { + // Fill a linear device buffer with uniform randoms in [*lo*, *hi*] and cast them to *T*. + // Optionally mirror the result into a provided CPU pointer. + if (N == 0) + return; + + float* tmp = nullptr; + NVTE_CHECK_CUDA(hipMalloc(&tmp, N * sizeof(float))); + + float* tmp_sign = nullptr; + if (random_sign) { + NVTE_CHECK_CUDA(hipMalloc(&tmp_sign, N * sizeof(float))); + } + + hiprandGenerator_t gen; + NVTE_CHECK(hiprandCreateGenerator(&gen, HIPRAND_RNG_PSEUDO_PHILOX4_32_10) == HIPRAND_STATUS_SUCCESS); + NVTE_CHECK(hiprandSetPseudoRandomGeneratorSeed(gen, seed) == HIPRAND_STATUS_SUCCESS); + NVTE_CHECK(hiprandGenerateUniform(gen, tmp, N) == HIPRAND_STATUS_SUCCESS); + + if (random_sign) { + NVTE_CHECK(hiprandGenerateUniform(gen, tmp_sign, N) == HIPRAND_STATUS_SUCCESS); + } + + dim3 block(256); + dim3 grid((N + block.x - 1) / block.x); + + if (random_sign) + hipLaunchKernelGGL(( affine_transform_cast_signs), dim3(grid), dim3(block), 0, 0, + tmp, tmp_sign, dst_dev, N, lo, hi); + else + hipLaunchKernelGGL(( affine_transform_cast_signs), dim3(grid), dim3(block), 0, 0, + tmp, nullptr, dst_dev, N, lo, hi); + + NVTE_CHECK_CUDA(hipGetLastError()); + + if (dst_cpu != nullptr) { + NVTE_CHECK_CUDA(hipMemcpy(dst_cpu, dst_dev, N * sizeof(T), hipMemcpyDeviceToHost)); + } + + NVTE_CHECK(hiprandDestroyGenerator(gen) == HIPRAND_STATUS_SUCCESS); + NVTE_CHECK_CUDA(hipFree(tmp)); + if (tmp_sign) + NVTE_CHECK_CUDA(hipFree(tmp_sign)); +} + +template +static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f, + double hi=1.0f, bool random_sign=false) { + void* dst_dev_void = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); + const auto shape = t->rowwise() ? (t->rowwise_shape()) : (t->columnwise_shape()); + const size_t N = product(shape); + + // per-tensor deterministic seed + const unsigned long long seed = static_cast(t->gen()()); + + T* dst_dev = reinterpret_cast(dst_dev_void); + // Keep the CPU mirror in sync. We could use Tensor::to_cpu() here, + // but that does more than just copying the data. + T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, lo, hi, random_sign); +} +#endif + +void fillUniform(Tensor *t) { + if (t->rowwise()) { + const size_t size = product(t->rowwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformTensorDevice(t); +#else + T *data = t->rowwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); +#endif + } + ); + } else { + const size_t size = product(t->columnwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { +#ifdef __HIP_PLATFORM_AMD__ + fillUniformTensorDevice(t); +#else + T *data = t->columnwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); +#endif + } + ); + } +#ifndef __HIP_PLATFORM_AMD__ + // Data is already on device on AMDGPU + t->from_cpu(); +#endif + std::uniform_real_distribution<> dis(-2.0, 1.0); + t->set_scale_inv(dis(t->gen())); +} + +template +void fillCase_special(Tensor *t) { + const size_t size = product(t->rowwise_shape()); + + if constexpr (Case == InputsFillCase::zeros) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { +#ifdef __HIP_PLATFORM_AMD__ + // Fill device and CPU mirror + void* dst_dev = t->rowwise_dptr(); + NVTE_CHECK_CUDA(hipMemset(dst_dev, 0, size * sizeof(InputType))); + InputType* dst_cpu = t->rowwise_cpu_dptr(); + std::fill_n(dst_cpu, size, static_cast(0)); +#else + InputType *data = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(0); + } +#endif + }); + } else { + double minAbs = -2.0; + double maxAbs = 1.0; + if constexpr (Case != InputsFillCase::uniform) { + minAbs = Quantized_Limits::ranges[Case]; + maxAbs = Quantized_Limits::ranges[Case + 1]; + } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { +#ifdef __HIP_PLATFORM_AMD__ + const unsigned long long seed = static_cast(t->gen()()); + InputType* dst_dev = static_cast(t->rowwise_dptr()); + InputType* dst_cpu = static_cast(t->rowwise_cpu_dptr()); + fillUniformLinearBufferDevice(dst_dev, dst_cpu, size, seed, + minAbs, maxAbs, /*random_sign=*/true); +#else + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + InputType *data = t->rowwise_cpu_dptr(); + for (size_t idx = 0; idx < size; ++idx) { + const bool is_negative = (dis_sign(t->gen()) < 0.0); + double val = dis(t->gen()); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } +#endif + }); + } + t->set_scale_inv(1.0); +#ifndef __HIP_PLATFORM_AMD__ + t->from_cpu(); +#endif +} + +template +void fillCase(Tensor *t, const InputsFillCase fill_case) { + switch (fill_case) { + case InputsFillCase::uniform: + fillCase_special(t); break; + case InputsFillCase::zeros: + fillCase_special(t); break; + case InputsFillCase::zero_to_minNorm: + fillCase_special(t); break; + case InputsFillCase::minNorm_to_maxNorm: + fillCase_special(t); break; + case InputsFillCase::maxNorm_to_inf: + fillCase_special(t); break; + } +} + +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +#if FP4_TYPE_SUPPORTED +template void fillCase(Tensor *t, const InputsFillCase fill_case); +#endif + +void setRandomScale(Tensor *t) { + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float scale = dis(t->gen()); + t->set_scale(scale); +} + +void setRandomScaleInv(Tensor *t) { + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float scale_inv = dis(t->gen()); + t->set_scale_inv(scale_inv); +} + +bool isFp8Type(DType type) { + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; +} + +bool isFp4Type(DType type) { + return type == DType::kFloat4E2M1; +} + +int32_t getDeviceComputeCapability() { + hipDeviceProp_t deviceProp; + hipGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; +} + +size_t first_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + if (shape.size() == 1) return 1; + return product(shape, 0, shape.size() - 1); +} + +size_t last_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + return shape[shape.size() - 1]; +} + +std::array get_scale_tensor_dims(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols) { + const bool is_rowwise = (block_size_rows == 1) + && ((block_size_cols == 32) || (block_size_cols == 16)); + + const size_t alignment_Y = is_rowwise + ? scale_tensor_alignment_Y_rowwise + : scale_tensor_alignment_Y_colwise; + const size_t alignment_X = is_rowwise + ? scale_tensor_alignment_X_rowwise + : scale_tensor_alignment_X_colwise; + + const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols); + + const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y); + const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X); + return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; +} + +} // namespace test diff --git a/benchmarks/cpp/utils/test_common.h b/benchmarks/cpp/utils/test_common.h new file mode 100644 index 000000000..50a9defab --- /dev/null +++ b/benchmarks/cpp/utils/test_common.h @@ -0,0 +1,684 @@ +// !!! This is a file automatically generated by hipify!!! +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#ifndef USE_ROCM +#define FP4_TYPE_SUPPORTED (TORCH_HIP_VERSION >= 12080) +#include +#include "common/amd_detail/hip_float8.h" +#if FP4_TYPE_SUPPORTED +#include +#endif +#else +#define FP4_TYPE_SUPPORTED (false) +#include +#include "amd_detail/hip_float8.h" +#endif +#include + +#include +#include "util/logging_hip.h" + +namespace test { +using namespace transformer_engine; + +template +struct BytesToType {}; + +template <> +struct BytesToType<1> { + using Type = uint8_t; +}; + +template <> +struct BytesToType<2> { + using Type = uint16_t; +}; + +template <> +struct BytesToType<4> { + using Type = uint32_t; +}; + +template <> +struct BytesToType<8> { + using Type = uint64_t; +}; + +using byte = uint8_t; +using int16 = int16_t; +using int32 = int32_t; +using int64 = int64_t; +using fp32 = float; +using fp16 = half; +#ifndef USE_ROCM +using bf16 = nv_bfloat16; +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; +#else +using bf16 = hip_bfloat16; +using fp8e4m3 = te_hip_fp8_e4m3; +using fp8e5m2 = te_hip_fp8_e5m2; +#endif //USE_ROCM +using fp8e8m0 = uint8_t; +#if FP4_TYPE_SUPPORTED +using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; +#endif + +template +struct BitsNumber; + +#if FP4_TYPE_SUPPORTED +template <> +struct BitsNumber { + static constexpr size_t num_bits = 4; +}; +#endif + +template +struct BitsNumber { + static constexpr size_t num_bits = 8 * sizeof(T); +}; + +template +struct TypeInfo { +#if FP4_TYPE_SUPPORTED + using types = std::tuple; +#else + using types = std::tuple; +#endif + + template + struct Helper { + constexpr static DType getType() { + constexpr int i = static_cast(current); + if (std::is_same::type>::value) { + return current; + } else { + return Helper(i + 1)>::getType(); + } + } + }; + + template + struct Helper { + constexpr static DType getType() { + return DType::kNumTypes; + } + }; + + template + constexpr static DType getType() { + return Helper::getType(); + } + + constexpr static DType dtype = getType(); + constexpr static size_t size = BitsNumber::num_bits;; +}; + +class Tensor { + public: + Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + + Tensor(const std::string& name, + const std::vector &shape, + const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {} + + Tensor() {} + + Tensor& operator=(const Tensor &other) = delete; + Tensor(const Tensor &other) = delete; + + Tensor(Tensor &&other) = default; + Tensor& operator=(Tensor &&other) = default; + + ~Tensor() { + void *data_ptr = tensor_.dptr(); + void *scale_inv = tensor_.scale_inv(); + void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr; + void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr; + if (columnwise_data_ptr == data_ptr) { + columnwise_data_ptr = nullptr; + } + if (columnwise_scale_inv == scale_inv) { + columnwise_scale_inv = nullptr; + } + if (data_ptr != nullptr) { + (void)hipFree(data_ptr); + } + if (scale_inv != nullptr) { + (void)hipFree(scale_inv); + } + if (columnwise_data_ptr != nullptr){ + (void)hipFree(columnwise_data_ptr); + } + if (columnwise_scale_inv != nullptr){ + (void)hipFree(columnwise_scale_inv); + } + } + + NVTETensor data() const noexcept { return tensor_.data(); } + + NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; } + + NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; } + + NVTEShape rowwise_scale_inv_shape() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().shape; + } + + NVTEShape columnwise_scale_inv_shape() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().shape; + } + + NVTEScalingMode scaling_mode() const noexcept { + return tensor_.scaling_mode(); + } + + DType dtype() const noexcept { + return tensor_.dtype(); + } + + void *rowwise_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_data().data_ptr; + } + + void *columnwise_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_data().data_ptr; + } + + template + T *rowwise_cpu_dptr() const { + NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return reinterpret_cast(cpu_data_rowwise_.get()); + } + + template + T *columnwise_cpu_dptr() const { + NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return reinterpret_cast(cpu_data_columnwise_.get()); + } + + float amax() const { + if(amax_cpu_data_) { + to_cpu(); + return *amax_cpu_data_; + } else { + return 0; + } + } + + void *amax_dptr() const { + return tensor_.amax(); + } + + float scale() const { + if(scale_cpu_data_) { + NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING), + "Invalid scaling_mode!"); + to_cpu(); + return *scale_cpu_data_; + } else { + return 1; + } + } + + template + T *rowwise_cpu_scale_inv_ptr() const { + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(rowwise_scale_inv_cpu_data_.get()); + } + + template + T *columnwise_cpu_scale_inv_ptr() const { + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); + } + + float rowwise_scale_inv() const { + if(rowwise_scale_inv_cpu_data_) { + float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; + return scale_inv; + } else { + return 1; + } + } + + bool rowwise() const { + return rowwise_; + } + + bool columnwise() const { + return columnwise_; + } + + void set_tensor_amax_nullptr(){ + tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); + } + + void to_cpu() const; + void from_cpu() const; + void set_scale(float scale); + void set_scale_inv(float scale_inv); + void shareFP8Meta(const Tensor &other); + + std::mt19937& gen() { return gen_; } + + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.scale_inv(); // rowwise scale_inv backing storage + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + + private: + TensorWrapper tensor_; + std::unique_ptr cpu_data_rowwise_; + std::unique_ptr cpu_data_columnwise_; + std::shared_ptr amax_cpu_data_; + std::shared_ptr scale_cpu_data_; + std::unique_ptr rowwise_scale_inv_cpu_data_; + std::unique_ptr columnwise_scale_inv_cpu_data_; + bool rowwise_; + bool columnwise_; + std::string name_; + std::mt19937 gen_; +}; + +constexpr uint32_t FP32_EXPONENT_BIAS = 127; +constexpr uint32_t FP32_MANTISSA_BITS = 23; + +// [128,4] rowwise and [4,128] colwise alignment requirement +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t scale_tensor_alignment_X_rowwise = 1; +constexpr size_t scale_tensor_alignment_Y_rowwise = 1; +constexpr size_t scale_tensor_alignment_X_colwise = 1; +constexpr size_t scale_tensor_alignment_Y_colwise = 1; +#else +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +#endif + +inline size_t divide_round_up(const size_t N, const size_t M) { + return (N - 1 + M) / M; +} + +inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) { + return divide_round_up(N, M) * M; +} + +template +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0; + static constexpr double maxSubnorm = 1.0; + static constexpr double minNorm = 1.0; + static constexpr double maxNorm = 1.0; + static constexpr double artifInf = 1.0; + static constexpr int maxBiasedExponent = 1; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 9); // std::pow(2.0, -9.0); + static constexpr double maxSubnorm = 0.875 / static_cast(1 << 6); // std::pow(2.0, -6.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 6); // std::pow(2.0, -6.0); +#ifndef USE_ROCM + static constexpr double maxNorm = 448.0; +#else + static const double maxNorm; +#endif //USE_ROCM + static const double artifInf; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 8; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +#ifdef USE_ROCM +inline const double Numeric_Traits::maxNorm = te_fp8_fnuz() ? 240.0 : 448.0; +#endif + +inline const double Numeric_Traits::artifInf = 10.0 * Numeric_Traits::maxNorm; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 16); // std::pow(2.0, -16.0); + static constexpr double maxSubnorm = 0.75 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double maxNorm = 57344.0; + static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 15; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = std::numeric_limits::denorm_min(); // std::pow(2.0, -149.0); + static constexpr double maxSubnorm = std::numeric_limits::min() + - std::numeric_limits::denorm_min(); // minNormalized - minDenormalized + static constexpr double minNorm = std::numeric_limits::min(); // std::pow(2.0, -126.0); + static constexpr double maxNorm = std::numeric_limits::max(); // (1 - pow(2, -24)) * pow(2, 128) + static constexpr double artifInf = std::numeric_limits::infinity(); + static constexpr int maxBiasedExponentAsFP32 = 255; + static constexpr int maxUnbiasedExponentAsFP32 = 128; +}; + +template +struct Quantized_Limits { + static const double ranges[4]; + static constexpr inline fp32 max() { return static_cast(Numeric_Traits::maxNorm); } + static constexpr inline fp32 max_reciprocal() { return static_cast(1.0 / max()); } + static constexpr inline fp32 emax() { return static_cast(Numeric_Traits::maxExpNorm); } + static constexpr inline fp32 emax_reciprocal() { return static_cast(1.0 / emax()); } + static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits::maxBiasedExponentAsFP32; } + static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits::maxUnbiasedExponentAsFP32; } +}; + +template +inline const double Quantized_Limits::ranges[4] = { + 0.0, + Numeric_Traits::minNorm, + Numeric_Traits::maxNorm, + Numeric_Traits::artifInf +}; +// Input data filling cases +// Considering normal and subnormal magnitudes of E4M3 and E5M2 formats +// with nearest to even rounding per OFP8 specification +enum InputsFillCase { + zero_to_minNorm = 0, // [0, min_normal) + minNorm_to_maxNorm = 1, // [min_normal, max_normal) + maxNorm_to_inf = 2, // [max_normal, inf) + zeros = 3, // {0} + uniform = 4, // std::uniform_real_distribution<> dis(-2.0, 1.0) +}; + +inline fp8e8m0 float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (std::isnan(val)) { + return 0xFF; + } + if (std::isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +} + +inline float exp2f_rcp(fp8e8m0 biased_exp) { + if (biased_exp == 0) { + return 1.0f; + } + int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) + float fp32_val = *reinterpret_cast(&int_val); + return fp32_val; +} + +inline float identity(const float x) { return x; } +inline float gelu(const float x) { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); } +inline float dgelu(const float x) { + const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x)); + return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + + 0.5f * (1 + tanh_out); +} +inline float sigmoid(const float x) { return 1 / (1 + expf(-x)); } +inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); } +inline float qgelu(const float x) { return x * sigmoid(1.702f * x); } +inline float dqgelu(const float x) { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); } +inline float relu(const float x) { return fmaxf(0, x); } +inline float drelu(const float x) { return x > 0 ? 1 : 0; } +inline float silu(const float x) { return x * sigmoid(x); } +inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } +inline float srelu(const float x) { return x > 0 ? x * x : 0; } +inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } + +size_t typeToNumBits(DType type); +size_t product(const NVTEShape &shape); +size_t product(const std::vector &shape); +size_t bytes(const NVTEShape& shape, const DType type); + +size_t first_dimension(const std::vector &shape); +size_t last_dimension(const std::vector &shape); + +bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); + +void compareResults(const std::string &name, const Tensor &test, const void *ref, + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + const size_t tolerable_mismatches_limit = 0); +void compareResults(const std::string &name, const float test, const float ref, + double atol = 1e-5, double rtol = 1e-8); +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol = 0.); +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef USE_ROCM + std::vector& mismatch_indices, +#endif //#ifdef USE_ROCM + size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); + +#ifdef USE_ROCM +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t scale_stride, const size_t rows, + const size_t cols, bool rowwise, void *ref_ptr, DType otype); +#endif + +std::array get_scale_tensor_dims(const size_t rows, const size_t cols, + const size_t block_size_rows, const size_t block_size_cols); + +std::pair getTolerances(const DType type); + +void fillUniform(Tensor *t); + +template +void fillCase(Tensor *t, const InputsFillCase fill_case); + +void setRandomScale(Tensor *t); +void setRandomScaleInv(Tensor *t); + +constexpr int THREADS_PER_WARP = 32; + +const std::string &typeName(DType type); +const std::string& caseName(InputsFillCase type); + +extern std::vector all_fp_types; + +bool isFp8Type(DType type); +bool isFp4Type(DType type); + +int32_t getDeviceComputeCapability(); +constexpr int32_t hopperComputeCapability = 90; +constexpr int32_t blackwellComputeCapability = 100; + +} // namespace test + +#if FP4_TYPE_SUPPORTED +#define SWITCH_FP4_TYPE_HANDLE(type, ...) \ + case DType::kFloat4E2M1: { \ + using type = fp4e2m1; \ + { __VA_ARGS__ } \ + } break; +#else +#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing +#endif + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kByte: \ + { \ + using type = byte; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kInt32: \ + { \ + using type = int32; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kInt64: \ + { \ + using type = int64; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat32: \ + { \ + using type = float; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat16: \ + { \ + using type = fp16; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kBFloat16: \ + { \ + using type = bf16; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat8E4M3: \ + { \ + using type = fp8e4m3; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat8E5M2: \ + { \ + using type = fp8e5m2; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat8E8M0: \ + { \ + using type = fp8e8m0; \ + {__VA_ARGS__} \ + } \ + break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ + default: \ + printf("dtype: %d\n", static_cast(dtype)); \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E4M3: \ + { \ + using type = fp8e4m3; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat8E5M2: \ + { \ + using type = fp8e5m2; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: \ + { \ + using type = float; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat16: \ + { \ + using type = fp16; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kBFloat16: \ + { \ + using type = bf16; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index a056cf50f..6a0d65685 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -807,20 +807,33 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu const size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); +#ifdef __HIP_PLATFORM_AMD__ + // rowwise-only uses direct global reads/writes, no shmem needed + size_t in_mem = 0; + size_t out_mem = 0; + if (USE_COLWISE_SCALING) { + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + in_mem = grad_mem + in_act_mem + in_gate_mem; + + out_mem = buff_size_aligned_out; + if constexpr (IS_BWD) { + out_mem += buff_size_aligned_out; + } + } +#else const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; -#ifdef __HIP_PLATFORM_AMD__ - const size_t out_gate_mem = buff_size_aligned_out; -#else const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); -#endif size_t out_mem = out_act_mem + out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } +#endif const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a69915c0d..4e4d8606a 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -576,9 +576,76 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t cols = input.flat_last_dim(); #ifdef __HIP_PLATFORM_AMD__ - constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; - constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; - constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; + // Choose chunk size based on tensor dimensions + const bool use_large_chunks = (rows * cols > 32 * 1024 * 1024); + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + TRANSFORMER_ENGINE_CHUNK_DIM_SWITCH( + use_large_chunks, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK, + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(IType))), IS_ALIGNED, + quantize_mxfp8_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + noop_ptr, workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))))); + + if constexpr (IS_DBIAS) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + ); // NOLINT(*) + } + ); #else constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); @@ -590,7 +657,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; constexpr size_t BUFF_DIM_Y = THREADS_Y; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); @@ -608,7 +674,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; -#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (use_rowwise_scaling && (!use_colwise_scaling)) { scaling_type = ScalingType::ROWWISE; @@ -617,7 +682,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } else if (use_rowwise_scaling && use_colwise_scaling) { scaling_type = ScalingType::BIDIMENSIONAL; } -#endif if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); @@ -639,27 +703,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, - -#ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - quantize_mxfp8_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - ))); // NOLINT(*) -#else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_output_rowwise{}; @@ -749,12 +792,12 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } } -#endif // #ifdef __HIP_PLATFORM_AMD__ if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); }); // NOLINT(*) ); // NOLINT(*) +#endif // __HIP_PLATFORM_AMD__ } } // namespace mxfp8 diff --git a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index aac39c6d2..147951899 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -41,7 +41,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t THREADS_PER_SCALE_X_ROWWISE = DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(IType); + constexpr size_t VECTOR_WIDTH = IS_ALIGNED ? 8 : 16; const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; @@ -68,7 +68,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); diff --git a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index eb389f9db..fee3db3e0 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -9,370 +9,526 @@ //#include "hip/hip_runtime.h" //dummy include to prevent hipification adding this header constexpr size_t ALIGNMENT_SIZE = 128; -// TODO: Identify optimal chunk/thread size for MI350+ constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; constexpr size_t THREADS_PER_CHUNK = 256; -constexpr size_t THREADS_PER_CHUNK_X = 64; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 256 / 64 constexpr size_t BUFFERS_NUM = 1; // No async load for HIP constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 2 static_assert(ITERATIONS >= 1); +constexpr size_t ELEMS_PER_THREAD = 8; +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 32 +static_assert(THREADS_PER_CHUNK_Y_ROWWISE <= BUFFER_DIM_Y); +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 64 +constexpr size_t THREADS_PER_CHUNK_Y_COLWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_COLWISE; // 4 +constexpr size_t BUFFER_STAGES_NUM_COLWISE = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_COLWISE; // 8 +constexpr size_t THREADS_PER_CHUNK_X = THREADS_PER_CHUNK_X_COLWISE; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK_Y_COLWISE; +constexpr size_t BUFFER_STAGES_NUM = BUFFER_STAGES_NUM_COLWISE; + __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } +template +__device__ inline void compute_gated_activation( + float act_elt, float gate_elt, float grad_elt, + const ParamOP &p, float &result_act, float &result_gate) { + + bool dgate_elt_valid = true; + if constexpr (std::is_same::value) { + dgate_elt_valid = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } + + if constexpr (IS_DGATED) { + const float x = act_elt; + float act_x, dact_x; + if constexpr (std::is_same::value) { + const float cx = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * cx); + act_x = cx * s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * cx : 0.0f; + } else { + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } + } + result_act = dact_x * grad_elt * gate_elt; + result_gate = dgate_elt_valid ? act_x * grad_elt : 0.0f; + } else { + result_act = ActOP(act_elt, p) * gate_elt; + result_gate = 0.0f; + } + + if constexpr (!std::is_same_v) { + result_act = static_cast(static_cast(result_act)); + if constexpr (IS_DGATED) { + result_gate = static_cast(static_cast(result_gate)); + } + } +} + template __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_gated_mxfp8_kernel( - const IType *grad_ptr, - const IType *input_act, - const IType *input_gate, - OType *output_act_rowwise, - OType *output_gate_rowwise, - OType *output_act_colwise, - OType *output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, const ParamOP p) { + const IType *grad_ptr, const IType *input_act, const IType *input_gate, OType *output_act_rowwise, + OType *output_gate_rowwise, OType *output_act_colwise, OType *output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, const size_t scale_stride_colwise, const ParamOP p) { constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 4 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; - const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType); + const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + constexpr size_t VECTOR_WIDTH_IN = IS_ALIGNED ? 8 : 16; + constexpr size_t VECTOR_WIDTH_OUT = 16; - const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + constexpr size_t ROWS_PER_THREAD = CHUNK_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + // ROWWISE-ONLY PATH: Direct global memory, no shared memory + if constexpr (USE_ROWWISE_SCALING && !USE_COLWISE_SCALING) { + const size_t col_start = chunk_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); +#pragma unroll + for (size_t r = 0; r < ROWS_PER_THREAD; r++) { + const size_t row = chunk_offset_Y + tid_rowwise_Y + r * THREADS_PER_CHUNK_Y_ROWWISE; - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = in_act_mem + in_gate_mem; + const bool row_valid = (row < rows); - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t out_mem = out_act_mem + out_gate_mem; + Vec act_vec, gate_vec, grad_vec; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + act_vec.load_from(&input_act[row * 2*cols + col_start]); + gate_vec.load_from(&input_gate[row * 2*cols + col_start]); + if constexpr (IS_DGATED) { + grad_vec.load_from(&grad_ptr[row * cols + col_start]); + } + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + act_vec.data.elt[j] = (col_start + j < cols) ? input_act[row * 2*cols + col_start + j] : static_cast(0); + gate_vec.data.elt[j] = (col_start + j < cols) ? input_gate[row * 2*cols + col_start + j] : static_cast(0); + if constexpr (IS_DGATED) { + grad_vec.data.elt[j] = (col_start + j < cols) ? grad_ptr[row * cols + col_start + j] : static_cast(0); + } + } + } + } + + // Compute activations + float computed_act[ELEMS_PER_THREAD]; + float computed_gate[ELEMS_PER_THREAD]; + float act_amax = 0; + float gate_amax = 0; - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); + float act_elt = static_cast(act_vec.data.elt[j]); + float gate_elt = static_cast(gate_vec.data.elt[j]); + float grad_elt = IS_DGATED ? static_cast(grad_vec.data.elt[j]) : 0.0f; + + compute_gated_activation( + act_elt, gate_elt, grad_elt, p, computed_act[j], computed_gate[j]); + + if (!out_of_bounds) { + act_amax = fmaxf(act_amax, fabsf(computed_act[j])); + if constexpr (IS_DGATED) { + gate_amax = fmaxf(gate_amax, fabsf(computed_gate[j])); + } + } + } - OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + // --- Act rowwise quantization --- + { + __builtin_assume(act_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(act_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); - OType *out_act_colwise_sh = out_act_rowwise_sh; - OType *out_gate_colwise_sh = out_gate_rowwise_sh; + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_act[j] * scale_inv); + } - if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { - out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); - out_gate_colwise_sh = - reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_act_rowwise[row * output_cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_act_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } + } + } + + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exp; + } + } + + // --- Gate rowwise quantization (BWD only) --- + if constexpr (IS_DGATED) { + __builtin_assume(gate_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(gate_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); + + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_gate[j] * scale_inv); + } + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_gate_rowwise[row * output_cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_gate_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } + } + } + + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE + + DIVUP(cols, SCALE_DIM_X); + scales_rowwise[scale_idx] = biased_exp; + } + } + } } - __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + // COLWISE PATH: Shared memory for input + colwise output + if constexpr (USE_COLWISE_SCALING) { + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_elems_total = BUFFERS_NUM * buff_elems; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_mem = in_act_mem + buff_size_aligned_in; // act + gate + + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - __syncthreads(); + OType *out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + buff_size_aligned_out); + + // For colwise cross-thread Y reduction + __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y_COLWISE][CHUNK_DIM_X]; + + __syncthreads(); for (int it = 0; it < ITERATIONS; it++) { const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - const size_t row_base = chunk_it_offset_y; - // Initiate bulk tensor copy + // === Load input to shmem === if constexpr (IS_DGATED) { - copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, - cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + copy_2d_to_shared( + &in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, + cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } - - // Act - copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - - // Gate - copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + copy_2d_to_shared( + &in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + copy_2d_to_shared( + &in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); - const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; - const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; - - float after_dact_reg[BUFFER_STAGES_NUM]; - float after_dgate_reg[BUFFER_STAGES_NUM]; - float thread_Y_mx_block_amax = 0.0f; - float thread_Y_mx_block_amax_gate = 0.0f; + if constexpr (USE_ROWWISE_SCALING) { + const size_t row = chunk_it_offset_y + tid_rowwise_Y; + const size_t col_start = chunk_offset_X + thread_offset_X_rowwise; + + const bool row_valid = (row < rows); + const bool col_valid = (col_start < cols); + + const int shmem_base = tid_rowwise_Y * SHMEM_DIM_X + thread_offset_X_rowwise; + Vec act_vec, gate_vec; + act_vec.load_from(&in_act_sh[shmem_base]); + gate_vec.load_from(&in_gate_sh[shmem_base]); + Vec grad_vec; + if constexpr (IS_DGATED) { + grad_vec.load_from(&in_grad_sh[shmem_base]); + } - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + float computed_act[ELEMS_PER_THREAD]; + float computed_gate[ELEMS_PER_THREAD]; + float act_amax = 0; + float gate_amax = 0; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + float act_elt = static_cast(act_vec.data.elt[j]); + float gate_elt = static_cast(gate_vec.data.elt[j]); + float grad_elt = IS_DGATED ? static_cast(grad_vec.data.elt[j]) : 0.0f; - float act_elt = static_cast(in_act_sh[shmem_idx]); - float gate_elt = static_cast(in_gate_sh[shmem_idx]); + compute_gated_activation( + act_elt, gate_elt, grad_elt, p, computed_act[j], computed_gate[j]); - bool dgate_elt = true; // gating is ideally an identity function - if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + act_amax = fmaxf(act_amax, fabsf(computed_act[j])); + if constexpr (IS_DGATED) { + gate_amax = fmaxf(gate_amax, fabsf(computed_gate[j])); + } } - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh[shmem_idx]); - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr (std::is_same::value) { - const float x = min(act_elt, p.limit); - const float s = sigmoidf(p.alpha * x); - act_x = x * s; - dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; - } else { - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; + // --- Act rowwise quantization --- + { + __builtin_assume(act_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(act_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); + + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_act[j] * scale_inv); + } + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_act_rowwise[row * output_cols + col_start]); } else { - act_x = ActOP(x, p); - dact_x = DActOP(x, p); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_act_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } } } - - after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = dgate_elt ? act_x * grad_elt : 0.0f; - } else { - after_dact_reg[stage] = ActOP(act_elt, p) * gate_elt; - } - // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 - if constexpr (!std::is_same_v) { - after_dact_reg[stage] = static_cast(static_cast(after_dact_reg[stage])); - if constexpr (IS_DGATED) { - after_dgate_reg[stage] = static_cast(static_cast(after_dgate_reg[stage])); + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exp; } } - if constexpr (USE_ROWWISE_SCALING) { - if constexpr (IS_DGATED) { - // dgate - float amax = fabsf(after_dgate_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); - - out_gate_rowwise_sh[shmem_idx] = - static_cast(scale_reciprocal_X * after_dgate_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + // --- Gate rowwise quantization (BWD only) --- + if constexpr (IS_DGATED) { + __builtin_assume(gate_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(gate_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); + + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_gate[j] * scale_inv); + } + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_gate_rowwise[row * output_cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_gate_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } } } - float amax = fabsf(after_dact_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); - - out_act_rowwise_sh[shmem_idx] = - static_cast(scale_reciprocal_X * after_dact_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE + + DIVUP(cols, SCALE_DIM_X); + scales_rowwise[scale_idx] = biased_exp; } } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (chunk_offset_X + tid_colwise_X >= cols); + const size_t row_base = chunk_it_offset_y; + const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; + + float after_dact_reg[BUFFER_STAGES_NUM_COLWISE]; + float after_dgate_reg[BUFFER_STAGES_NUM_COLWISE]; + + float thread_Y_mx_block_amax = 0.0f; + float thread_Y_mx_block_amax_gate = 0.0f; + + // Compute activation and accumulate column amax + for (int stage = 0; stage < BUFFER_STAGES_NUM_COLWISE; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_COLWISE; + const int shmem_offset_y = tid_colwise_Y + stage_offset_Y; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + tid_colwise_X; + + float act_elt = static_cast(in_act_sh[shmem_idx]); + float gate_elt = static_cast(in_gate_sh[shmem_idx]); + float grad_elt = 0.0f; + if constexpr (IS_DGATED) { + grad_elt = static_cast(in_grad_sh[shmem_idx]); + } + + compute_gated_activation( + act_elt, gate_elt, grad_elt, p, after_dact_reg[stage], after_dgate_reg[stage]); - if constexpr (USE_COLWISE_SCALING) { __builtin_assume(thread_Y_mx_block_amax >= 0); - __builtin_assume(thread_Y_mx_block_amax_gate >= 0); thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); if constexpr (IS_DGATED) { + __builtin_assume(thread_Y_mx_block_amax_gate >= 0); thread_Y_mx_block_amax_gate = fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); } } - } - if constexpr (USE_COLWISE_SCALING) { const bool row_out_of_bounds = (row_base >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); if constexpr (IS_DGATED) { - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + if (tid_colwise_Y > 0) { + stage_amax_sh[tid_colwise_Y][tid_colwise_X] = thread_Y_mx_block_amax_gate; } __syncthreads(); - if (tid_Y == 0) { + if (tid_colwise_Y == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + for (int y = 1; y < THREADS_PER_CHUNK_Y_COLWISE; ++y) { thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_colwise_X]); } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + stage_amax_sh[0][tid_colwise_X] = thread_Y_mx_block_amax_gate; } __syncthreads(); - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax - - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); - } + const float mx_block_Y_amax = stage_amax_sh[0][tid_colwise_X]; + __builtin_assume(mx_block_Y_amax >= 0); const e8m0_t biased_exponent = ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { + if ((tid_colwise_Y == 0) && !out_of_bounds) { const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X + cols; const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; scales_colwise[scale_idx] = biased_exponent; } #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - + for (int stage = 0; stage < BUFFER_STAGES_NUM_COLWISE; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_COLWISE; + const int shmem_idx = (tid_colwise_Y + stage_offset_Y) * SHMEM_DIM_X + tid_colwise_X; out_gate_colwise_sh[shmem_idx] = static_cast(scale_reciprocal * after_dgate_reg[stage]); } } - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; - } - __syncthreads(); - if (tid_Y == 0) { + + { + if (tid_colwise_Y > 0) { + stage_amax_sh[tid_colwise_Y][tid_colwise_X] = thread_Y_mx_block_amax; + } + __syncthreads(); + if (tid_colwise_Y == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + for (int y = 1; y < THREADS_PER_CHUNK_Y_COLWISE; ++y) { + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_colwise_X]); + } + stage_amax_sh[0][tid_colwise_X] = thread_Y_mx_block_amax; } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax - } - __syncthreads(); - - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + __syncthreads(); - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { + const float mx_block_Y_amax = stage_amax_sh[0][tid_colwise_X]; __builtin_assume(mx_block_Y_amax >= 0); - } - const e8m0_t biased_exponent = - ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - } + const e8m0_t biased_exponent = + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); + + if ((tid_colwise_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; + + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_act_colwise_sh[shmem_idx] = - static_cast(scale_reciprocal * after_dact_reg[stage]); + for (int stage = 0; stage < BUFFER_STAGES_NUM_COLWISE; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_COLWISE; + const int shmem_idx = (tid_colwise_Y + stage_offset_Y) * SHMEM_DIM_X + tid_colwise_X; + out_act_colwise_sh[shmem_idx] = + static_cast(scale_reciprocal * after_dact_reg[stage]); + } } } __syncthreads(); - if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - } - } - - if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - } + bulk_tensor_2d_shared_to_global( + &out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + if constexpr (IS_DGATED) { + bulk_tensor_2d_shared_to_global( + &out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } __syncthreads(); } + } } diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index 198728f34..40913cc03 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -9,30 +9,22 @@ constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; constexpr size_t ELEMS_PER_THREAD = 16; -constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported -constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 -constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 - -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = - MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 -constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = - MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_BUFF_STAGES_NUM = - MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 -constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported + +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ +typedef short mxfp8_v2i16_t __attribute__((ext_vector_type(2))); +#endif template -__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + size_t SCALE_DIM_X, bool IS_ALIGNED, + size_t CHUNK_DIM_Y = 64, + size_t CHUNK_DIM_X = 64, + size_t THREADS_PER_CHUNK = 64> +__global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_mxfp8_kernel( const IType *input_ptr, const IType *act_input_ptr, @@ -47,29 +39,34 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = - SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = - SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; + constexpr size_t SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; + constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; + + constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; + constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; + constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; + + constexpr size_t BUFF_STAGES_NUM = MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; + constexpr size_t ITERATIONS = CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = - SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_X = - SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = CHUNK_DIM_Y; + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = CHUNK_DIM_X; constexpr size_t THREADS_PER_SCALE_X_ROWWISE = DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 - constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType); + // Cap vector width so each load/store is at most 16 bytes (AMD max: global_load_dwordx4) + constexpr size_t VECTOR_WIDTH_IN = 16 / sizeof(IType); // BF16/FP16: 8, FP32: 4 + constexpr size_t VECTOR_WIDTH_OUT = 16 / sizeof(OType); // FP8: 16 - const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; @@ -77,102 +74,208 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_Y = tid_rowwise_Y; const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - // const int thread_offset_X_colwise = tid_colwise_X; - - const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; - const int dbias_rowwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; - const int dbias_colwise_offset_Y = blockIdx.y; - const int dbias_colwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; - const int dbias_stride = cols; - - Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; - float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + + const int dbias_rowwise_offset_Y = blockIdx.y + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = block_offset_X + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = blockIdx.y; + const int dbias_colwise_block_offset_X = block_offset_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise; + float partial_dbias_colwise = 0; if constexpr (IS_DBIAS) { if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; i++) { - partial_dbias_rowwise[i].clear(); - } - } else { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; i++) { - partial_dbias_colwise[i] = 0; - } + partial_dbias_rowwise.clear(); } } - // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - alignas(128) __shared__ IType in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ IType act_in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ OType out_rowwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ OType out_colwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - float block_amax = 0; + constexpr size_t ROWS_PER_THREAD = CHUNK_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; + + if constexpr (USE_ROWWISE_SCALING && !USE_COLWISE_SCALING) { + const size_t col_start = block_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); #pragma unroll - for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; chunk++) { - const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + for (size_t r = 0; r < ROWS_PER_THREAD; r++) { + const size_t row = block_offset_Y + tid_rowwise_Y + r * THREADS_PER_CHUNK_Y_ROWWISE; + const bool row_valid = (row < rows); - const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; - const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + Vec in; + Vec act_in; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + in.load_from(&input_ptr[row * cols + col_start]); + if constexpr (IS_DACT) { + act_in.load_from(&act_input_ptr[row * cols + col_start]); + } + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + in.data.elt[j] = (col_start + j < cols) ? input_ptr[row * cols + col_start + j] + : static_cast(0); + } + if constexpr (IS_DACT) { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + act_in.data.elt[j] = (col_start + j < cols) ? act_input_ptr[row * cols + col_start + j] + : static_cast(0); + } + } + } + } - const int scales_rowwise_chunk_offset_Y = - scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = - scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = - scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = - scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; - __syncthreads(); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise.data.elt[j] += elt; + } + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute[j] = elt; + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + { + constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; + uint32_t my_scale = static_cast(biased_exponent); + if constexpr (SCALES_PER_GROUP >= 4) { + uint32_t s1 = __shfl_down(my_scale, 1 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s2 = __shfl_down(my_scale, 2 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s3 = __shfl_down(my_scale, 3 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t packed = (my_scale & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); + if (tid_rowwise_X == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + scales_rowwise_block_offset_X; + reinterpret_cast(&scales_rowwise[scale_idx])[0] = packed; + } + } else { + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = + row * scale_stride_rowwise + + scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exponent; + } + } + } + Vec out_c; +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); + union { + uint32_t packed[ELEMS_PER_THREAD / 4]; + mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; + } cvt_out{}; #pragma unroll - for (int iter = 0; iter < MXFP8_ITERATIONS; iter++) { - const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - const size_t row_base = chunk_it_offset_y; - if constexpr (IS_DACT) { - copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, - chunk_it_offset_x, chunk_it_offset_y, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); + for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { + if constexpr (std::is_same_v) { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } else { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } + } + memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); } - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, - chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, rows, cols); - __syncthreads(); +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + } +#endif - if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec act_in; - Vec out_c; + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_c.store_to(&output_rowwise[row * cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; + } + } + } + } + } + } + + if constexpr (USE_COLWISE_SCALING) { + alignas(128) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(128) __shared__ IType act_in_sh[IS_DACT ? SHMEM_DIM_Y : 1][IS_DACT ? SHMEM_DIM_X : 1]; + alignas(128) __shared__ OType out_colwise_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; - const int iteration_scale_rowwise_offset_Y = - scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const size_t col = block_offset_X + tid_colwise_X; + const bool col_valid_colwise = (col < cols); #pragma unroll - for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; stage++) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X_rowwise; + for (int iter = 0; iter < ITERATIONS; iter++) { + const size_t row_base = block_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if constexpr (IS_DACT) { + copy_2d_to_shared( + &act_in_sh[0][0], act_input_ptr, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + } + copy_2d_to_shared( + &in_sh[0][0], input_ptr, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + __syncthreads(); - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); + if constexpr (USE_ROWWISE_SCALING) { + const size_t col_start = block_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); - in.load_from(&in_sh[shmem_offset_y][shmem_offset_x]); +#pragma unroll + for (int stage = 0; stage < BUFF_STAGES_NUM; stage++) { + const int shmem_y = thread_offset_Y + stage * THREADS_PER_CHUNK_Y_ROWWISE; + const size_t row = row_base + shmem_y; + const bool row_valid = (row < rows); + + Vec in; + Vec act_in; + in.load_from(&in_sh[shmem_y][thread_offset_X_rowwise]); if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_y][shmem_offset_x]); + act_in.load_from(&act_in_sh[shmem_y][thread_offset_X_rowwise]); } float thread_amax = 0; @@ -180,9 +283,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { - const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); float elt = static_cast(in.data.elt[j]); if constexpr (IS_ACT) { elt = OP(elt, {}); @@ -193,10 +294,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { if (!out_of_bounds) { - partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + partial_dbias_rowwise.data.elt[j] += elt; } } - // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 if constexpr (!std::is_same_v) { elt = static_cast(static_cast(elt)); } @@ -213,38 +313,84 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); const e8m0_t biased_exponent = ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - // Only single thread writes the computed scaling factor - const bool col_out_of_bounds = dbias_rowwise_offset_X >= cols; - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && !(row_out_of_bounds || col_out_of_bounds)) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; + + { + constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; + uint32_t my_scale = static_cast(biased_exponent); + if constexpr (SCALES_PER_GROUP >= 4) { + uint32_t s1 = __shfl_down(my_scale, 1 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s2 = __shfl_down(my_scale, 2 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s3 = __shfl_down(my_scale, 3 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t packed = (my_scale & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); + if (tid_rowwise_X == 0 && row_valid && col_valid) { + reinterpret_cast(&scales_rowwise[row * scale_stride_rowwise + scales_rowwise_block_offset_X])[0] = packed; + } + } else { + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exponent; + } + } } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + Vec out_c; +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); + union { + uint32_t packed[ELEMS_PER_THREAD / 4]; + mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; + } cvt_out{}; +#pragma unroll + for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { + if constexpr (std::is_same_v) { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } else { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } + } + memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); + } +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + } +#endif + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_c.store_to(&output_rowwise[row * cols + col_start]); + } else { #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; + } + } + } } - out_c.store_to(&out_rowwise_sh[shmem_offset_y][shmem_offset_x]); } } - if constexpr (USE_COLWISE_SCALING) { - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (threadIdx.x < CHUNK_DIM_X) { float in_compute[SCALE_DIM_Y]; - float amax = 0; + #pragma unroll for (int i = 0; i < SCALE_DIM_Y; i++) { const size_t row = row_base + i; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const bool out_of_bounds = (!col_valid_colwise || row >= rows); float elt = static_cast(in_sh[i][tid_colwise_X]); if constexpr (IS_ACT) { @@ -256,10 +402,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } if constexpr (IS_DBIAS) { if (!out_of_bounds) { - partial_dbias_colwise[chunk_X] += elt; + partial_dbias_colwise += elt; } } - // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 if constexpr (!std::is_same_v) { elt = static_cast(static_cast(elt)); } @@ -275,35 +420,56 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const e8m0_t biased_exponent = ptx::float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); - const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - const bool row_out_of_bounds = row_base >= rows; - if (!(row_out_of_bounds || col_out_of_bounds)) { + if (col_valid_colwise && row_base < rows) { + const int scale_idx = + (scales_colwise_block_offset_Y + iter) * scale_stride_colwise + col; scales_colwise[scale_idx] = biased_exponent; } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); #pragma unroll - for (int i = 0; i < SCALE_DIM_Y; i++) { - out_colwise_sh[i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); + for (int i = 0; i < SCALE_DIM_Y; i += 2) { + union { + uint32_t packed; + mxfp8_v2i16_t v2i16; + uint8_t bytes[4]; + } cvt_out{}; + if constexpr (std::is_same_v) { + cvt_out.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16, in_compute[i], in_compute[i+1], cvt_scale, false); + } else { + cvt_out.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16, in_compute[i], in_compute[i+1], cvt_scale, false); + } + OType val0, val1; + memcpy(&val0, &cvt_out.bytes[0], sizeof(OType)); + memcpy(&val1, &cvt_out.bytes[1], sizeof(OType)); + out_colwise_sh[i][tid_colwise_X] = val0; + if (i + 1 < SCALE_DIM_Y) { + out_colwise_sh[i+1][tid_colwise_X] = val1; + } + } } +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; i++) { + out_colwise_sh[i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } +#endif } - + __syncthreads(); - if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, - chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, rows, cols); - } - if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, - chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, rows, cols); - } + bulk_tensor_2d_shared_to_global( + &out_colwise_sh[0][0], output_colwise, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); } @@ -311,58 +477,45 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if constexpr (IS_DBIAS) { if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + __shared__ float shmem_partial_dbias_rowwise[Y][X][ELEMS_PER_THREAD]; if (tid_rowwise_Y > 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; c++) { - partial_dbias_rowwise[c].store_to( - &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); - } + partial_dbias_rowwise.store_to( + &shmem_partial_dbias_rowwise[tid_rowwise_Y - 1][tid_rowwise_X]); } __syncthreads(); if (tid_rowwise_Y == 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; c++) { - Vec other_row_dbias; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; - - const int left_bound = dbias_rowwise_offset_X; - const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + Vec other_row_dbias; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_block_offset_X; + const int left_bound = dbias_rowwise_block_offset_X; + const int right_bound = dbias_rowwise_block_offset_X + ELEMS_PER_THREAD - 1; #pragma unroll - for (int i = 0; i < Y; i++) { - other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); + for (int i = 0; i < Y; i++) { + other_row_dbias.load_from(&shmem_partial_dbias_rowwise[i][tid_rowwise_X]); #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; - } + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + partial_dbias_rowwise.data.elt[j] += other_row_dbias.data.elt[j]; } + } - // Vectorized store when all elements are inside the boundaries - if (right_bound < cols) { - partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); - } else if (left_bound < cols && right_bound >= cols) { - // Element-by-element store when some elements cross the boundaries - const int in_bound_elts_count = cols - left_bound; - partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, - in_bound_elts_count); - } + if (right_bound < cols) { + partial_dbias_rowwise.store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise.store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); } } } else { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; i++) { - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (threadIdx.x < CHUNK_DIM_X) { + const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_block_offset_X; + const bool col_out_of_bounds = (dbias_colwise_block_offset_X >= cols); if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + dbias_workspace[dbias_offset] = partial_dbias_colwise; } } } @@ -370,8 +523,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); + block_amax = reduce_max(block_amax, warp_id); } if (threadIdx.x == 0 && amax_ptr != nullptr) { diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 10a65bb1d..f26820ba1 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -736,6 +736,21 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_CHUNK_DIM_SWITCH(use_large, CHUNK_Y, CHUNK_X, THREADS, ...) \ + [&] { \ + if (use_large) { \ + constexpr size_t CHUNK_Y = 128; \ + constexpr size_t CHUNK_X = 128; \ + constexpr size_t THREADS = 256; \ + { __VA_ARGS__ } \ + } else { \ + constexpr size_t CHUNK_Y = 64; \ + constexpr size_t CHUNK_X = 64; \ + constexpr size_t THREADS = 128; \ + { __VA_ARGS__ } \ + } \ + }() + #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ if (CONDITION) { \ constexpr bool FLAG = true; \ diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 9901a688e..70584fac3 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -447,10 +447,8 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) const size_t scale_dim_X_rowwise = 32; const size_t scale_dim_Y_colwise = launch_params.training ? 32 : 1; - const size_t chunks_Y = DIVUP(rows, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_X); + const size_t blocks_Y = DIVUP(rows, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_X); const size_t scale_stride_rowwise = launch_params.z_tensor->scale_inv.shape[1]; const size_t scale_stride_colwise = launch_params.training ? launch_params.z_tensor->columnwise_scale_inv.shape[1] : 1; diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 05fe2f539..17648dade 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -9,6 +11,16 @@ namespace transformer_engine { +// AMD Fast tanh using hardware exp instruction +__device__ inline float fast_tanhf(float x) { +#ifdef __HIP_PLATFORM_AMD__ + float e2x = __expf(2.0f * x); + return (e2x - 1.0f) * __frcp_rn(e2x + 1.0f); +#else + return tanhf(x); +#endif +} + struct Empty {}; struct ClampedSwiGLUParam { @@ -19,13 +31,13 @@ struct ClampedSwiGLUParam { template __device__ inline OType gelu(const IType val, const Empty&) { const float cval = val; - return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); + return cval * (0.5F + 0.5F * fast_tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); } template __device__ inline OType dgelu(const IType val, const Empty&) { const float cval = val; - const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); + const float tanh_out = fast_tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) + 0.5f * (1.f + tanh_out); } @@ -33,7 +45,11 @@ __device__ inline OType dgelu(const IType val, const Empty&) { template __device__ inline OType sigmoid(const IType val, const Empty&) { const float cval = val; +#ifdef __HIP_PLATFORM_AMD__ + return __frcp_rn(1.0f + __expf(-cval)); +#else return 1.f / (1.f + expf(-cval)); +#endif } __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } @@ -61,8 +77,8 @@ template __device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) { const float cval = val; Empty e = {}; - return alpha * cval * dsigmoid(alpha * cval, e) + - sigmoid(alpha * cval, e); + const float s = sigmoid(alpha * cval, e); + return s * (1.f + alpha * cval * (1.f - s)); } template @@ -85,7 +101,8 @@ __device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& template __device__ inline OType dsilu(const IType val, const Empty& e) { const float cval = val; - return cval * dsigmoid(cval, e) + sigmoid(cval, e); + const float s = sigmoid(cval, e); + return s * (1.f + cval * (1.f - s)); } template From d2351bca3804328edb21eed61db6b0f6a020f903 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 27 Mar 2026 09:54:13 -0500 Subject: [PATCH 2/5] cleanup and use of TRANSFORMER_ENGINE_SWITCH_CONDITION --- benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp | 2 -- benchmarks/cpp/cast/bench_gated_mxfp8.cpp | 2 -- .../common/cast/mxfp8/quantize_mxfp8.cuh | 8 ++++++-- transformer_engine/common/common.h | 15 --------------- 4 files changed, 6 insertions(+), 21 deletions(-) diff --git a/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp index 7dbfd834e..77222f23f 100644 --- a/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp +++ b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp @@ -12,8 +12,6 @@ #include "benchmark_utils.h" -#include "amd_detail/hip_float8.h" - #include #include diff --git a/benchmarks/cpp/cast/bench_gated_mxfp8.cpp b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp index 7df895817..04d3e06d6 100644 --- a/benchmarks/cpp/cast/bench_gated_mxfp8.cpp +++ b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp @@ -12,8 +12,6 @@ #include "benchmark_utils.h" -#include "amd_detail/hip_float8.h" - #include #include #include diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 4e4d8606a..ac32ceacd 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -591,8 +591,12 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const float *noop_ptr = reinterpret_cast(noop->data.dptr); float *const amax_ptr = reinterpret_cast(output->amax.dptr); - TRANSFORMER_ENGINE_CHUNK_DIM_SWITCH( - use_large_chunks, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_large_chunks, USE_LARGE_CHUNKS, + + constexpr size_t CHUNK_DIM_Y = USE_LARGE_CHUNKS ? 128 : 64; + constexpr size_t CHUNK_DIM_X = USE_LARGE_CHUNKS ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = USE_LARGE_CHUNKS ? 256 : 128; const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index f26820ba1..10a65bb1d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -736,21 +736,6 @@ struct TypeInfo { } \ } -#define TRANSFORMER_ENGINE_CHUNK_DIM_SWITCH(use_large, CHUNK_Y, CHUNK_X, THREADS, ...) \ - [&] { \ - if (use_large) { \ - constexpr size_t CHUNK_Y = 128; \ - constexpr size_t CHUNK_X = 128; \ - constexpr size_t THREADS = 256; \ - { __VA_ARGS__ } \ - } else { \ - constexpr size_t CHUNK_Y = 64; \ - constexpr size_t CHUNK_X = 64; \ - constexpr size_t THREADS = 128; \ - { __VA_ARGS__ } \ - } \ - }() - #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ if (CONDITION) { \ constexpr bool FLAG = true; \ From 87e5752afd1b37fa001dfe8c876dc564c920781a Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 27 Mar 2026 11:43:29 -0500 Subject: [PATCH 3/5] Rework utils to use test/cpp/test_common* --- benchmarks/cpp/CMakeLists.txt | 6 +- benchmarks/cpp/run_benchmarks.sh | 25 + benchmarks/cpp/utils/benchmark_utils.h | 2 +- benchmarks/cpp/utils/test_common.cpp | 1241 ------------------------ benchmarks/cpp/utils/test_common.h | 684 ------------- tests/cpp/test_common.cu | 9 + 6 files changed, 40 insertions(+), 1927 deletions(-) delete mode 100644 benchmarks/cpp/utils/test_common.cpp delete mode 100644 benchmarks/cpp/utils/test_common.h diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index e95b5e873..c608eb54b 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -24,6 +24,7 @@ FetchContent_MakeAvailable(benchmark) include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common/include ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine ${CMAKE_CURRENT_SOURCE_DIR}/utils ) @@ -71,8 +72,11 @@ else() endif() function(add_te_benchmark TARGET_NAME SOURCE_FILE) - add_executable(${TARGET_NAME} ${SOURCE_FILE} utils/test_common.cpp) + set(TEST_COMMON_HIP ${CMAKE_CURRENT_SOURCE_DIR}/utils/test_common.hip) + set_source_files_properties(${TEST_COMMON_HIP} PROPERTIES LANGUAGE CXX) + add_executable(${TARGET_NAME} ${SOURCE_FILE} ${TEST_COMMON_HIP}) target_compile_options(${TARGET_NAME} PRIVATE ${COMMON_COMPILE_OPTIONS}) + target_compile_definitions(${TARGET_NAME} PRIVATE NVTE_ROCM_BENCHMARK) target_link_libraries(${TARGET_NAME} PRIVATE benchmark::benchmark ${TRANSFORMER_ENGINE_LIB} diff --git a/benchmarks/cpp/run_benchmarks.sh b/benchmarks/cpp/run_benchmarks.sh index 6b9fb5806..834fc3aec 100755 --- a/benchmarks/cpp/run_benchmarks.sh +++ b/benchmarks/cpp/run_benchmarks.sh @@ -10,9 +10,34 @@ GREEN='\033[0;32m' YELLOW='\033[1;33m' NC='\033[0m' +setup_test_common_symlinks() { + local utils_dir="${SCRIPT_DIR}/utils" + local test_common_hip="../../tests/cpp/test_common.hip" + local test_common_h="../../tests/cpp/test_common_hip.h" + + if [ ! -f "${SCRIPT_DIR}/${test_common_hip}" ] || [ ! -f "${SCRIPT_DIR}/${test_common_h}" ]; then + echo -e "${RED}Error: hipified test_common files not found. Build tests before running benchmarks." + return 1 + fi + + if [ ! -L "${utils_dir}/test_common.hip" ] || [ ! -e "${utils_dir}/test_common.hip" ]; then + ln -sf "../${test_common_hip}" "${utils_dir}/test_common.hip" + fi + + if [ ! -L "${utils_dir}/test_common_hip.h" ] || [ ! -e "${utils_dir}/test_common_hip.h" ]; then + ln -sf "../${test_common_h}" "${utils_dir}/test_common_hip.h" + fi + + return 0 +} + main() { echo -e "${GREEN}=== MXFP8 Benchmark Suite ===${NC}" + if ! setup_test_common_symlinks; then + return + fi + echo -e "\n${YELLOW}[1/3] Building benchmarks...${NC}" cd "${SCRIPT_DIR}" if ! cmake -GNinja -B"${BUILD_DIR}" . || ! cmake --build "${BUILD_DIR}"; then diff --git a/benchmarks/cpp/utils/benchmark_utils.h b/benchmarks/cpp/utils/benchmark_utils.h index 35a857109..bd2906b30 100644 --- a/benchmarks/cpp/utils/benchmark_utils.h +++ b/benchmarks/cpp/utils/benchmark_utils.h @@ -16,7 +16,7 @@ #include #include -#include "test_common.h" +#include "test_common_hip.h" namespace te_bench { diff --git a/benchmarks/cpp/utils/test_common.cpp b/benchmarks/cpp/utils/test_common.cpp deleted file mode 100644 index 3caf1245c..000000000 --- a/benchmarks/cpp/utils/test_common.cpp +++ /dev/null @@ -1,1241 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#include "hip/hip_runtime.h" -/************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - - -#include "test_common.h" - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include "util/logging_hip.h" - -#include - -namespace test { - -size_t create_seed_from_tensor_name(const std::string& tensor_name) { - auto full_name = "benchmark/" + tensor_name; - return std::hash{}(full_name); -} - -std::vector all_fp_types = {DType::kFloat32, - DType::kFloat16, - DType::kBFloat16, - DType::kFloat8E5M2, - DType::kFloat8E4M3}; - -bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) { - if (s1.ndim != s2.ndim) return false; - - for (size_t i = 0; i < s1.ndim; ++i) { - if (s1.data[i] != s2.data[i]) return false; - } - - return true; -} - -size_t typeToNumBits(DType type) { - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, - { - return TypeInfo::size; - }); -} - -const std::string &typeName(DType type) { - static const std::unordered_map name_map = { - {DType::kByte, "byte"}, - {DType::kInt32, "int32"}, - {DType::kInt64, "int64"}, - {DType::kFloat32, "float32"}, - {DType::kFloat16, "float16"}, - {DType::kBFloat16, "bfloat16"}, - {DType::kFloat8E4M3, "float8e4m3"}, - {DType::kFloat8E5M2, "float8e5m2"}, - {DType::kFloat8E8M0, "float8e8m0"}, - {DType::kFloat4E2M1, "float4e2m1"}}; - return name_map.at(type); -} - -const std::string& caseName(InputsFillCase type) { - static const std::unordered_map name_map = { - {InputsFillCase::uniform, "uniform"}, - {InputsFillCase::zeros, "zeros"}, - {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"}, - {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"}, - {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}}; - return name_map.at(type); -} - -size_t product(const NVTEShape &shape, size_t begin, size_t end) { - size_t ret = 1; - NVTE_CHECK(end <= shape.ndim); - for (size_t i = begin; i < end; ++i) { - ret *= shape.data[i]; - } - return ret; -} - -size_t product(const NVTEShape &shape) { - return product(shape, 0, shape.ndim); -} - -size_t product(const std::vector shape, size_t begin, size_t end) { - size_t ret = 1; - NVTE_CHECK(end <= shape.size()); - for (size_t i = begin; i < end; ++i) { - ret *= shape[i]; - } - return ret; -} - -size_t product(const std::vector& shape) { - return product(shape, 0, shape.size()); -} - -size_t DIVUP(const size_t &x, const size_t &y){ - return (((x) + ((y)-1)) / (y)); -} - -size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){ - return DIVUP(x, y) * y; -} - -struct scale_inv_meta { - std::vector shape; - DType type; - size_t type_size_bits; - size_t bytes() const noexcept { - return (product(shape) * type_size_bits) / 8; - } -}; - -size_t bytes(const NVTEShape& shape, const DType type) { - return (product(shape) * typeToNumBits(type)) / 8; -} - -NVTEShape convertShape(const std::vector& s) { - return nvte_make_shape(s.data(), s.size()); -} - -std::pair get_scales(const NVTEShape& shape, - const NVTEScalingMode scaling_mode) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - scale_inv_meta ret; - ret.shape = {1}; - ret.type = DType::kFloat32; - ret.type_size_bits = typeToNumBits(DType::kFloat32); - return {ret, ret}; - } - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); - - scale_inv_meta ret_rowwise, ret_colwise; - - const size_t block_size_X_rowwise = 32; - size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); - size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); - ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; - - const size_t block_size_Y_colwise = 32; - size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); - size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); - ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; - - ret_rowwise.type = DType::kFloat8E8M0; - ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); - ret_colwise.type = DType::kFloat8E8M0; - ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); - - return {ret_rowwise, ret_colwise}; - } - if (scaling_mode == NVTE_NVFP4_1D_SCALING) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); - - NVTE_CHECK(last_dim % 32 == 0); - NVTE_CHECK(first_dim % 32 == 0); - - scale_inv_meta ret_rowwise, ret_colwise; - - size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); - size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise); - ret_rowwise.shape = {scale_dim_Y, scale_dim_X}; - - size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise); - size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise); - ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t}; - - ret_rowwise.type = DType::kFloat8E4M3; - ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); - ret_colwise.type = DType::kFloat8E4M3; - ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); - - return {ret_rowwise, ret_colwise}; - } - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); - - scale_inv_meta ret_rowwise, ret_colwise; - - const size_t block_size_X_rowwise = 32; - size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); - size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); - ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; - - const size_t block_size_Y_colwise = 32; - size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); - size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); - ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; - - ret_rowwise.type = DType::kFloat8E8M0; - ret_colwise.type = DType::kFloat8E8M0; - ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); - ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); - - return {ret_rowwise, ret_colwise}; - } - if (scaling_mode == NVTE_BLOCK_SCALING_2D) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); - - scale_inv_meta ret_rowwise, ret_colwise; - - { - auto scale_dim_0 = DIVUP(first_dim, 128lu); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = DIVUP(last_dim, 128lu); -#else - auto scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4; -#endif - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; - } - { - auto scale_dim_0 = DIVUP(last_dim, 128lu); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = DIVUP(first_dim, 128lu); -#else - auto scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4; -#endif - ret_colwise.shape = {scale_dim_0, scale_dim_1}; - } - ret_rowwise.type = DType::kFloat32; - ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); - ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); - - return {ret_rowwise, ret_colwise}; - } - if (scaling_mode == NVTE_BLOCK_SCALING_1D) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); - scale_inv_meta ret_rowwise, ret_colwise; - - { - auto scale_dim_0 = DIVUP(last_dim, 128lu); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = first_dim; -#else - auto scale_dim_1 = DIVUP(first_dim, 4) * 4; -#endif - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; - } - { - auto scale_dim_0 = DIVUP(first_dim, 128lu); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = last_dim; -#else - auto scale_dim_1 = DIVUP(last_dim, 4) * 4; -#endif - ret_colwise.shape = {scale_dim_0, scale_dim_1}; - } - ret_rowwise.type = DType::kFloat32; - ret_colwise.type = DType::kFloat32; - ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32); - ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32); - return {ret_rowwise, ret_colwise}; - } - - NVTE_ERROR("Invalid scaling mode!"); -} - -Tensor::Tensor(const std::string& name, - const NVTEShape &shape, const DType type, - const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode) { - name_ = name; - const size_t seed = create_seed_from_tensor_name(name); - gen_.seed(seed); - rowwise_ = rowwise; - columnwise_ = columnwise; - size_t total_size = bytes(shape, type); - void *dptr_rowwise = nullptr; - void *dptr_columnwise = nullptr; - cpu_data_rowwise_ = nullptr; - cpu_data_columnwise_ = nullptr; - amax_cpu_data_ = nullptr; - scale_cpu_data_ = nullptr; - rowwise_scale_inv_cpu_data_ = nullptr; - columnwise_scale_inv_cpu_data_ = nullptr; - float *amax = nullptr, *scale = nullptr; - float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr; - if (columnwise) { - NVTE_CHECK(shape.ndim >= 2); - } - std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), - shape.data[shape.ndim - 1]}; - NVTEShape normalized_shape = convertShape(normalized_shape_v); - NVTEShape columnwise_shape = {}; - - std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING - || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { - // Transpose when tensor scaling - columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); - for (size_t i = 0; i < shape.ndim - 1; ++i) { - columnwise_shape_vec.emplace_back(shape.data[i]); - } - } else { - // Same shape for MX and NVFP4 - for (size_t i = 0; i < shape.ndim; ++i) { - columnwise_shape_vec.emplace_back(shape.data[i]); - } - } - - if (columnwise) { - columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); - } - - tensor_ = TensorWrapper(scaling_mode); - - if (total_size != 0) { - if (rowwise) { - (void)hipMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) - (void)hipMemset(dptr_rowwise, 0, total_size); - cpu_data_rowwise_ = std::make_unique(total_size); - std::fill_n(cpu_data_rowwise_.get(), total_size, 0); - } - if (columnwise) { - (void)hipMalloc((void**)&dptr_columnwise, total_size); // NOLINT(*) - (void)hipMemset(dptr_columnwise, 0, total_size); - cpu_data_columnwise_ = std::make_unique(total_size); - std::fill_n(cpu_data_columnwise_.get(), total_size, 0); - } - } - - const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; - const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; - tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); - tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); - - if (isFp8Type(type) || isFp4Type(type)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - (void)hipMalloc((void**)&amax, sizeof(float)); // NOLINT(*) - (void)hipMemset(amax, 0, sizeof(float)); - (void)hipMalloc((void**)&scale, sizeof(float)); // NOLINT(*) - (void)hipMemset(scale, 0, sizeof(float)); - amax_cpu_data_ = std::make_shared(0); - scale_cpu_data_ = std::make_shared(0); - tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); - tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); - (void)hipMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) - if (rowwise) { - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); - rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); - std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); - } - if (columnwise) { - tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); - columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); - std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); - } - } else { - if (scaling_mode == NVTE_NVFP4_1D_SCALING) { - // Used for NVFP4 second stage scaling - hipMalloc((void**)&scale, sizeof(float)); // NOLINT(*) - hipMemset(scale, 0, sizeof(float)); - scale_cpu_data_ = std::make_shared(0); - tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); - } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); - auto rowwise_scale_size = rowwise_scale_meta.bytes(); - auto columnwise_scale_size = colwise_scale_meta.bytes(); - auto scale_shape = rowwise_scale_meta.shape; - auto columnwise_scale_shape = colwise_scale_meta.shape; - if (rowwise) { - (void)hipMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) - (void)hipMemset(rowwise_scale_inv, 0, rowwise_scale_size); - rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); - std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); - auto scale_dtype = rowwise_scale_meta.type; - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); - } - if (columnwise) { - (void)hipMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) - (void)hipMemset(columnwise_scale_inv, 0, columnwise_scale_size); - columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); - std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); - auto scale_dtype = colwise_scale_meta.type; - tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); - } - } - } -} - -void Tensor::to_cpu() const { - const NVTEShape s = tensor_.shape(); - const size_t size = bytes(s, tensor_.dtype()); - if (rowwise_) { - (void)hipMemcpy(cpu_data_rowwise_.get(), - tensor_.get_rowwise_data().data_ptr, - size, - hipMemcpyDeviceToHost); - } - if (columnwise_) { - const DType colwise_type = tensor_.dtype(); - - const size_t colwise_size = bytes(s, colwise_type); - (void)hipMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - colwise_size, - hipMemcpyDeviceToHost); - } - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { - if (tensor_.amax() != nullptr){ - (void)hipMemcpy(amax_cpu_data_.get(), - tensor_.amax(), - sizeof(float), - hipMemcpyDeviceToHost); - } - (void)hipMemcpy(scale_cpu_data_.get(), - tensor_.scale(), - sizeof(float), - hipMemcpyDeviceToHost); - } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); - if (rowwise_) { - auto scale_size = rowwise_scale_meta.bytes(); - (void)hipMemcpy(rowwise_scale_inv_cpu_data_.get(), - tensor_.get_rowwise_scale_inv().data_ptr, - scale_size, - hipMemcpyDeviceToHost); - } - if (columnwise_) { - auto scale_size = colwise_scale_meta.bytes(); - (void)hipMemcpy(columnwise_scale_inv_cpu_data_.get(), - tensor_.get_columnwise_scale_inv().data_ptr, - scale_size, - hipMemcpyDeviceToHost); - } - } -} - -void Tensor::from_cpu() const { - const NVTEShape s = tensor_.shape(); - const size_t size = bytes(s, tensor_.dtype()); - if (rowwise_) { - (void)hipMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size, - hipMemcpyHostToDevice); - } - if (columnwise_) { - (void)hipMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, - hipMemcpyHostToDevice); - } - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) - || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) { - if (tensor_.amax() != nullptr){ - (void)hipMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), hipMemcpyHostToDevice); - } - (void)hipMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), hipMemcpyHostToDevice); - } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); - if (rowwise_) { - auto scale_size = rowwise_scale_meta.bytes(); - (void)hipMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, - rowwise_scale_inv_cpu_data_.get(), scale_size, - hipMemcpyHostToDevice); - } - if (columnwise_) { - auto scale_size = colwise_scale_meta.bytes(); - (void)hipMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, - columnwise_scale_inv_cpu_data_.get(), scale_size, - hipMemcpyHostToDevice); - } - } -} - -void Tensor::set_scale(float scale) { - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - NVTE_CHECK(scale_cpu_data_); - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { - *scale_cpu_data_ = scale; - from_cpu(); - } - } -} - -void Tensor::set_scale_inv(float scale_inv) { - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - if (rowwise_) { - NVTE_CHECK(rowwise_scale_inv_cpu_data_); - } - if (columnwise_) { - NVTE_CHECK(columnwise_scale_inv_cpu_data_); - } - - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); - if (rowwise_) { - auto num_scales = product(rowwise_scale_meta.shape); - if (num_scales == 1) { - rowwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else { - std::uniform_int_distribution dis(0, 127); - auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++) { - scale_inv_ptr[i] = dis(gen_); - } - } - } - if (columnwise_) { - auto num_scales = product(colwise_scale_meta.shape); - if (num_scales == 1) { - columnwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else { - std::uniform_int_distribution dis(0, 127); - if (rowwise_) { - from_cpu(); //Need it because scale_inv_ptr getting does to_cpu() - } - auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++) { - scale_inv_ptr[i] = dis(gen_); - } - } - } - from_cpu(); - } -} - -void Tensor::shareFP8Meta(const Tensor &other) { - if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) - || isFp4Type(dtype()) && isFp4Type(other.dtype())) { - auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); - auto my_rowwise_data = tensor_.get_rowwise_data(); - new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), - my_rowwise_data.shape); - auto my_columnwise_data = tensor_.get_columnwise_data(); - new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, - static_cast(my_columnwise_data.dtype), - my_columnwise_data.shape); - auto other_amax = other.tensor_.get_amax(); - new_tensor.set_amax(other_amax.data_ptr, static_cast(other_amax.dtype), - other_amax.shape); - auto other_scale = other.tensor_.get_scale(); - new_tensor.set_scale(other_scale.data_ptr, static_cast(other_scale.dtype), - other_scale.shape); - auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); - new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, - static_cast(other_row_scale_inv.dtype), - other_row_scale_inv.shape); - auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv(); - new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr, - static_cast(other_col_scale_inv.dtype), - other_col_scale_inv.shape); - tensor_ = std::move(new_tensor); - to_cpu(); - } -} - -using std::to_string; - -template -std::string to_string(const std::vector &v) { - std::string s = "["; - for (const auto x : v) { - s += to_string(x) + ", "; - } - s.pop_back(); - s.pop_back(); - return s + "]"; -} - -std::vector unravel(const size_t i, const NVTEShape &shape) { - std::vector ret; - size_t current_i = i; - for (size_t current = shape.ndim - 1; current > 0; --current) { - ret.push_back(current_i % shape.data[current]); - current_i /= shape.data[current]; - } - ret.push_back(current_i); - std::reverse(ret.begin(), ret.end()); - return ret; -} - -#ifndef BENCHMARK_STATIC_DEFINE -void compareResults_sequential(const std::string &name, const Tensor &test, - const void *ref, const bool rowwise, - double atol, double rtol, bool if_on_gpus, - const size_t tolerable_mismatches_limit) { - if (if_on_gpus) test.to_cpu(); - const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); - const size_t N = product(shape); - size_t mismatches_num = 0; - int first_mismatch_idx = -1; - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, - const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); - const T *ref_data = reinterpret_cast(ref); - for (size_t i = 0; i < N; ++i) { - double t = static_cast(test_data[i]); - double r = static_cast(ref_data[i]); - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = mismatch && test.dtype() == DType::kFloat32; - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); - } - std::string direction = rowwise ? "rowwise" : "columnwise"; - if (assertion) { - mismatches_num++; - if (first_mismatch_idx == -1) { - first_mismatch_idx = i; - } - } - if (mismatches_num > tolerable_mismatches_limit) { - const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); - const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); - - GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " - << tolerable_mismatches_limit << "." << std::endl - << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape)) - << " (" << std::to_string(first_mismatch_idx) << "): " - << first_mismatch_t << " vs " << first_mismatch_r; - } - } - ); -} - -template -static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, - const size_t N, const double atol, const double rtol, - size_t& mismatches) { - int first_mismatch_idx = N; - - #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread) - { - size_t thread_mismatches = 0; - #pragma omp for schedule(static) - for (size_t i = 0; i < N; ++i) { - double t = static_cast(test_data[i]); - double r = static_cast(ref_data[i]); - - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = mismatch && (data_type == DType::kFloat32); - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); - } - if (assertion) { - if (i < first_mismatch_idx) { - first_mismatch_idx = i; - } - thread_mismatches++; - } - } - mismatches += thread_mismatches; - } - return first_mismatch_idx; -} - -void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus, - const size_t tolerable_mismatches_limit) { - if (if_on_gpus) test.to_cpu(); - const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); - const size_t N = product(shape); - size_t mismatches = 0; - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, - const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); - const T *ref_data = reinterpret_cast(ref); - - const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches); - if ((i != N) && (mismatches > tolerable_mismatches_limit)) { - const double t = static_cast(test_data[i]); - const double r = static_cast(ref_data[i]); - std::string direction = rowwise ? "rowwise" : "columnwise"; - - GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of " - << tolerable_mismatches_limit << "." << std::endl - << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; - } - ); -} - -void compareResults(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus, - const size_t tolerable_mismatches_limit) { - constexpr bool sequential = false; - if constexpr (sequential) { - compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); - } else { - compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); - } -} - -void compareResults(const std::string &name, const float test, const float ref, - double atol, double rtol) { - double t = static_cast(test); - double r = static_cast(ref); - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - ASSERT_FALSE(mismatch) << "Error in " << name << std::endl - << "Mismatch: " << t << " vs " << r; - -} - - -void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, - size_t N, float mismatch_rate_tol) { - size_t max_mismatches = std::ceil(N * mismatch_rate_tol); - size_t n_mismatches = 0; - std::vector mismatch_indices; - for (int i = 0; i < N; i++){ - bool mismatch = test[i] != ref[i]; - if (mismatch){ - n_mismatches++; - mismatch_indices.push_back(i); - } - if (n_mismatches > max_mismatches){ - std::cout << "Error in " << name << std::endl; - for (auto &index : mismatch_indices) - std::cout << "Mismatch at (" << index << "):" << static_cast(test[i]) << " vs " - << static_cast(ref[i]) << std::endl; - GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol."; - } - } -} - -template -struct CastToType; - -template <> -struct CastToType { - using type = int; -}; - -template <> -struct CastToType { - using type = float; -}; - -template -void compare_scaling_factors(const std::string &name, const T *test, const T *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, -#ifdef __HIP_PLATFORM_AMD__ - std::vector &mismatch_indices, -#endif //#ifdef __HIP_PLATFORM_AMD__ - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit) -{ - using UpcastType = typename CastToType::type; - auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); - - - const size_t N = row_blocks * col_blocks; - const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, - std::floor(N * rel_tolerable_mismatches_limit)); - mismatches_num = 0; -#ifndef __HIP_PLATFORM_AMD__ - std::vector mismatch_indices; -#endif //#ifndef __HIP_PLATFORM_AMD__ - - for (int i = 0; i < row_blocks; ++i) { - for (int j = 0; j < col_blocks; ++j) { - const int idx = i * stride + j; - float t, r; - - bool assertion = false; - - if (std::is_same::value) { - t = static_cast(test[idx]); - r = static_cast(ref[idx]); - assertion = std::abs(t - r) > atol; - } else { - t = static_cast(*reinterpret_cast(&test[idx])); - r = static_cast(*reinterpret_cast(&ref[idx])); - const bool mismatch = (fabs(t - r) > atol_fp8e4m3) - && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3); - if (mismatch) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); - } - } - if (assertion) { - mismatches_num++; - mismatch_indices.push_back(idx); - } - if (mismatches_num > tolerable_mismatches_limit) { - std::cout << "Error in " << name << std::endl; - for (const int index : mismatch_indices) { - std::cout << "Mismatch at (" << index << "):" - << static_cast(test[index]) << " vs " - << static_cast(ref[index]) << std::endl; - } - GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " - << tolerable_mismatches_limit << "."; - } - } - } -} - -#ifdef __HIP_PLATFORM_AMD__ -void adjust_ref_for_e8m0_scale_error(const std::string &name, - const std::vector &mismatch_idx, - const uint8_t *test_scale, const uint8_t *ref_scale, - const size_t scale_stride, const size_t rows, - const size_t cols, bool rowwise, void *ref_ptr, DType otype) { - if (mismatch_idx.size() == 0) { - return; - } - const size_t col_blocks_size = rowwise ? 32 : 1; - const size_t row_blocks_size = rowwise ? 1 : 32; - GTEST_LOG_(INFO) << "Adjusting reference data for " << mismatch_idx.size() - << " scale mismatches in tensor " << name << " " - << (rowwise ? "rowwise" : "colwise") << " direction." << std::endl; - for (const auto scale_idx : mismatch_idx) { - const int scale_diff = ref_scale[scale_idx] - test_scale[scale_idx]; - double scale_val; - if (scale_diff == 1) { - scale_val = 2.; - } else if (scale_diff == -1) { - scale_val = .5; - } else { - GTEST_FAIL() << "Error in " << name << ": mismatch " << test_scale[scale_idx] << " vs " - << ref_scale[scale_idx] << " at index " << scale_idx; - } - const int i = scale_idx / scale_stride; - const int j = scale_idx % scale_stride; - size_t ii_min = i * row_blocks_size; - const size_t ii_max = std::min(ii_min + row_blocks_size, rows); - for (; ii_min < ii_max; ii_min++) { - size_t jj_min = j * col_blocks_size; - const size_t jj_max = std::min(jj_min + col_blocks_size, cols); - for (; jj_min < jj_max; jj_min++) { - const size_t data_idx = ii_min * cols + jj_min; - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(otype, T, { - T *ref_data = reinterpret_cast(ref_ptr); - ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val); - }); // NOLINT(*) - } - } - } -} -#endif // #ifdef __HIP_PLATFORM_AMD__ - -// Instantiate templates -template -void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, -#ifdef __HIP_PLATFORM_AMD__ - std::vector &mismatch_indices, -#endif //#ifdef __HIP_PLATFORM_AMD__ - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit); - -template -void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, -#ifdef __HIP_PLATFORM_AMD__ - std::vector &mismatch_indices, -#endif //#ifdef __HIP_PLATFORM_AMD__ - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit); - -#endif // #ifndef BENCHMARK_STATIC_DEFINE - - -std::pair getTolerances(const DType type) { - switch(type) { - case DType::kFloat32: - return {1e-6, 5e-6}; - case DType::kFloat16: - return {1e-5, 1e-3}; - case DType::kBFloat16: - return {1e-5, 1e-2}; - case DType::kFloat8E4M3: - case DType::kFloat8E5M2: - case DType::kFloat8E8M0: - return {1e-2, 1e-2}; - default: - NVTE_ERROR("Invalid type!"); - } - return {0, 0}; -} - -#ifndef __HIP_PLATFORM_AMD__ -template -void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { - // Check how many RNG calls are required to generate one uniform random value - int rng_calls_per_val = 0; - { - std::mt19937 gen1 = *gen, gen2 = *gen; - std::uniform_real_distribution<> dis(-2.0, 1.0); - const float _ = dis(gen1); - while (gen2 != gen1) { - auto _ = gen2(); - ++rng_calls_per_val; - } - } - - // Generate uniform random values in parallel - #pragma omp parallel proc_bind(spread) - { - std::mt19937 gen_local = *gen; - const int thread_ID = omp_get_thread_num(); - const int threads_num = omp_get_max_threads(); - const int chunk_size = (size + threads_num - 1) / threads_num; - const int idx_min = chunk_size * thread_ID; - const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast(size)); - gen_local.discard(idx_min * rng_calls_per_val); - std::uniform_real_distribution<> dis(-2.0, 1.0); - - for (int i = idx_min; i < idx_max; ++i) { - data[i] = static_cast(dis(gen_local)); - } - } - gen->discard(size * rng_calls_per_val); -} -#endif - -#ifdef __HIP_PLATFORM_AMD__ -template -__global__ void affine_transform_cast_signs(const float* __restrict__ in, - const float* __restrict__ signs, - T* __restrict__ out, - size_t n, double lo, double hi) { - // Map values in *in* from [0, 1) to [lo, hi) and cast to type *T* for *out*. - // Potentially flip signs if RandomSign==true. - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - float val = lo + (hi - lo) * in[idx]; - - if constexpr (RandomSign) { - if (signs[idx] < 0.5f) - val = -val; - } - - out[idx] = static_cast(val); - } -} - -template -static void fillUniformLinearBufferDevice(T* dst_dev, - T* dst_cpu, // nullable - size_t N, - unsigned long long seed, - double lo, double hi, - bool random_sign=false) { - // Fill a linear device buffer with uniform randoms in [*lo*, *hi*] and cast them to *T*. - // Optionally mirror the result into a provided CPU pointer. - if (N == 0) - return; - - float* tmp = nullptr; - NVTE_CHECK_CUDA(hipMalloc(&tmp, N * sizeof(float))); - - float* tmp_sign = nullptr; - if (random_sign) { - NVTE_CHECK_CUDA(hipMalloc(&tmp_sign, N * sizeof(float))); - } - - hiprandGenerator_t gen; - NVTE_CHECK(hiprandCreateGenerator(&gen, HIPRAND_RNG_PSEUDO_PHILOX4_32_10) == HIPRAND_STATUS_SUCCESS); - NVTE_CHECK(hiprandSetPseudoRandomGeneratorSeed(gen, seed) == HIPRAND_STATUS_SUCCESS); - NVTE_CHECK(hiprandGenerateUniform(gen, tmp, N) == HIPRAND_STATUS_SUCCESS); - - if (random_sign) { - NVTE_CHECK(hiprandGenerateUniform(gen, tmp_sign, N) == HIPRAND_STATUS_SUCCESS); - } - - dim3 block(256); - dim3 grid((N + block.x - 1) / block.x); - - if (random_sign) - hipLaunchKernelGGL(( affine_transform_cast_signs), dim3(grid), dim3(block), 0, 0, - tmp, tmp_sign, dst_dev, N, lo, hi); - else - hipLaunchKernelGGL(( affine_transform_cast_signs), dim3(grid), dim3(block), 0, 0, - tmp, nullptr, dst_dev, N, lo, hi); - - NVTE_CHECK_CUDA(hipGetLastError()); - - if (dst_cpu != nullptr) { - NVTE_CHECK_CUDA(hipMemcpy(dst_cpu, dst_dev, N * sizeof(T), hipMemcpyDeviceToHost)); - } - - NVTE_CHECK(hiprandDestroyGenerator(gen) == HIPRAND_STATUS_SUCCESS); - NVTE_CHECK_CUDA(hipFree(tmp)); - if (tmp_sign) - NVTE_CHECK_CUDA(hipFree(tmp_sign)); -} - -template -static void fillUniformTensorDevice(Tensor* t, double lo=-2.0f, - double hi=1.0f, bool random_sign=false) { - void* dst_dev_void = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); - const auto shape = t->rowwise() ? (t->rowwise_shape()) : (t->columnwise_shape()); - const size_t N = product(shape); - - // per-tensor deterministic seed - const unsigned long long seed = static_cast(t->gen()()); - - T* dst_dev = reinterpret_cast(dst_dev_void); - // Keep the CPU mirror in sync. We could use Tensor::to_cpu() here, - // but that does more than just copying the data. - T* dst_cpu = t->rowwise() ? t->rowwise_cpu_dptr() : t->columnwise_cpu_dptr(); - fillUniformLinearBufferDevice(dst_dev, dst_cpu, N, seed, lo, hi, random_sign); -} -#endif - -void fillUniform(Tensor *t) { - if (t->rowwise()) { - const size_t size = product(t->rowwise_shape()); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, - { -#ifdef __HIP_PLATFORM_AMD__ - fillUniformTensorDevice(t); -#else - T *data = t->rowwise_cpu_dptr(); - generate_data_uniformly(data, size, &(t->gen())); -#endif - } - ); - } else { - const size_t size = product(t->columnwise_shape()); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, - { -#ifdef __HIP_PLATFORM_AMD__ - fillUniformTensorDevice(t); -#else - T *data = t->columnwise_cpu_dptr(); - generate_data_uniformly(data, size, &(t->gen())); -#endif - } - ); - } -#ifndef __HIP_PLATFORM_AMD__ - // Data is already on device on AMDGPU - t->from_cpu(); -#endif - std::uniform_real_distribution<> dis(-2.0, 1.0); - t->set_scale_inv(dis(t->gen())); -} - -template -void fillCase_special(Tensor *t) { - const size_t size = product(t->rowwise_shape()); - - if constexpr (Case == InputsFillCase::zeros) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { -#ifdef __HIP_PLATFORM_AMD__ - // Fill device and CPU mirror - void* dst_dev = t->rowwise_dptr(); - NVTE_CHECK_CUDA(hipMemset(dst_dev, 0, size * sizeof(InputType))); - InputType* dst_cpu = t->rowwise_cpu_dptr(); - std::fill_n(dst_cpu, size, static_cast(0)); -#else - InputType *data = t->rowwise_cpu_dptr(); - for (size_t i = 0; i < size; ++i) { - data[i] = static_cast(0); - } -#endif - }); - } else { - double minAbs = -2.0; - double maxAbs = 1.0; - if constexpr (Case != InputsFillCase::uniform) { - minAbs = Quantized_Limits::ranges[Case]; - maxAbs = Quantized_Limits::ranges[Case + 1]; - } - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { -#ifdef __HIP_PLATFORM_AMD__ - const unsigned long long seed = static_cast(t->gen()()); - InputType* dst_dev = static_cast(t->rowwise_dptr()); - InputType* dst_cpu = static_cast(t->rowwise_cpu_dptr()); - fillUniformLinearBufferDevice(dst_dev, dst_cpu, size, seed, - minAbs, maxAbs, /*random_sign=*/true); -#else - std::uniform_real_distribution<> dis(minAbs, maxAbs); - std::uniform_real_distribution<> dis_sign(-1.0, 1.0); - InputType *data = t->rowwise_cpu_dptr(); - for (size_t idx = 0; idx < size; ++idx) { - const bool is_negative = (dis_sign(t->gen()) < 0.0); - double val = dis(t->gen()); - if (is_negative) { - val = -val; - } - data[idx] = static_cast(val); - } -#endif - }); - } - t->set_scale_inv(1.0); -#ifndef __HIP_PLATFORM_AMD__ - t->from_cpu(); -#endif -} - -template -void fillCase(Tensor *t, const InputsFillCase fill_case) { - switch (fill_case) { - case InputsFillCase::uniform: - fillCase_special(t); break; - case InputsFillCase::zeros: - fillCase_special(t); break; - case InputsFillCase::zero_to_minNorm: - fillCase_special(t); break; - case InputsFillCase::minNorm_to_maxNorm: - fillCase_special(t); break; - case InputsFillCase::maxNorm_to_inf: - fillCase_special(t); break; - } -} - -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); -#if FP4_TYPE_SUPPORTED -template void fillCase(Tensor *t, const InputsFillCase fill_case); -#endif - -void setRandomScale(Tensor *t) { - std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale = dis(t->gen()); - t->set_scale(scale); -} - -void setRandomScaleInv(Tensor *t) { - std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale_inv = dis(t->gen()); - t->set_scale_inv(scale_inv); -} - -bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; -} - -bool isFp4Type(DType type) { - return type == DType::kFloat4E2M1; -} - -int32_t getDeviceComputeCapability() { - hipDeviceProp_t deviceProp; - hipGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; -} - -size_t first_dimension(const std::vector &shape) { - if (shape.size() == 0) return 1; - if (shape.size() == 1) return 1; - return product(shape, 0, shape.size() - 1); -} - -size_t last_dimension(const std::vector &shape) { - if (shape.size() == 0) return 1; - return shape[shape.size() - 1]; -} - -std::array get_scale_tensor_dims(const size_t rows, - const size_t cols, - const size_t block_size_rows, - const size_t block_size_cols) { - const bool is_rowwise = (block_size_rows == 1) - && ((block_size_cols == 32) || (block_size_cols == 16)); - - const size_t alignment_Y = is_rowwise - ? scale_tensor_alignment_Y_rowwise - : scale_tensor_alignment_Y_colwise; - const size_t alignment_X = is_rowwise - ? scale_tensor_alignment_X_rowwise - : scale_tensor_alignment_X_colwise; - - const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows); - const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols); - - const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y); - const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X); - return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; -} - -} // namespace test diff --git a/benchmarks/cpp/utils/test_common.h b/benchmarks/cpp/utils/test_common.h deleted file mode 100644 index 50a9defab..000000000 --- a/benchmarks/cpp/utils/test_common.h +++ /dev/null @@ -1,684 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -/************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#pragma once - -#include -#include -#include -#include - -#ifndef USE_ROCM -#define FP4_TYPE_SUPPORTED (TORCH_HIP_VERSION >= 12080) -#include -#include "common/amd_detail/hip_float8.h" -#if FP4_TYPE_SUPPORTED -#include -#endif -#else -#define FP4_TYPE_SUPPORTED (false) -#include -#include "amd_detail/hip_float8.h" -#endif -#include - -#include -#include "util/logging_hip.h" - -namespace test { -using namespace transformer_engine; - -template -struct BytesToType {}; - -template <> -struct BytesToType<1> { - using Type = uint8_t; -}; - -template <> -struct BytesToType<2> { - using Type = uint16_t; -}; - -template <> -struct BytesToType<4> { - using Type = uint32_t; -}; - -template <> -struct BytesToType<8> { - using Type = uint64_t; -}; - -using byte = uint8_t; -using int16 = int16_t; -using int32 = int32_t; -using int64 = int64_t; -using fp32 = float; -using fp16 = half; -#ifndef USE_ROCM -using bf16 = nv_bfloat16; -using fp8e4m3 = __nv_fp8_e4m3; -using fp8e5m2 = __nv_fp8_e5m2; -#else -using bf16 = hip_bfloat16; -using fp8e4m3 = te_hip_fp8_e4m3; -using fp8e5m2 = te_hip_fp8_e5m2; -#endif //USE_ROCM -using fp8e8m0 = uint8_t; -#if FP4_TYPE_SUPPORTED -using fp4e2m1 = __nv_fp4_e2m1; -using fp4e2m1x2 = __nv_fp4x2_e2m1; -using fp4e2m1x4 = __nv_fp4x4_e2m1; -#endif - -template -struct BitsNumber; - -#if FP4_TYPE_SUPPORTED -template <> -struct BitsNumber { - static constexpr size_t num_bits = 4; -}; -#endif - -template -struct BitsNumber { - static constexpr size_t num_bits = 8 * sizeof(T); -}; - -template -struct TypeInfo { -#if FP4_TYPE_SUPPORTED - using types = std::tuple; -#else - using types = std::tuple; -#endif - - template - struct Helper { - constexpr static DType getType() { - constexpr int i = static_cast(current); - if (std::is_same::type>::value) { - return current; - } else { - return Helper(i + 1)>::getType(); - } - } - }; - - template - struct Helper { - constexpr static DType getType() { - return DType::kNumTypes; - } - }; - - template - constexpr static DType getType() { - return Helper::getType(); - } - - constexpr static DType dtype = getType(); - constexpr static size_t size = BitsNumber::num_bits;; -}; - -class Tensor { - public: - Tensor(const std::string& name, - const NVTEShape &shape, const DType type, - const bool rowwise = true, - const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); - - Tensor(const std::string& name, - const std::vector &shape, - const DType type, - const bool rowwise = true, - const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : - Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {} - - Tensor() {} - - Tensor& operator=(const Tensor &other) = delete; - Tensor(const Tensor &other) = delete; - - Tensor(Tensor &&other) = default; - Tensor& operator=(Tensor &&other) = default; - - ~Tensor() { - void *data_ptr = tensor_.dptr(); - void *scale_inv = tensor_.scale_inv(); - void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr; - void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr; - if (columnwise_data_ptr == data_ptr) { - columnwise_data_ptr = nullptr; - } - if (columnwise_scale_inv == scale_inv) { - columnwise_scale_inv = nullptr; - } - if (data_ptr != nullptr) { - (void)hipFree(data_ptr); - } - if (scale_inv != nullptr) { - (void)hipFree(scale_inv); - } - if (columnwise_data_ptr != nullptr){ - (void)hipFree(columnwise_data_ptr); - } - if (columnwise_scale_inv != nullptr){ - (void)hipFree(columnwise_scale_inv); - } - } - - NVTETensor data() const noexcept { return tensor_.data(); } - - NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; } - - NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; } - - NVTEShape rowwise_scale_inv_shape() const { - NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return tensor_.get_rowwise_scale_inv().shape; - } - - NVTEShape columnwise_scale_inv_shape() const { - NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return tensor_.get_columnwise_scale_inv().shape; - } - - NVTEScalingMode scaling_mode() const noexcept { - return tensor_.scaling_mode(); - } - - DType dtype() const noexcept { - return tensor_.dtype(); - } - - void *rowwise_dptr() const { - NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return tensor_.get_rowwise_data().data_ptr; - } - - void *columnwise_dptr() const { - NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return tensor_.get_columnwise_data().data_ptr; - } - - template - T *rowwise_cpu_dptr() const { - NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); - NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return reinterpret_cast(cpu_data_rowwise_.get()); - } - - template - T *columnwise_cpu_dptr() const { - NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); - NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return reinterpret_cast(cpu_data_columnwise_.get()); - } - - float amax() const { - if(amax_cpu_data_) { - to_cpu(); - return *amax_cpu_data_; - } else { - return 0; - } - } - - void *amax_dptr() const { - return tensor_.amax(); - } - - float scale() const { - if(scale_cpu_data_) { - NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) - || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING), - "Invalid scaling_mode!"); - to_cpu(); - return *scale_cpu_data_; - } else { - return 1; - } - } - - template - T *rowwise_cpu_scale_inv_ptr() const { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); - } else { - NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); - } - to_cpu(); - return reinterpret_cast(rowwise_scale_inv_cpu_data_.get()); - } - - template - T *columnwise_cpu_scale_inv_ptr() const { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); - } else { - NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); - } - to_cpu(); - return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); - } - - float rowwise_scale_inv() const { - if(rowwise_scale_inv_cpu_data_) { - float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; - return scale_inv; - } else { - return 1; - } - } - - bool rowwise() const { - return rowwise_; - } - - bool columnwise() const { - return columnwise_; - } - - void set_tensor_amax_nullptr(){ - tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); - } - - void to_cpu() const; - void from_cpu() const; - void set_scale(float scale); - void set_scale_inv(float scale_inv); - void shareFP8Meta(const Tensor &other); - - std::mt19937& gen() { return gen_; } - - void *rowwise_scale_inv_dptr() const { - NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return tensor_.scale_inv(); // rowwise scale_inv backing storage - } - - void *columnwise_scale_inv_dptr() const { - NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return tensor_.get_columnwise_scale_inv().data_ptr; - } - - private: - TensorWrapper tensor_; - std::unique_ptr cpu_data_rowwise_; - std::unique_ptr cpu_data_columnwise_; - std::shared_ptr amax_cpu_data_; - std::shared_ptr scale_cpu_data_; - std::unique_ptr rowwise_scale_inv_cpu_data_; - std::unique_ptr columnwise_scale_inv_cpu_data_; - bool rowwise_; - bool columnwise_; - std::string name_; - std::mt19937 gen_; -}; - -constexpr uint32_t FP32_EXPONENT_BIAS = 127; -constexpr uint32_t FP32_MANTISSA_BITS = 23; - -// [128,4] rowwise and [4,128] colwise alignment requirement -#ifdef __HIP_PLATFORM_AMD__ -constexpr size_t scale_tensor_alignment_X_rowwise = 1; -constexpr size_t scale_tensor_alignment_Y_rowwise = 1; -constexpr size_t scale_tensor_alignment_X_colwise = 1; -constexpr size_t scale_tensor_alignment_Y_colwise = 1; -#else -constexpr size_t scale_tensor_alignment_Y_rowwise = 128; -constexpr size_t scale_tensor_alignment_X_rowwise = 4; -constexpr size_t scale_tensor_alignment_Y_colwise = 4; -constexpr size_t scale_tensor_alignment_X_colwise = 128; -#endif - -inline size_t divide_round_up(const size_t N, const size_t M) { - return (N - 1 + M) / M; -} - -inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) { - return divide_round_up(N, M) * M; -} - -template -struct Numeric_Traits { - static constexpr double minSubnorm = 1.0; - static constexpr double maxSubnorm = 1.0; - static constexpr double minNorm = 1.0; - static constexpr double maxNorm = 1.0; - static constexpr double artifInf = 1.0; - static constexpr int maxBiasedExponent = 1; -}; - -template <> -struct Numeric_Traits { - static constexpr double minSubnorm = 1.0 / static_cast(1 << 9); // std::pow(2.0, -9.0); - static constexpr double maxSubnorm = 0.875 / static_cast(1 << 6); // std::pow(2.0, -6.0); - static constexpr double minNorm = 1.0 / static_cast(1 << 6); // std::pow(2.0, -6.0); -#ifndef USE_ROCM - static constexpr double maxNorm = 448.0; -#else - static const double maxNorm; -#endif //USE_ROCM - static const double artifInf; // artificial Infinity - static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; - static constexpr int maxUnbiasedExponentAsFP32 = 8; - static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; -}; - -#ifdef USE_ROCM -inline const double Numeric_Traits::maxNorm = te_fp8_fnuz() ? 240.0 : 448.0; -#endif - -inline const double Numeric_Traits::artifInf = 10.0 * Numeric_Traits::maxNorm; - -template <> -struct Numeric_Traits { - static constexpr double minSubnorm = 1.0 / static_cast(1 << 16); // std::pow(2.0, -16.0); - static constexpr double maxSubnorm = 0.75 / static_cast(1 << 14); // std::pow(2.0, -14.0); - static constexpr double minNorm = 1.0 / static_cast(1 << 14); // std::pow(2.0, -14.0); - static constexpr double maxNorm = 57344.0; - static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity - static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS; - static constexpr int maxUnbiasedExponentAsFP32 = 15; - static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; -}; - -template <> -struct Numeric_Traits { - static constexpr double minSubnorm = std::numeric_limits::denorm_min(); // std::pow(2.0, -149.0); - static constexpr double maxSubnorm = std::numeric_limits::min() - - std::numeric_limits::denorm_min(); // minNormalized - minDenormalized - static constexpr double minNorm = std::numeric_limits::min(); // std::pow(2.0, -126.0); - static constexpr double maxNorm = std::numeric_limits::max(); // (1 - pow(2, -24)) * pow(2, 128) - static constexpr double artifInf = std::numeric_limits::infinity(); - static constexpr int maxBiasedExponentAsFP32 = 255; - static constexpr int maxUnbiasedExponentAsFP32 = 128; -}; - -template -struct Quantized_Limits { - static const double ranges[4]; - static constexpr inline fp32 max() { return static_cast(Numeric_Traits::maxNorm); } - static constexpr inline fp32 max_reciprocal() { return static_cast(1.0 / max()); } - static constexpr inline fp32 emax() { return static_cast(Numeric_Traits::maxExpNorm); } - static constexpr inline fp32 emax_reciprocal() { return static_cast(1.0 / emax()); } - static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits::maxBiasedExponentAsFP32; } - static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits::maxUnbiasedExponentAsFP32; } -}; - -template -inline const double Quantized_Limits::ranges[4] = { - 0.0, - Numeric_Traits::minNorm, - Numeric_Traits::maxNorm, - Numeric_Traits::artifInf -}; -// Input data filling cases -// Considering normal and subnormal magnitudes of E4M3 and E5M2 formats -// with nearest to even rounding per OFP8 specification -enum InputsFillCase { - zero_to_minNorm = 0, // [0, min_normal) - minNorm_to_maxNorm = 1, // [min_normal, max_normal) - maxNorm_to_inf = 2, // [max_normal, inf) - zeros = 3, // {0} - uniform = 4, // std::uniform_real_distribution<> dis(-2.0, 1.0) -}; - -inline fp8e8m0 float_to_e8m0(float val) { - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (std::isnan(val)) { - return 0xFF; - } - if (std::isinf(val)) { - return 0xFE; - } - if (val == 0.0f) { - return 0x00; - } - uint32_t val_u32 = *reinterpret_cast(&val); - fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS); - uint32_t mantissa = val_u32 & 0x7FFFFF; - // Round up exponent and deal with satfinite. - if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { - ++exponent; - } - return exponent; -} - -inline float exp2f_rcp(fp8e8m0 biased_exp) { - if (biased_exp == 0) { - return 1.0f; - } - int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) - float fp32_val = *reinterpret_cast(&int_val); - return fp32_val; -} - -inline float identity(const float x) { return x; } -inline float gelu(const float x) { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); } -inline float dgelu(const float x) { - const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x)); - return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) - + 0.5f * (1 + tanh_out); -} -inline float sigmoid(const float x) { return 1 / (1 + expf(-x)); } -inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); } -inline float qgelu(const float x) { return x * sigmoid(1.702f * x); } -inline float dqgelu(const float x) { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); } -inline float relu(const float x) { return fmaxf(0, x); } -inline float drelu(const float x) { return x > 0 ? 1 : 0; } -inline float silu(const float x) { return x * sigmoid(x); } -inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } -inline float srelu(const float x) { return x > 0 ? x * x : 0; } -inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } - -size_t typeToNumBits(DType type); -size_t product(const NVTEShape &shape); -size_t product(const std::vector &shape); -size_t bytes(const NVTEShape& shape, const DType type); - -size_t first_dimension(const std::vector &shape); -size_t last_dimension(const std::vector &shape); - -bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); - -void compareResults(const std::string &name, const Tensor &test, const void *ref, - bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, - const size_t tolerable_mismatches_limit = 0); -void compareResults(const std::string &name, const float test, const float ref, - double atol = 1e-5, double rtol = 1e-8); -void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, - size_t N, float mismatch_rate_tol = 0.); -template -void compare_scaling_factors(const std::string &name, const T *test, const T *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, -#ifdef USE_ROCM - std::vector& mismatch_indices, -#endif //#ifdef USE_ROCM - size_t& mismatches_num, - const size_t scale_diff_abs_tolerance = 0, - const double abs_tolerable_mismatches_limit = 0, - const double rel_tolerable_mismatches_limit = 0); - -#ifdef USE_ROCM -void adjust_ref_for_e8m0_scale_error(const std::string &name, - const std::vector &mismatch_idx, - const uint8_t *test_scale, const uint8_t *ref_scale, - const size_t scale_stride, const size_t rows, - const size_t cols, bool rowwise, void *ref_ptr, DType otype); -#endif - -std::array get_scale_tensor_dims(const size_t rows, const size_t cols, - const size_t block_size_rows, const size_t block_size_cols); - -std::pair getTolerances(const DType type); - -void fillUniform(Tensor *t); - -template -void fillCase(Tensor *t, const InputsFillCase fill_case); - -void setRandomScale(Tensor *t); -void setRandomScaleInv(Tensor *t); - -constexpr int THREADS_PER_WARP = 32; - -const std::string &typeName(DType type); -const std::string& caseName(InputsFillCase type); - -extern std::vector all_fp_types; - -bool isFp8Type(DType type); -bool isFp4Type(DType type); - -int32_t getDeviceComputeCapability(); -constexpr int32_t hopperComputeCapability = 90; -constexpr int32_t blackwellComputeCapability = 100; - -} // namespace test - -#if FP4_TYPE_SUPPORTED -#define SWITCH_FP4_TYPE_HANDLE(type, ...) \ - case DType::kFloat4E2M1: { \ - using type = fp4e2m1; \ - { __VA_ARGS__ } \ - } break; -#else -#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing -#endif - -#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kByte: \ - { \ - using type = byte; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kInt32: \ - { \ - using type = int32; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kInt64: \ - { \ - using type = int64; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat32: \ - { \ - using type = float; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat16: \ - { \ - using type = fp16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kBFloat16: \ - { \ - using type = bf16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E4M3: \ - { \ - using type = fp8e4m3; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E5M2: \ - { \ - using type = fp8e5m2; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E8M0: \ - { \ - using type = fp8e8m0; \ - {__VA_ARGS__} \ - } \ - break; \ - SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ - default: \ - printf("dtype: %d\n", static_cast(dtype)); \ - NVTE_ERROR("Invalid type."); \ - } - -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat8E4M3: \ - { \ - using type = fp8e4m3; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat8E5M2: \ - { \ - using type = fp8e5m2; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - NVTE_ERROR("Invalid type."); \ - } - -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ - default: \ - NVTE_ERROR("Invalid type."); \ - } - -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: \ - { \ - using type = float; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kFloat16: \ - { \ - using type = fp16; \ - {__VA_ARGS__} \ - } \ - break; \ - case DType::kBFloat16: \ - { \ - using type = bf16; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - NVTE_ERROR("Invalid type."); \ - } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 29fd3a66d..c7d20778e 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -17,7 +17,9 @@ #include #include +#ifndef NVTE_ROCM_BENCHMARK #include +#endif #include #include @@ -28,9 +30,13 @@ namespace test { size_t create_seed_from_tensor_name(const std::string& tensor_name) { +#ifndef NVTE_ROCM_BENCHMARK auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) + "/" + tensor_name; return std::hash{}(full_name); +#else + return std::hash{}(tensor_name); +#endif } std::vector all_fp_types = {DType::kFloat32, @@ -619,6 +625,8 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { return ret; } +#ifndef NVTE_ROCM_BENCHMARK + void compareResults_sequential(const std::string &name, const Tensor &test, const void *ref, const bool rowwise, double atol, double rtol, bool if_on_gpus, @@ -923,6 +931,7 @@ void compare_scaling_factors(const std::string &name, const fp8e4m3 *te const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit); +#endif // NVTE_ROCM_BENCHMARK std::pair getTolerances(const DType type) { switch(type) { From 91f8441b277c5cdee7d4863caf408a27cefa27e9 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 27 Mar 2026 13:31:46 -0500 Subject: [PATCH 4/5] Fix clamping with fast_tanh --- transformer_engine/common/util/math.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 17648dade..d7b97723d 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -14,6 +14,8 @@ namespace transformer_engine { // AMD Fast tanh using hardware exp instruction __device__ inline float fast_tanhf(float x) { #ifdef __HIP_PLATFORM_AMD__ + // tanh(x) saturates at ±1 for |x| > 20 + x = fmaxf(fminf(x, 20.f), -20.f); float e2x = __expf(2.0f * x); return (e2x - 1.0f) * __frcp_rn(e2x + 1.0f); #else From 330c40e947e389b646b52d3fdc537a2fa79afb6b Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 27 Mar 2026 16:18:45 -0500 Subject: [PATCH 5/5] Formatting and comments --- benchmarks/cpp/CMakeLists.txt | 1 - .../cpp/cast/bench_dequantize_mxfp8.cpp | 1 + benchmarks/cpp/cast/bench_gated_mxfp8.cpp | 2 ++ .../cpp/cast/bench_quantize_mxfp8_fused.cpp | 1 + benchmarks/cpp/run_benchmarks.sh | 35 ++++++++----------- 5 files changed, 19 insertions(+), 21 deletions(-) diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index c608eb54b..7e7af41af 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -41,7 +41,6 @@ set(COMMON_COMPILE_OPTIONS -DNDEBUG -DUSE_ROCM --offload-arch=${GPU_TARGETS} - -w ) find_library(TRANSFORMER_ENGINE_LIB diff --git a/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp index 77222f23f..6a082a026 100644 --- a/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp +++ b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp @@ -100,6 +100,7 @@ static void BM_DequantizeMXFP8(benchmark::State &state) { const size_t bytes_read_data = rows * cols * sizeof(IType) * ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type const size_t bytes_read_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); const size_t bytes_write = rows * cols * sizeof(OType); diff --git a/benchmarks/cpp/cast/bench_gated_mxfp8.cpp b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp index 04d3e06d6..ce168ebc1 100644 --- a/benchmarks/cpp/cast/bench_gated_mxfp8.cpp +++ b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp @@ -94,6 +94,7 @@ static void BM_GatedMXFP8_Forward(benchmark::State &state) { const size_t bytes_write_data = rows * output_cols * sizeof(OType) * ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); @@ -162,6 +163,7 @@ static void BM_GatedMXFP8_Backward(benchmark::State &state) { const size_t bytes_write_data = rows * output_cols * sizeof(OType) * ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); diff --git a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp index 4326a65ff..119d43f86 100644 --- a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp +++ b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp @@ -132,6 +132,7 @@ static void BM_QuantizeMXFP8_Fused(benchmark::State &state) { size_t bytes_write_data = rows * cols * sizeof(OType) * ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); diff --git a/benchmarks/cpp/run_benchmarks.sh b/benchmarks/cpp/run_benchmarks.sh index 834fc3aec..03d8b5716 100755 --- a/benchmarks/cpp/run_benchmarks.sh +++ b/benchmarks/cpp/run_benchmarks.sh @@ -5,18 +5,13 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BUILD_DIR="${SCRIPT_DIR}/build" RESULTS_DIR="${SCRIPT_DIR}/results" -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -NC='\033[0m' - setup_test_common_symlinks() { local utils_dir="${SCRIPT_DIR}/utils" local test_common_hip="../../tests/cpp/test_common.hip" local test_common_h="../../tests/cpp/test_common_hip.h" if [ ! -f "${SCRIPT_DIR}/${test_common_hip}" ] || [ ! -f "${SCRIPT_DIR}/${test_common_h}" ]; then - echo -e "${RED}Error: hipified test_common files not found. Build tests before running benchmarks." + echo -e "Error: hipified test_common files not found. Build tests before running benchmarks." return 1 fi @@ -32,25 +27,25 @@ setup_test_common_symlinks() { } main() { - echo -e "${GREEN}=== MXFP8 Benchmark Suite ===${NC}" + echo -e "=== MXFP8 Benchmark Suite ===" if ! setup_test_common_symlinks; then return fi - echo -e "\n${YELLOW}[1/3] Building benchmarks...${NC}" + echo -e "\n[1/3] Building benchmarks..." cd "${SCRIPT_DIR}" if ! cmake -GNinja -B"${BUILD_DIR}" . || ! cmake --build "${BUILD_DIR}"; then - echo -e "${RED}Build failed. Fix the build errors and try again.${NC}" + echo -e "Build failed. Fix the build errors and try again." return fi - echo -e "${GREEN}✓ Build complete${NC}" + echo -e "✓ Build complete" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) RESULT_PREFIX="${RESULTS_DIR}/bench_${TIMESTAMP}" - echo -e "\n${YELLOW}[2/3] Running benchmarks...${NC}" + echo -e "\n[2/3] Running benchmarks..." BENCHMARKS=( "bench_quantize_mxfp8_fused" @@ -66,23 +61,23 @@ main() { --benchmark_out="${RESULT_PREFIX}_${bench}.csv" \ --benchmark_out_format=csv \ --benchmark_min_time=0.2s; then - echo -e " ${GREEN}✓${NC} Saved to ${RESULT_PREFIX}_${bench}.csv" + echo -e " ✓ Saved to ${RESULT_PREFIX}_${bench}.csv" else - echo -e " ${RED}✗${NC} ${bench} failed (exit code $?), continuing..." + echo -e " ✗ ${bench} failed (exit code $?), continuing..." FAILED_BENCHMARKS+=("${bench}") fi else - echo -e " ${RED}✗${NC} ${bench} not found, skipping" + echo -e " ✗ ${bench} not found, skipping" fi done - echo -e "\n${YELLOW}[3/3] Consolidating results...${NC}" + echo -e "\n[3/3] Consolidating results..." CONSOLIDATED_CSV="${RESULT_PREFIX}_all.csv" FIRST_CSV=$(ls "${RESULT_PREFIX}"_*.csv 2>/dev/null | grep -v "_all.csv" | head -1) if [ -z "$FIRST_CSV" ]; then - echo -e "${RED}No CSV files found to consolidate${NC}" + echo -e "No CSV files found to consolidate" return fi @@ -94,9 +89,9 @@ main() { fi done - echo -e "${GREEN}✓ Consolidated CSV: ${CONSOLIDATED_CSV}${NC}" + echo -e "✓ Consolidated CSV: ${CONSOLIDATED_CSV}" - echo -e "\n${GREEN}=== Summary ===${NC}" + echo -e "\n=== Summary ===" TOTAL_ROWS=$(tail -n +2 "$CONSOLIDATED_CSV" | wc -l) echo "Total benchmarks: $TOTAL_ROWS" echo "Results saved to: ${RESULTS_DIR}/" @@ -111,9 +106,9 @@ main() { echo "" if [ ${#FAILED_BENCHMARKS[@]} -gt 0 ]; then - echo -e "${RED}Failed benchmarks:${NC}" + echo -e "Failed benchmarks:" for bench in "${FAILED_BENCHMARKS[@]}"; do - echo -e " ${RED}✗${NC} ${bench}" + echo -e " ✗ ${bench}" done echo "" fi