diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt new file mode 100644 index 000000000..7e7af41af --- /dev/null +++ b/benchmarks/cpp/CMakeLists.txt @@ -0,0 +1,89 @@ +cmake_minimum_required(VERSION 3.18) + +if(NOT DEFINED CMAKE_CXX_COMPILER) + set(CMAKE_CXX_COMPILER hipcc) +endif() + +project(transformer_engine_benchmarks LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(HIP REQUIRED) + +include(FetchContent) +FetchContent_Declare( + benchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.8.3 +) +set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) +set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable gtest in benchmark" FORCE) +FetchContent_MakeAvailable(benchmark) + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common/include + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine + ${CMAKE_CURRENT_SOURCE_DIR}/utils +) + +if(DEFINED ENV{NVTE_ROCM_ARCH}) + set(GPU_TARGETS $ENV{NVTE_ROCM_ARCH}) +else() + set(GPU_TARGETS "gfx942;gfx950") +endif() + +set(COMMON_COMPILE_OPTIONS + -Wall + -Wextra + -O3 + -DNDEBUG + -DUSE_ROCM + --offload-arch=${GPU_TARGETS} +) + +find_library(TRANSFORMER_ENGINE_LIB + NAMES transformer_engine + PATHS ${CMAKE_CURRENT_SOURCE_DIR}/../.. + ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake + ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib + /usr/local/lib + $ENV{HOME}/.local/lib + NO_DEFAULT_PATH +) + +if(NOT TRANSFORMER_ENGINE_LIB) + message(WARNING "TransformerEngine library not found in expected paths. Trying system paths...") + find_library(TRANSFORMER_ENGINE_LIB NAMES transformer_engine) +endif() + +if(TRANSFORMER_ENGINE_LIB) + message(STATUS "Found TransformerEngine library: ${TRANSFORMER_ENGINE_LIB}") +else() + message(FATAL_ERROR "TransformerEngine library not found. Please build TransformerEngine first:\n" + " cd ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" + " pip install -e . --no-build-isolation\n" + "Searched paths:\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib") +endif() + +function(add_te_benchmark TARGET_NAME SOURCE_FILE) + set(TEST_COMMON_HIP ${CMAKE_CURRENT_SOURCE_DIR}/utils/test_common.hip) + set_source_files_properties(${TEST_COMMON_HIP} PROPERTIES LANGUAGE CXX) + add_executable(${TARGET_NAME} ${SOURCE_FILE} ${TEST_COMMON_HIP}) + target_compile_options(${TARGET_NAME} PRIVATE ${COMMON_COMPILE_OPTIONS}) + target_compile_definitions(${TARGET_NAME} PRIVATE NVTE_ROCM_BENCHMARK) + target_link_libraries(${TARGET_NAME} PRIVATE + benchmark::benchmark + ${TRANSFORMER_ENGINE_LIB} + hiprand + ) + set_target_properties(${TARGET_NAME} PROPERTIES HIP_ARCHITECTURES "${GPU_TARGETS}") +endfunction() + +add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) +add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) +add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) diff --git a/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp new file mode 100644 index 000000000..6a082a026 --- /dev/null +++ b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp @@ -0,0 +1,130 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include +#include + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 3584}) \ + ->Args({1024, 4096}) \ + ->Args({1024, 8192}) \ + ->Args({1024, 14336}) \ + ->Args({2048, 4096}) \ + ->Args({2048, 8192}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({4096, 4096}) \ + ->Args({4096, 8192}) \ + ->Args({4096, 16384}) \ + ->Args({4096, 28672}) \ + ->Args({8192, 8192}) \ + ->Args({8192, 16384}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 8192}) \ + ->Args({16384, 16384})\ + ->Args({32768, 8192}) + +template +static void BM_DequantizeMXFP8(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? cols : 0; + + std::vector shape = {rows, cols}; + DType itype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + DType otype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + + test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + DeviceBuffer temp_fp32(rows * cols); + fill_random_uniform_gpu(temp_fp32.get(), rows * cols, -2.0f, 1.0f, stream); + + void *input_data_ptr = USE_ROWWISE ? input_tensor.rowwise_dptr() : input_tensor.columnwise_dptr(); + size_t threads = 256; + size_t blocks = (rows * cols + threads - 1) / threads; + cast_fp32_kernel<<>>(temp_fp32.get(), static_cast(input_data_ptr), rows * cols); + + HIP_CHECK(hipStreamSynchronize(stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_read_data = rows * cols * sizeof(IType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + const size_t bytes_read_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + const size_t bytes_write = rows * cols * sizeof(OType); + const size_t total_bytes = bytes_read_data + bytes_read_scales + bytes_write; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_DEQUANTIZE_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, __half, "E4M3", "FP16") +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, hip_bfloat16, "E4M3", "BF16") +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, float, "E4M3", "FP32") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/cast/bench_gated_mxfp8.cpp b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp new file mode 100644 index 000000000..ce168ebc1 --- /dev/null +++ b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include +#include +#include + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +// SwiGLU shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 14336}) \ + ->Args({1024, 18944}) \ + ->Args({1024, 28672}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({2048, 29568}) \ + ->Args({4096, 14336}) \ + ->Args({4096, 28672}) \ + ->Args({4096, 53248}) \ + ->Args({8192, 14336}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 29568}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 28672}) \ + ->Args({16384, 53248}) \ + ->Args({32768, 28672}) \ + ->Args({32768, 53248}) + +template +static void BM_GatedMXFP8_Forward(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t input_cols = cols * 2; + const size_t output_cols = cols; + + const size_t scale_cols_row = USE_ROWWISE ? (output_cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? output_cols : 0; + + std::vector input_shape = {rows, input_cols}; + std::vector output_shape = {rows, output_cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &input_tensor = TensorCache::get_or_create("input", input_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", output_shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_write_data = rows * output_cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + const size_t bytes_read = rows * cols * sizeof(IType) * 2; + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +template +static void BM_GatedMXFP8_Backward(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t input_cols = cols * 2; + const size_t output_cols = cols * 2; + + const size_t scale_cols_row = USE_ROWWISE ? (output_cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? output_cols : 0; + + std::vector grad_shape = {rows, cols}; + std::vector input_shape = {rows, input_cols}; + std::vector output_shape = {rows, output_cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &grad_tensor = TensorCache::get_or_create("grad", grad_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &input_tensor = TensorCache::get_or_create("input", input_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", output_shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_dswiglu(grad_tensor.data(), input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_write_data = rows * output_cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + const size_t bytes_read = rows * cols * sizeof(IType) * 3; + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_GATED_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/both") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/both") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_GATED_ALL_CONFIGS(__half, fp8_e4m3, "FP16", "E4M3") +REGISTER_GATED_ALL_CONFIGS(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") +REGISTER_GATED_ALL_CONFIGS(float, fp8_e4m3, "FP32", "E4M3") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp new file mode 100644 index 000000000..119d43f86 --- /dev/null +++ b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp @@ -0,0 +1,182 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include +#include +#include + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 3584}) \ + ->Args({1024, 4096}) \ + ->Args({1024, 8192}) \ + ->Args({1024, 14336}) \ + ->Args({1024, 18944}) \ + ->Args({2048, 4096}) \ + ->Args({2048, 8192}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({2048, 29568}) \ + ->Args({4096, 4096}) \ + ->Args({4096, 8192}) \ + ->Args({4096, 16384}) \ + ->Args({4096, 14336}) \ + ->Args({4096, 28672}) \ + ->Args({8192, 8192}) \ + ->Args({8192, 16384}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 29568}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 8192}) \ + ->Args({16384, 16384})\ + ->Args({16384, 28672})\ + ->Args({32768, 8192}) \ + ->Args({32768, 16384}) + +template +static void BM_QuantizeMXFP8_Fused(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? cols : 0; + + std::vector shape = {rows, cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspace_tensor_ptr = nullptr; + + if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { + std::vector bias_shape = {cols}; + dbias_tensor_ptr = &TensorCache::get_or_create("dbias", bias_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + workspace_tensor_ptr = &TensorCache::get_or_create("workspace", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + } + + if constexpr (PROC_METHOD == CAST_DBIAS_DACT || PROC_METHOD == CAST_DACT) { + grad_tensor_ptr = &TensorCache::get_or_create("grad", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + } + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + if constexpr (PROC_METHOD == CAST_ONLY) { + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); + } else if constexpr (PROC_METHOD == CAST_DBIAS) { + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor_ptr->data(), workspace_tensor_ptr->data(), stream); + } else if constexpr (PROC_METHOD == CAST_DBIAS_DACT) { + nvte_quantize_dbias_dgelu(grad_tensor_ptr->data(), input_tensor.data(), output_tensor.data(), dbias_tensor_ptr->data(), workspace_tensor_ptr->data(), stream); + } else if constexpr (PROC_METHOD == CAST_DACT) { + nvte_dgelu(grad_tensor_ptr->data(), input_tensor.data(), output_tensor.data(), stream); + } else if constexpr (PROC_METHOD == CAST_ACT) { + nvte_gelu(input_tensor.data(), output_tensor.data(), stream); + } + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t bytes_write_data = rows * cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + size_t bytes_read = rows * cols * sizeof(IType); + if constexpr (PROC_METHOD == CAST_DBIAS_DACT || PROC_METHOD == CAST_DACT) { + bytes_read += rows * cols * sizeof(IType); + } + if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { + bytes_write_data += cols * sizeof(IType); + } + + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, METHOD, METHOD_NAME) \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 1, 32, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/rowwise/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 32, 1, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/colwise/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 32, 32, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/both/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define REGISTER_ALL_METHODS(ITYPE, OTYPE, INAME, ONAME) \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_ONLY, "CastOnly") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DBIAS, "CastDBias") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DBIAS_DACT, "CastDBiasDACT") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DACT, "CastDACT") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_ACT, "CastACT") + +REGISTER_ALL_METHODS(__half, fp8_e4m3, "FP16", "E4M3") +REGISTER_ALL_METHODS(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") +REGISTER_ALL_METHODS(float, fp8_e4m3, "FP32", "E4M3") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/run_benchmarks.sh b/benchmarks/cpp/run_benchmarks.sh new file mode 100755 index 000000000..03d8b5716 --- /dev/null +++ b/benchmarks/cpp/run_benchmarks.sh @@ -0,0 +1,117 @@ +#!/bin/bash +# Builds benchmarks, runs them, and consolidates results into a single CSV + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${SCRIPT_DIR}/build" +RESULTS_DIR="${SCRIPT_DIR}/results" + +setup_test_common_symlinks() { + local utils_dir="${SCRIPT_DIR}/utils" + local test_common_hip="../../tests/cpp/test_common.hip" + local test_common_h="../../tests/cpp/test_common_hip.h" + + if [ ! -f "${SCRIPT_DIR}/${test_common_hip}" ] || [ ! -f "${SCRIPT_DIR}/${test_common_h}" ]; then + echo -e "Error: hipified test_common files not found. Build tests before running benchmarks." + return 1 + fi + + if [ ! -L "${utils_dir}/test_common.hip" ] || [ ! -e "${utils_dir}/test_common.hip" ]; then + ln -sf "../${test_common_hip}" "${utils_dir}/test_common.hip" + fi + + if [ ! -L "${utils_dir}/test_common_hip.h" ] || [ ! -e "${utils_dir}/test_common_hip.h" ]; then + ln -sf "../${test_common_h}" "${utils_dir}/test_common_hip.h" + fi + + return 0 +} + +main() { + echo -e "=== MXFP8 Benchmark Suite ===" + + if ! setup_test_common_symlinks; then + return + fi + + echo -e "\n[1/3] Building benchmarks..." + cd "${SCRIPT_DIR}" + if ! cmake -GNinja -B"${BUILD_DIR}" . || ! cmake --build "${BUILD_DIR}"; then + echo -e "Build failed. Fix the build errors and try again." + return + fi + echo -e "✓ Build complete" + + mkdir -p "${RESULTS_DIR}" + TIMESTAMP=$(date +%Y%m%d_%H%M%S) + RESULT_PREFIX="${RESULTS_DIR}/bench_${TIMESTAMP}" + + echo -e "\n[2/3] Running benchmarks..." + + BENCHMARKS=( + "bench_quantize_mxfp8_fused" + "bench_dequantize_mxfp8" + "bench_gated_mxfp8" + ) + + FAILED_BENCHMARKS=() + for bench in "${BENCHMARKS[@]}"; do + if [ -f "${BUILD_DIR}/${bench}" ]; then + echo -e " Running ${bench}..." + if "${BUILD_DIR}/${bench}" \ + --benchmark_out="${RESULT_PREFIX}_${bench}.csv" \ + --benchmark_out_format=csv \ + --benchmark_min_time=0.2s; then + echo -e " ✓ Saved to ${RESULT_PREFIX}_${bench}.csv" + else + echo -e " ✗ ${bench} failed (exit code $?), continuing..." + FAILED_BENCHMARKS+=("${bench}") + fi + else + echo -e " ✗ ${bench} not found, skipping" + fi + done + + echo -e "\n[3/3] Consolidating results..." + + CONSOLIDATED_CSV="${RESULT_PREFIX}_all.csv" + FIRST_CSV=$(ls "${RESULT_PREFIX}"_*.csv 2>/dev/null | grep -v "_all.csv" | head -1) + + if [ -z "$FIRST_CSV" ]; then + echo -e "No CSV files found to consolidate" + return + fi + + head -1 "$FIRST_CSV" > "$CONSOLIDATED_CSV" + + for csv in "${RESULT_PREFIX}"_bench_*.csv; do + if [ "$csv" != "$CONSOLIDATED_CSV" ]; then + tail -n +2 "$csv" >> "$CONSOLIDATED_CSV" + fi + done + + echo -e "✓ Consolidated CSV: ${CONSOLIDATED_CSV}" + + echo -e "\n=== Summary ===" + TOTAL_ROWS=$(tail -n +2 "$CONSOLIDATED_CSV" | wc -l) + echo "Total benchmarks: $TOTAL_ROWS" + echo "Results saved to: ${RESULTS_DIR}/" + echo "" + echo "Files created:" + for bench in "${BENCHMARKS[@]}"; do + if [ -f "${RESULT_PREFIX}_${bench}.csv" ]; then + echo " - $(basename "${RESULT_PREFIX}_${bench}.csv")" + fi + done + echo " - $(basename "$CONSOLIDATED_CSV") (consolidated)" + echo "" + + if [ ${#FAILED_BENCHMARKS[@]} -gt 0 ]; then + echo -e "Failed benchmarks:" + for bench in "${FAILED_BENCHMARKS[@]}"; do + echo -e " ✗ ${bench}" + done + echo "" + fi +} + +main diff --git a/benchmarks/cpp/utils/benchmark_utils.h b/benchmarks/cpp/utils/benchmark_utils.h new file mode 100644 index 000000000..bd2906b30 --- /dev/null +++ b/benchmarks/cpp/utils/benchmark_utils.h @@ -0,0 +1,223 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "test_common_hip.h" + +namespace te_bench { + +#define HIP_CHECK(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error at %s:%d: %s\n", __FILE__, __LINE__, \ + hipGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +template +class DeviceBuffer { + public: + DeviceBuffer(size_t count) : count_(count) { + HIP_CHECK(hipMalloc(&ptr_, count * sizeof(T))); + } + + ~DeviceBuffer() { + if (ptr_) { + hipError_t err = hipFree(ptr_); + (void)err; + } + } + + DeviceBuffer(const DeviceBuffer &) = delete; + DeviceBuffer &operator=(const DeviceBuffer &) = delete; + + DeviceBuffer(DeviceBuffer &&other) noexcept : ptr_(other.ptr_), count_(other.count_) { + other.ptr_ = nullptr; + other.count_ = 0; + } + + T *get() { return ptr_; } + const T *get() const { return ptr_; } + size_t count() const { return count_; } + size_t bytes() const { return count_ * sizeof(T); } + + void upload(const std::vector &host_data) { + if (host_data.size() != count_) { + throw std::runtime_error("Size mismatch in upload"); + } + HIP_CHECK(hipMemcpy(ptr_, host_data.data(), bytes(), hipMemcpyHostToDevice)); + } + + void download(std::vector &host_data) const { + host_data.resize(count_); + HIP_CHECK(hipMemcpy(host_data.data(), ptr_, bytes(), hipMemcpyDeviceToHost)); + } + + private: + T *ptr_ = nullptr; + size_t count_ = 0; +}; + +template +std::vector generate_random_data(size_t count, T min_val = -1.0, T max_val = 1.0) { + std::vector data(count); + std::mt19937 gen(42); + + if constexpr (std::is_floating_point_v) { + std::uniform_real_distribution dist(min_val, max_val); + for (auto &val : data) { + val = dist(gen); + } + } else { + std::uniform_int_distribution dist(static_cast(min_val), static_cast(max_val)); + for (auto &val : data) { + val = static_cast(dist(gen)); + } + } + + return data; +} + +__global__ void scale_shift_kernel(float *data, size_t count, float scale, float offset) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < count) { + data[idx] = data[idx] * scale + offset; + } +} + +inline void fill_random_uniform_gpu(float *dptr, size_t count, float min_val = -2.0f, float max_val = 1.0f, hipStream_t stream = 0) { + hiprandGenerator_t gen; + hiprandCreateGenerator(&gen, HIPRAND_RNG_PSEUDO_DEFAULT); + hiprandSetPseudoRandomGeneratorSeed(gen, 42); + if (stream != 0) { + hiprandSetStream(gen, stream); + } + hiprandGenerateUniform(gen, dptr, count); + float scale = max_val - min_val; + float offset = min_val; + + size_t threads = 256; + size_t blocks = (count + threads - 1) / threads; + scale_shift_kernel<<>>(dptr, count, scale, offset); + + hiprandDestroyGenerator(gen); +} + +template +__global__ void cast_fp32_kernel(const float *in, T *out, size_t count) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < count) { + out[idx] = static_cast(in[idx]); + } +} + +template +inline void fill_random_uniform_gpu_typed(T *dptr, size_t count, float min_val = -2.0f, float max_val = 1.0f, hipStream_t stream = 0) { + if constexpr (std::is_same_v) { + fill_random_uniform_gpu(dptr, count, min_val, max_val, stream); + } else { + DeviceBuffer temp_fp32(count); + fill_random_uniform_gpu(temp_fp32.get(), count, min_val, max_val, stream); + + size_t threads = 256; + size_t blocks = (count + threads - 1) / threads; + cast_fp32_kernel<<>>(temp_fp32.get(), dptr, count); + } +} + +inline void warmup_gpu(int iterations = 10) { + DeviceBuffer dummy(1024); + for (int i = 0; i < iterations; ++i) { + HIP_CHECK(hipMemset(dummy.get(), 0, dummy.bytes())); + } + HIP_CHECK(hipDeviceSynchronize()); +} + +inline double calculate_bandwidth_gbps(size_t bytes, double time_ns) { + return (bytes / 1e9) / (time_ns / 1e9); +} + +inline void set_items_processed(benchmark::State &state, size_t items_per_iter) { + state.SetItemsProcessed(state.iterations() * items_per_iter); +} + +inline void set_bytes_processed(benchmark::State &state, size_t bytes_per_iter) { + state.SetBytesProcessed(state.iterations() * bytes_per_iter); +} + +class TensorCache { + public: + struct CacheKey { + std::string name; + size_t rows; + size_t cols; + transformer_engine::DType dtype; + bool rowwise; + bool colwise; + NVTEScalingMode scaling_mode; + + bool operator<(const CacheKey &other) const { + return std::tie(name, rows, cols, dtype, rowwise, colwise, scaling_mode) < + std::tie(other.name, other.rows, other.cols, other.dtype, other.rowwise, other.colwise, other.scaling_mode); + } + }; + + static test::Tensor &get_or_create(const std::string &name, + const std::vector &shape, + transformer_engine::DType dtype, + bool rowwise = true, + bool colwise = false, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING, + bool initialize_random = false) { + CacheKey key{name, shape[0], shape[1], dtype, rowwise, colwise, scaling_mode}; + + static auto* cache = new std::map>(); + + auto it = cache->find(key); + if (it == cache->end()) { + auto tensor_ptr = std::make_unique(name, shape, dtype, rowwise, colwise, scaling_mode); + + if (initialize_random && dtype != transformer_engine::DType::kFloat8E4M3 && + dtype != transformer_engine::DType::kFloat8E5M2) { + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + size_t count = shape[0] * shape[1]; + void *data_ptr = tensor_ptr->rowwise_dptr(); + + if (dtype == transformer_engine::DType::kFloat32) { + fill_random_uniform_gpu(static_cast(data_ptr), count, -2.0f, 1.0f, stream); + } else if (dtype == transformer_engine::DType::kFloat16) { + fill_random_uniform_gpu_typed<__half>(static_cast<__half*>(data_ptr), count, -2.0f, 1.0f, stream); + } else if (dtype == transformer_engine::DType::kBFloat16) { + fill_random_uniform_gpu_typed(static_cast(data_ptr), count, -2.0f, 1.0f, stream); + } + + HIP_CHECK(hipStreamSynchronize(stream)); + HIP_CHECK(hipStreamDestroy(stream)); + } + + (*cache)[key] = std::move(tensor_ptr); + it = cache->find(key); + } + + return *(it->second); + } +}; +} // namespace te_bench diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 29fd3a66d..c7d20778e 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -17,7 +17,9 @@ #include #include +#ifndef NVTE_ROCM_BENCHMARK #include +#endif #include #include @@ -28,9 +30,13 @@ namespace test { size_t create_seed_from_tensor_name(const std::string& tensor_name) { +#ifndef NVTE_ROCM_BENCHMARK auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) + "/" + tensor_name; return std::hash{}(full_name); +#else + return std::hash{}(tensor_name); +#endif } std::vector all_fp_types = {DType::kFloat32, @@ -619,6 +625,8 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { return ret; } +#ifndef NVTE_ROCM_BENCHMARK + void compareResults_sequential(const std::string &name, const Tensor &test, const void *ref, const bool rowwise, double atol, double rtol, bool if_on_gpus, @@ -923,6 +931,7 @@ void compare_scaling_factors(const std::string &name, const fp8e4m3 *te const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit); +#endif // NVTE_ROCM_BENCHMARK std::pair getTolerances(const DType type) { switch(type) { diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index a056cf50f..6a0d65685 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -807,20 +807,33 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu const size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); +#ifdef __HIP_PLATFORM_AMD__ + // rowwise-only uses direct global reads/writes, no shmem needed + size_t in_mem = 0; + size_t out_mem = 0; + if (USE_COLWISE_SCALING) { + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + in_mem = grad_mem + in_act_mem + in_gate_mem; + + out_mem = buff_size_aligned_out; + if constexpr (IS_BWD) { + out_mem += buff_size_aligned_out; + } + } +#else const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; -#ifdef __HIP_PLATFORM_AMD__ - const size_t out_gate_mem = buff_size_aligned_out; -#else const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); -#endif size_t out_mem = out_act_mem + out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } +#endif const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a69915c0d..ac32ceacd 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -576,9 +576,80 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t cols = input.flat_last_dim(); #ifdef __HIP_PLATFORM_AMD__ - constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; - constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; - constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; + // Choose chunk size based on tensor dimensions + const bool use_large_chunks = (rows * cols > 32 * 1024 * 1024); + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_large_chunks, USE_LARGE_CHUNKS, + + constexpr size_t CHUNK_DIM_Y = USE_LARGE_CHUNKS ? 128 : 64; + constexpr size_t CHUNK_DIM_X = USE_LARGE_CHUNKS ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = USE_LARGE_CHUNKS ? 256 : 128; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(IType))), IS_ALIGNED, + quantize_mxfp8_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + noop_ptr, workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))))); + + if constexpr (IS_DBIAS) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + ); // NOLINT(*) + } + ); #else constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); @@ -590,7 +661,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; constexpr size_t BUFF_DIM_Y = THREADS_Y; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); @@ -608,7 +678,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; -#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (use_rowwise_scaling && (!use_colwise_scaling)) { scaling_type = ScalingType::ROWWISE; @@ -617,7 +686,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } else if (use_rowwise_scaling && use_colwise_scaling) { scaling_type = ScalingType::BIDIMENSIONAL; } -#endif if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); @@ -639,27 +707,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, - -#ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - quantize_mxfp8_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - ))); // NOLINT(*) -#else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_output_rowwise{}; @@ -749,12 +796,12 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } } -#endif // #ifdef __HIP_PLATFORM_AMD__ if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); }); // NOLINT(*) ); // NOLINT(*) +#endif // __HIP_PLATFORM_AMD__ } } // namespace mxfp8 diff --git a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index aac39c6d2..147951899 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -41,7 +41,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t THREADS_PER_SCALE_X_ROWWISE = DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(IType); + constexpr size_t VECTOR_WIDTH = IS_ALIGNED ? 8 : 16; const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; @@ -68,7 +68,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); diff --git a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index eb389f9db..fee3db3e0 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -9,370 +9,526 @@ //#include "hip/hip_runtime.h" //dummy include to prevent hipification adding this header constexpr size_t ALIGNMENT_SIZE = 128; -// TODO: Identify optimal chunk/thread size for MI350+ constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; constexpr size_t THREADS_PER_CHUNK = 256; -constexpr size_t THREADS_PER_CHUNK_X = 64; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 256 / 64 constexpr size_t BUFFERS_NUM = 1; // No async load for HIP constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 2 static_assert(ITERATIONS >= 1); +constexpr size_t ELEMS_PER_THREAD = 8; +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 32 +static_assert(THREADS_PER_CHUNK_Y_ROWWISE <= BUFFER_DIM_Y); +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 64 +constexpr size_t THREADS_PER_CHUNK_Y_COLWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_COLWISE; // 4 +constexpr size_t BUFFER_STAGES_NUM_COLWISE = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_COLWISE; // 8 +constexpr size_t THREADS_PER_CHUNK_X = THREADS_PER_CHUNK_X_COLWISE; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK_Y_COLWISE; +constexpr size_t BUFFER_STAGES_NUM = BUFFER_STAGES_NUM_COLWISE; + __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } +template +__device__ inline void compute_gated_activation( + float act_elt, float gate_elt, float grad_elt, + const ParamOP &p, float &result_act, float &result_gate) { + + bool dgate_elt_valid = true; + if constexpr (std::is_same::value) { + dgate_elt_valid = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } + + if constexpr (IS_DGATED) { + const float x = act_elt; + float act_x, dact_x; + if constexpr (std::is_same::value) { + const float cx = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * cx); + act_x = cx * s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * cx : 0.0f; + } else { + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } + } + result_act = dact_x * grad_elt * gate_elt; + result_gate = dgate_elt_valid ? act_x * grad_elt : 0.0f; + } else { + result_act = ActOP(act_elt, p) * gate_elt; + result_gate = 0.0f; + } + + if constexpr (!std::is_same_v) { + result_act = static_cast(static_cast(result_act)); + if constexpr (IS_DGATED) { + result_gate = static_cast(static_cast(result_gate)); + } + } +} + template __global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_gated_mxfp8_kernel( - const IType *grad_ptr, - const IType *input_act, - const IType *input_gate, - OType *output_act_rowwise, - OType *output_gate_rowwise, - OType *output_act_colwise, - OType *output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, const ParamOP p) { + const IType *grad_ptr, const IType *input_act, const IType *input_gate, OType *output_act_rowwise, + OType *output_gate_rowwise, OType *output_act_colwise, OType *output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, const size_t scale_stride_colwise, const ParamOP p) { constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 4 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; - const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType); + const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + constexpr size_t VECTOR_WIDTH_IN = IS_ALIGNED ? 8 : 16; + constexpr size_t VECTOR_WIDTH_OUT = 16; - const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + constexpr size_t ROWS_PER_THREAD = CHUNK_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + // ROWWISE-ONLY PATH: Direct global memory, no shared memory + if constexpr (USE_ROWWISE_SCALING && !USE_COLWISE_SCALING) { + const size_t col_start = chunk_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); +#pragma unroll + for (size_t r = 0; r < ROWS_PER_THREAD; r++) { + const size_t row = chunk_offset_Y + tid_rowwise_Y + r * THREADS_PER_CHUNK_Y_ROWWISE; - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = in_act_mem + in_gate_mem; + const bool row_valid = (row < rows); - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t out_mem = out_act_mem + out_gate_mem; + Vec act_vec, gate_vec, grad_vec; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + act_vec.load_from(&input_act[row * 2*cols + col_start]); + gate_vec.load_from(&input_gate[row * 2*cols + col_start]); + if constexpr (IS_DGATED) { + grad_vec.load_from(&grad_ptr[row * cols + col_start]); + } + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + act_vec.data.elt[j] = (col_start + j < cols) ? input_act[row * 2*cols + col_start + j] : static_cast(0); + gate_vec.data.elt[j] = (col_start + j < cols) ? input_gate[row * 2*cols + col_start + j] : static_cast(0); + if constexpr (IS_DGATED) { + grad_vec.data.elt[j] = (col_start + j < cols) ? grad_ptr[row * cols + col_start + j] : static_cast(0); + } + } + } + } + + // Compute activations + float computed_act[ELEMS_PER_THREAD]; + float computed_gate[ELEMS_PER_THREAD]; + float act_amax = 0; + float gate_amax = 0; - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); + float act_elt = static_cast(act_vec.data.elt[j]); + float gate_elt = static_cast(gate_vec.data.elt[j]); + float grad_elt = IS_DGATED ? static_cast(grad_vec.data.elt[j]) : 0.0f; + + compute_gated_activation( + act_elt, gate_elt, grad_elt, p, computed_act[j], computed_gate[j]); + + if (!out_of_bounds) { + act_amax = fmaxf(act_amax, fabsf(computed_act[j])); + if constexpr (IS_DGATED) { + gate_amax = fmaxf(gate_amax, fabsf(computed_gate[j])); + } + } + } - OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + // --- Act rowwise quantization --- + { + __builtin_assume(act_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(act_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); - OType *out_act_colwise_sh = out_act_rowwise_sh; - OType *out_gate_colwise_sh = out_gate_rowwise_sh; + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_act[j] * scale_inv); + } - if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { - out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); - out_gate_colwise_sh = - reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_act_rowwise[row * output_cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_act_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } + } + } + + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exp; + } + } + + // --- Gate rowwise quantization (BWD only) --- + if constexpr (IS_DGATED) { + __builtin_assume(gate_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(gate_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); + + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_gate[j] * scale_inv); + } + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_gate_rowwise[row * output_cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_gate_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } + } + } + + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE + + DIVUP(cols, SCALE_DIM_X); + scales_rowwise[scale_idx] = biased_exp; + } + } + } } - __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + // COLWISE PATH: Shared memory for input + colwise output + if constexpr (USE_COLWISE_SCALING) { + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_elems_total = BUFFERS_NUM * buff_elems; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_mem = in_act_mem + buff_size_aligned_in; // act + gate + + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - __syncthreads(); + OType *out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + buff_size_aligned_out); + + // For colwise cross-thread Y reduction + __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y_COLWISE][CHUNK_DIM_X]; + + __syncthreads(); for (int it = 0; it < ITERATIONS; it++) { const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - const size_t row_base = chunk_it_offset_y; - // Initiate bulk tensor copy + // === Load input to shmem === if constexpr (IS_DGATED) { - copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, - cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + copy_2d_to_shared( + &in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, + cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } - - // Act - copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - - // Gate - copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + copy_2d_to_shared( + &in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + copy_2d_to_shared( + &in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); - const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; - const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; - - float after_dact_reg[BUFFER_STAGES_NUM]; - float after_dgate_reg[BUFFER_STAGES_NUM]; - float thread_Y_mx_block_amax = 0.0f; - float thread_Y_mx_block_amax_gate = 0.0f; + if constexpr (USE_ROWWISE_SCALING) { + const size_t row = chunk_it_offset_y + tid_rowwise_Y; + const size_t col_start = chunk_offset_X + thread_offset_X_rowwise; + + const bool row_valid = (row < rows); + const bool col_valid = (col_start < cols); + + const int shmem_base = tid_rowwise_Y * SHMEM_DIM_X + thread_offset_X_rowwise; + Vec act_vec, gate_vec; + act_vec.load_from(&in_act_sh[shmem_base]); + gate_vec.load_from(&in_gate_sh[shmem_base]); + Vec grad_vec; + if constexpr (IS_DGATED) { + grad_vec.load_from(&in_grad_sh[shmem_base]); + } - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + float computed_act[ELEMS_PER_THREAD]; + float computed_gate[ELEMS_PER_THREAD]; + float act_amax = 0; + float gate_amax = 0; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + float act_elt = static_cast(act_vec.data.elt[j]); + float gate_elt = static_cast(gate_vec.data.elt[j]); + float grad_elt = IS_DGATED ? static_cast(grad_vec.data.elt[j]) : 0.0f; - float act_elt = static_cast(in_act_sh[shmem_idx]); - float gate_elt = static_cast(in_gate_sh[shmem_idx]); + compute_gated_activation( + act_elt, gate_elt, grad_elt, p, computed_act[j], computed_gate[j]); - bool dgate_elt = true; // gating is ideally an identity function - if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + act_amax = fmaxf(act_amax, fabsf(computed_act[j])); + if constexpr (IS_DGATED) { + gate_amax = fmaxf(gate_amax, fabsf(computed_gate[j])); + } } - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh[shmem_idx]); - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr (std::is_same::value) { - const float x = min(act_elt, p.limit); - const float s = sigmoidf(p.alpha * x); - act_x = x * s; - dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; - } else { - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; + // --- Act rowwise quantization --- + { + __builtin_assume(act_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(act_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); + + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_act[j] * scale_inv); + } + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_act_rowwise[row * output_cols + col_start]); } else { - act_x = ActOP(x, p); - dact_x = DActOP(x, p); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_act_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } } } - - after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = dgate_elt ? act_x * grad_elt : 0.0f; - } else { - after_dact_reg[stage] = ActOP(act_elt, p) * gate_elt; - } - // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 - if constexpr (!std::is_same_v) { - after_dact_reg[stage] = static_cast(static_cast(after_dact_reg[stage])); - if constexpr (IS_DGATED) { - after_dgate_reg[stage] = static_cast(static_cast(after_dgate_reg[stage])); + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exp; } } - if constexpr (USE_ROWWISE_SCALING) { - if constexpr (IS_DGATED) { - // dgate - float amax = fabsf(after_dgate_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); - - out_gate_rowwise_sh[shmem_idx] = - static_cast(scale_reciprocal_X * after_dgate_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + // --- Gate rowwise quantization (BWD only) --- + if constexpr (IS_DGATED) { + __builtin_assume(gate_amax >= 0); + const float scale_amax = subwarp_reduce_max_broadcast(gate_amax); + const e8m0_t biased_exp = + ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); + const float scale_inv = ptx::exp2f_rcp(biased_exp); + + Vec out_vec; +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_vec.data.elt[j] = static_cast(computed_gate[j] * scale_inv); + } + + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_vec.store_to(&output_gate_rowwise[row * output_cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_gate_rowwise[row * output_cols + col_start + j] = out_vec.data.elt[j]; + } + } } } - float amax = fabsf(after_dact_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); - - out_act_rowwise_sh[shmem_idx] = - static_cast(scale_reciprocal_X * after_dact_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE + + DIVUP(cols, SCALE_DIM_X); + scales_rowwise[scale_idx] = biased_exp; } } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (chunk_offset_X + tid_colwise_X >= cols); + const size_t row_base = chunk_it_offset_y; + const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; + + float after_dact_reg[BUFFER_STAGES_NUM_COLWISE]; + float after_dgate_reg[BUFFER_STAGES_NUM_COLWISE]; + + float thread_Y_mx_block_amax = 0.0f; + float thread_Y_mx_block_amax_gate = 0.0f; + + // Compute activation and accumulate column amax + for (int stage = 0; stage < BUFFER_STAGES_NUM_COLWISE; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_COLWISE; + const int shmem_offset_y = tid_colwise_Y + stage_offset_Y; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + tid_colwise_X; + + float act_elt = static_cast(in_act_sh[shmem_idx]); + float gate_elt = static_cast(in_gate_sh[shmem_idx]); + float grad_elt = 0.0f; + if constexpr (IS_DGATED) { + grad_elt = static_cast(in_grad_sh[shmem_idx]); + } + + compute_gated_activation( + act_elt, gate_elt, grad_elt, p, after_dact_reg[stage], after_dgate_reg[stage]); - if constexpr (USE_COLWISE_SCALING) { __builtin_assume(thread_Y_mx_block_amax >= 0); - __builtin_assume(thread_Y_mx_block_amax_gate >= 0); thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); if constexpr (IS_DGATED) { + __builtin_assume(thread_Y_mx_block_amax_gate >= 0); thread_Y_mx_block_amax_gate = fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); } } - } - if constexpr (USE_COLWISE_SCALING) { const bool row_out_of_bounds = (row_base >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); if constexpr (IS_DGATED) { - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + if (tid_colwise_Y > 0) { + stage_amax_sh[tid_colwise_Y][tid_colwise_X] = thread_Y_mx_block_amax_gate; } __syncthreads(); - if (tid_Y == 0) { + if (tid_colwise_Y == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + for (int y = 1; y < THREADS_PER_CHUNK_Y_COLWISE; ++y) { thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_colwise_X]); } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + stage_amax_sh[0][tid_colwise_X] = thread_Y_mx_block_amax_gate; } __syncthreads(); - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax - - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); - } + const float mx_block_Y_amax = stage_amax_sh[0][tid_colwise_X]; + __builtin_assume(mx_block_Y_amax >= 0); const e8m0_t biased_exponent = ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { + if ((tid_colwise_Y == 0) && !out_of_bounds) { const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X + cols; const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; scales_colwise[scale_idx] = biased_exponent; } #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - + for (int stage = 0; stage < BUFFER_STAGES_NUM_COLWISE; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_COLWISE; + const int shmem_idx = (tid_colwise_Y + stage_offset_Y) * SHMEM_DIM_X + tid_colwise_X; out_gate_colwise_sh[shmem_idx] = static_cast(scale_reciprocal * after_dgate_reg[stage]); } } - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; - } - __syncthreads(); - if (tid_Y == 0) { + + { + if (tid_colwise_Y > 0) { + stage_amax_sh[tid_colwise_Y][tid_colwise_X] = thread_Y_mx_block_amax; + } + __syncthreads(); + if (tid_colwise_Y == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + for (int y = 1; y < THREADS_PER_CHUNK_Y_COLWISE; ++y) { + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_colwise_X]); + } + stage_amax_sh[0][tid_colwise_X] = thread_Y_mx_block_amax; } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax - } - __syncthreads(); - - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + __syncthreads(); - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { + const float mx_block_Y_amax = stage_amax_sh[0][tid_colwise_X]; __builtin_assume(mx_block_Y_amax >= 0); - } - const e8m0_t biased_exponent = - ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - } + const e8m0_t biased_exponent = + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); + + if ((tid_colwise_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; + + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_act_colwise_sh[shmem_idx] = - static_cast(scale_reciprocal * after_dact_reg[stage]); + for (int stage = 0; stage < BUFFER_STAGES_NUM_COLWISE; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_COLWISE; + const int shmem_idx = (tid_colwise_Y + stage_offset_Y) * SHMEM_DIM_X + tid_colwise_X; + out_act_colwise_sh[shmem_idx] = + static_cast(scale_reciprocal * after_dact_reg[stage]); + } } } __syncthreads(); - if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - } - } - - if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - } + bulk_tensor_2d_shared_to_global( + &out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + if constexpr (IS_DGATED) { + bulk_tensor_2d_shared_to_global( + &out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } __syncthreads(); } + } } diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index 198728f34..40913cc03 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -9,30 +9,22 @@ constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; constexpr size_t ELEMS_PER_THREAD = 16; -constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported -constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 -constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 - -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = - MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 -constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = - MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_BUFF_STAGES_NUM = - MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 -constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported + +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ +typedef short mxfp8_v2i16_t __attribute__((ext_vector_type(2))); +#endif template -__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + size_t SCALE_DIM_X, bool IS_ALIGNED, + size_t CHUNK_DIM_Y = 64, + size_t CHUNK_DIM_X = 64, + size_t THREADS_PER_CHUNK = 64> +__global__ void __launch_bounds__(THREADS_PER_CHUNK) quantize_mxfp8_kernel( const IType *input_ptr, const IType *act_input_ptr, @@ -47,29 +39,34 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = - SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = - SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; + constexpr size_t SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; + constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; + + constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; + constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; + constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; + + constexpr size_t BUFF_STAGES_NUM = MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; + constexpr size_t ITERATIONS = CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = - SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_X = - SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = CHUNK_DIM_Y; + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = CHUNK_DIM_X; constexpr size_t THREADS_PER_SCALE_X_ROWWISE = DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 - constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType); + // Cap vector width so each load/store is at most 16 bytes (AMD max: global_load_dwordx4) + constexpr size_t VECTOR_WIDTH_IN = 16 / sizeof(IType); // BF16/FP16: 8, FP32: 4 + constexpr size_t VECTOR_WIDTH_OUT = 16 / sizeof(OType); // FP8: 16 - const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; @@ -77,102 +74,208 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_Y = tid_rowwise_Y; const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - // const int thread_offset_X_colwise = tid_colwise_X; - - const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; - const int dbias_rowwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; - const int dbias_colwise_offset_Y = blockIdx.y; - const int dbias_colwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; - const int dbias_stride = cols; - - Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; - float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + + const int dbias_rowwise_offset_Y = blockIdx.y + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = block_offset_X + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = blockIdx.y; + const int dbias_colwise_block_offset_X = block_offset_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise; + float partial_dbias_colwise = 0; if constexpr (IS_DBIAS) { if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; i++) { - partial_dbias_rowwise[i].clear(); - } - } else { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; i++) { - partial_dbias_colwise[i] = 0; - } + partial_dbias_rowwise.clear(); } } - // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - alignas(128) __shared__ IType in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ IType act_in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ OType out_rowwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ OType out_colwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - float block_amax = 0; + constexpr size_t ROWS_PER_THREAD = CHUNK_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; + + if constexpr (USE_ROWWISE_SCALING && !USE_COLWISE_SCALING) { + const size_t col_start = block_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); #pragma unroll - for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; chunk++) { - const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + for (size_t r = 0; r < ROWS_PER_THREAD; r++) { + const size_t row = block_offset_Y + tid_rowwise_Y + r * THREADS_PER_CHUNK_Y_ROWWISE; + const bool row_valid = (row < rows); - const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; - const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + Vec in; + Vec act_in; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + in.load_from(&input_ptr[row * cols + col_start]); + if constexpr (IS_DACT) { + act_in.load_from(&act_input_ptr[row * cols + col_start]); + } + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + in.data.elt[j] = (col_start + j < cols) ? input_ptr[row * cols + col_start + j] + : static_cast(0); + } + if constexpr (IS_DACT) { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + act_in.data.elt[j] = (col_start + j < cols) ? act_input_ptr[row * cols + col_start + j] + : static_cast(0); + } + } + } + } - const int scales_rowwise_chunk_offset_Y = - scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = - scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = - scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = - scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; - __syncthreads(); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise.data.elt[j] += elt; + } + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute[j] = elt; + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + { + constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; + uint32_t my_scale = static_cast(biased_exponent); + if constexpr (SCALES_PER_GROUP >= 4) { + uint32_t s1 = __shfl_down(my_scale, 1 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s2 = __shfl_down(my_scale, 2 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s3 = __shfl_down(my_scale, 3 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t packed = (my_scale & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); + if (tid_rowwise_X == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + scales_rowwise_block_offset_X; + reinterpret_cast(&scales_rowwise[scale_idx])[0] = packed; + } + } else { + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = + row * scale_stride_rowwise + + scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exponent; + } + } + } + Vec out_c; +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); + union { + uint32_t packed[ELEMS_PER_THREAD / 4]; + mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; + } cvt_out{}; #pragma unroll - for (int iter = 0; iter < MXFP8_ITERATIONS; iter++) { - const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - const size_t row_base = chunk_it_offset_y; - if constexpr (IS_DACT) { - copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, - chunk_it_offset_x, chunk_it_offset_y, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); + for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { + if constexpr (std::is_same_v) { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } else { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } + } + memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); } - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, - chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, rows, cols); - __syncthreads(); +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + } +#endif - if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec act_in; - Vec out_c; + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_c.store_to(&output_rowwise[row * cols + col_start]); + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; + } + } + } + } + } + } + + if constexpr (USE_COLWISE_SCALING) { + alignas(128) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(128) __shared__ IType act_in_sh[IS_DACT ? SHMEM_DIM_Y : 1][IS_DACT ? SHMEM_DIM_X : 1]; + alignas(128) __shared__ OType out_colwise_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; - const int iteration_scale_rowwise_offset_Y = - scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const size_t col = block_offset_X + tid_colwise_X; + const bool col_valid_colwise = (col < cols); #pragma unroll - for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; stage++) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X_rowwise; + for (int iter = 0; iter < ITERATIONS; iter++) { + const size_t row_base = block_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if constexpr (IS_DACT) { + copy_2d_to_shared( + &act_in_sh[0][0], act_input_ptr, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + } + copy_2d_to_shared( + &in_sh[0][0], input_ptr, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + __syncthreads(); - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); + if constexpr (USE_ROWWISE_SCALING) { + const size_t col_start = block_offset_X + thread_offset_X_rowwise; + const bool col_valid = (col_start < cols); - in.load_from(&in_sh[shmem_offset_y][shmem_offset_x]); +#pragma unroll + for (int stage = 0; stage < BUFF_STAGES_NUM; stage++) { + const int shmem_y = thread_offset_Y + stage * THREADS_PER_CHUNK_Y_ROWWISE; + const size_t row = row_base + shmem_y; + const bool row_valid = (row < rows); + + Vec in; + Vec act_in; + in.load_from(&in_sh[shmem_y][thread_offset_X_rowwise]); if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_y][shmem_offset_x]); + act_in.load_from(&act_in_sh[shmem_y][thread_offset_X_rowwise]); } float thread_amax = 0; @@ -180,9 +283,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { - const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - + const bool out_of_bounds = (!row_valid || !col_valid || col_start + j >= cols); float elt = static_cast(in.data.elt[j]); if constexpr (IS_ACT) { elt = OP(elt, {}); @@ -193,10 +294,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { if (!out_of_bounds) { - partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + partial_dbias_rowwise.data.elt[j] += elt; } } - // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 if constexpr (!std::is_same_v) { elt = static_cast(static_cast(elt)); } @@ -213,38 +313,84 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); const e8m0_t biased_exponent = ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - // Only single thread writes the computed scaling factor - const bool col_out_of_bounds = dbias_rowwise_offset_X >= cols; - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && !(row_out_of_bounds || col_out_of_bounds)) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; + + { + constexpr size_t SCALES_PER_GROUP = THREADS_PER_CHUNK_X_ROWWISE / THREADS_PER_SCALE_X_ROWWISE; + uint32_t my_scale = static_cast(biased_exponent); + if constexpr (SCALES_PER_GROUP >= 4) { + uint32_t s1 = __shfl_down(my_scale, 1 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s2 = __shfl_down(my_scale, 2 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t s3 = __shfl_down(my_scale, 3 * THREADS_PER_SCALE_X_ROWWISE, THREADS_PER_CHUNK_X_ROWWISE); + uint32_t packed = (my_scale & 0xFF) | ((s1 & 0xFF) << 8) | ((s2 & 0xFF) << 16) | ((s3 & 0xFF) << 24); + if (tid_rowwise_X == 0 && row_valid && col_valid) { + reinterpret_cast(&scales_rowwise[row * scale_stride_rowwise + scales_rowwise_block_offset_X])[0] = packed; + } + } else { + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && row_valid && col_valid) { + const int scale_idx = row * scale_stride_rowwise + + scales_rowwise_block_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + scales_rowwise[scale_idx] = biased_exponent; + } + } } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + Vec out_c; +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); + union { + uint32_t packed[ELEMS_PER_THREAD / 4]; + mxfp8_v2i16_t v2i16[ELEMS_PER_THREAD / 4]; + } cvt_out{}; +#pragma unroll + for (int p = 0; p < ELEMS_PER_THREAD / 4; p++) { + if constexpr (std::is_same_v) { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } else { + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+0], in_compute[p*4+1], cvt_scale, false); + cvt_out.v2i16[p] = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16[p], in_compute[p*4+2], in_compute[p*4+3], cvt_scale, true); + } + } + memcpy(out_c.data.elt, cvt_out.packed, ELEMS_PER_THREAD * sizeof(OType)); + } +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + } +#endif + if (row_valid && col_valid) { + if (IS_ALIGNED || col_start + ELEMS_PER_THREAD <= cols) { + out_c.store_to(&output_rowwise[row * cols + col_start]); + } else { #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + if (col_start + j < cols) { + output_rowwise[row * cols + col_start + j] = out_c.data.elt[j]; + } + } + } } - out_c.store_to(&out_rowwise_sh[shmem_offset_y][shmem_offset_x]); } } - if constexpr (USE_COLWISE_SCALING) { - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (threadIdx.x < CHUNK_DIM_X) { float in_compute[SCALE_DIM_Y]; - float amax = 0; + #pragma unroll for (int i = 0; i < SCALE_DIM_Y; i++) { const size_t row = row_base + i; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const bool out_of_bounds = (!col_valid_colwise || row >= rows); float elt = static_cast(in_sh[i][tid_colwise_X]); if constexpr (IS_ACT) { @@ -256,10 +402,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } if constexpr (IS_DBIAS) { if (!out_of_bounds) { - partial_dbias_colwise[chunk_X] += elt; + partial_dbias_colwise += elt; } } - // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 if constexpr (!std::is_same_v) { elt = static_cast(static_cast(elt)); } @@ -275,35 +420,56 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const e8m0_t biased_exponent = ptx::float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); - const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - const bool row_out_of_bounds = row_base >= rows; - if (!(row_out_of_bounds || col_out_of_bounds)) { + if (col_valid_colwise && row_base < rows) { + const int scale_idx = + (scales_colwise_block_offset_Y + iter) * scale_stride_colwise + col; scales_colwise[scale_idx] = biased_exponent; } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + { + const float cvt_scale = (biased_exponent == 0) ? 1.0f : ptx::exp2f(biased_exponent); #pragma unroll - for (int i = 0; i < SCALE_DIM_Y; i++) { - out_colwise_sh[i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); + for (int i = 0; i < SCALE_DIM_Y; i += 2) { + union { + uint32_t packed; + mxfp8_v2i16_t v2i16; + uint8_t bytes[4]; + } cvt_out{}; + if constexpr (std::is_same_v) { + cvt_out.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( + cvt_out.v2i16, in_compute[i], in_compute[i+1], cvt_scale, false); + } else { + cvt_out.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( + cvt_out.v2i16, in_compute[i], in_compute[i+1], cvt_scale, false); + } + OType val0, val1; + memcpy(&val0, &cvt_out.bytes[0], sizeof(OType)); + memcpy(&val1, &cvt_out.bytes[1], sizeof(OType)); + out_colwise_sh[i][tid_colwise_X] = val0; + if (i + 1 < SCALE_DIM_Y) { + out_colwise_sh[i+1][tid_colwise_X] = val1; + } + } } +#else + { + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; i++) { + out_colwise_sh[i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } +#endif } - + __syncthreads(); - if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, - chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, rows, cols); - } - if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, - chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, rows, cols); - } + bulk_tensor_2d_shared_to_global( + &out_colwise_sh[0][0], output_colwise, + block_offset_X, row_base, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); } @@ -311,58 +477,45 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if constexpr (IS_DBIAS) { if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + __shared__ float shmem_partial_dbias_rowwise[Y][X][ELEMS_PER_THREAD]; if (tid_rowwise_Y > 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; c++) { - partial_dbias_rowwise[c].store_to( - &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); - } + partial_dbias_rowwise.store_to( + &shmem_partial_dbias_rowwise[tid_rowwise_Y - 1][tid_rowwise_X]); } __syncthreads(); if (tid_rowwise_Y == 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; c++) { - Vec other_row_dbias; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; - - const int left_bound = dbias_rowwise_offset_X; - const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + Vec other_row_dbias; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_block_offset_X; + const int left_bound = dbias_rowwise_block_offset_X; + const int right_bound = dbias_rowwise_block_offset_X + ELEMS_PER_THREAD - 1; #pragma unroll - for (int i = 0; i < Y; i++) { - other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); + for (int i = 0; i < Y; i++) { + other_row_dbias.load_from(&shmem_partial_dbias_rowwise[i][tid_rowwise_X]); #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { - partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; - } + for (int j = 0; j < ELEMS_PER_THREAD; j++) { + partial_dbias_rowwise.data.elt[j] += other_row_dbias.data.elt[j]; } + } - // Vectorized store when all elements are inside the boundaries - if (right_bound < cols) { - partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); - } else if (left_bound < cols && right_bound >= cols) { - // Element-by-element store when some elements cross the boundaries - const int in_bound_elts_count = cols - left_bound; - partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, - in_bound_elts_count); - } + if (right_bound < cols) { + partial_dbias_rowwise.store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise.store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); } } } else { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; i++) { - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (threadIdx.x < CHUNK_DIM_X) { + const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_block_offset_X; + const bool col_out_of_bounds = (dbias_colwise_block_offset_X >= cols); if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + dbias_workspace[dbias_offset] = partial_dbias_colwise; } } } @@ -370,8 +523,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); + block_amax = reduce_max(block_amax, warp_id); } if (threadIdx.x == 0 && amax_ptr != nullptr) { diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 9901a688e..70584fac3 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -447,10 +447,8 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) const size_t scale_dim_X_rowwise = 32; const size_t scale_dim_Y_colwise = launch_params.training ? 32 : 1; - const size_t chunks_Y = DIVUP(rows, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_X); + const size_t blocks_Y = DIVUP(rows, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_X); const size_t scale_stride_rowwise = launch_params.z_tensor->scale_inv.shape[1]; const size_t scale_stride_colwise = launch_params.training ? launch_params.z_tensor->columnwise_scale_inv.shape[1] : 1; diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 05fe2f539..d7b97723d 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -9,6 +11,18 @@ namespace transformer_engine { +// AMD Fast tanh using hardware exp instruction +__device__ inline float fast_tanhf(float x) { +#ifdef __HIP_PLATFORM_AMD__ + // tanh(x) saturates at ±1 for |x| > 20 + x = fmaxf(fminf(x, 20.f), -20.f); + float e2x = __expf(2.0f * x); + return (e2x - 1.0f) * __frcp_rn(e2x + 1.0f); +#else + return tanhf(x); +#endif +} + struct Empty {}; struct ClampedSwiGLUParam { @@ -19,13 +33,13 @@ struct ClampedSwiGLUParam { template __device__ inline OType gelu(const IType val, const Empty&) { const float cval = val; - return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); + return cval * (0.5F + 0.5F * fast_tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); } template __device__ inline OType dgelu(const IType val, const Empty&) { const float cval = val; - const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); + const float tanh_out = fast_tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) + 0.5f * (1.f + tanh_out); } @@ -33,7 +47,11 @@ __device__ inline OType dgelu(const IType val, const Empty&) { template __device__ inline OType sigmoid(const IType val, const Empty&) { const float cval = val; +#ifdef __HIP_PLATFORM_AMD__ + return __frcp_rn(1.0f + __expf(-cval)); +#else return 1.f / (1.f + expf(-cval)); +#endif } __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } @@ -61,8 +79,8 @@ template __device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) { const float cval = val; Empty e = {}; - return alpha * cval * dsigmoid(alpha * cval, e) + - sigmoid(alpha * cval, e); + const float s = sigmoid(alpha * cval, e); + return s * (1.f + alpha * cval * (1.f - s)); } template @@ -85,7 +103,8 @@ __device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& template __device__ inline OType dsilu(const IType val, const Empty& e) { const float cval = val; - return cval * dsigmoid(cval, e) + sigmoid(cval, e); + const float s = sigmoid(cval, e); + return s * (1.f + cval * (1.f - s)); } template