Skip to content

Commit 74a82fc

Browse files
committed
MXFP8 Cast Kernel Optimizations
1 parent 98ccd2e commit 74a82fc

16 files changed

Lines changed: 3729 additions & 478 deletions

benchmarks/cpp/CMakeLists.txt

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
cmake_minimum_required(VERSION 3.18)
2+
3+
if(NOT DEFINED CMAKE_CXX_COMPILER)
4+
set(CMAKE_CXX_COMPILER hipcc)
5+
endif()
6+
7+
project(transformer_engine_benchmarks LANGUAGES CXX)
8+
9+
set(CMAKE_CXX_STANDARD 17)
10+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
11+
12+
find_package(HIP REQUIRED)
13+
14+
include(FetchContent)
15+
FetchContent_Declare(
16+
benchmark
17+
GIT_REPOSITORY https://github.com/google/benchmark.git
18+
GIT_TAG v1.8.3
19+
)
20+
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE)
21+
set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable gtest in benchmark" FORCE)
22+
FetchContent_MakeAvailable(benchmark)
23+
24+
include_directories(
25+
${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common/include
26+
${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common
27+
${CMAKE_CURRENT_SOURCE_DIR}/utils
28+
)
29+
30+
if(DEFINED ENV{NVTE_ROCM_ARCH})
31+
set(GPU_TARGETS $ENV{NVTE_ROCM_ARCH})
32+
else()
33+
set(GPU_TARGETS "gfx942;gfx950")
34+
endif()
35+
36+
set(COMMON_COMPILE_OPTIONS
37+
-Wall
38+
-Wextra
39+
-O3
40+
-DNDEBUG
41+
-DUSE_ROCM
42+
--offload-arch=${GPU_TARGETS}
43+
-w
44+
)
45+
46+
find_library(TRANSFORMER_ENGINE_LIB
47+
NAMES transformer_engine
48+
PATHS ${CMAKE_CURRENT_SOURCE_DIR}/../..
49+
${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake
50+
${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib
51+
/usr/local/lib
52+
$ENV{HOME}/.local/lib
53+
NO_DEFAULT_PATH
54+
)
55+
56+
if(NOT TRANSFORMER_ENGINE_LIB)
57+
message(WARNING "TransformerEngine library not found in expected paths. Trying system paths...")
58+
find_library(TRANSFORMER_ENGINE_LIB NAMES transformer_engine)
59+
endif()
60+
61+
if(TRANSFORMER_ENGINE_LIB)
62+
message(STATUS "Found TransformerEngine library: ${TRANSFORMER_ENGINE_LIB}")
63+
else()
64+
message(FATAL_ERROR "TransformerEngine library not found. Please build TransformerEngine first:\n"
65+
" cd ${CMAKE_CURRENT_SOURCE_DIR}/../..\n"
66+
" pip install -e . --no-build-isolation\n"
67+
"Searched paths:\n"
68+
" ${CMAKE_CURRENT_SOURCE_DIR}/../..\n"
69+
" ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake\n"
70+
" ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib")
71+
endif()
72+
73+
function(add_te_benchmark TARGET_NAME SOURCE_FILE)
74+
add_executable(${TARGET_NAME} ${SOURCE_FILE} utils/test_common.cpp)
75+
target_compile_options(${TARGET_NAME} PRIVATE ${COMMON_COMPILE_OPTIONS})
76+
target_link_libraries(${TARGET_NAME} PRIVATE
77+
benchmark::benchmark
78+
${TRANSFORMER_ENGINE_LIB}
79+
hiprand
80+
)
81+
set_target_properties(${TARGET_NAME} PROPERTIES HIP_ARCHITECTURES "${GPU_TARGETS}")
82+
endfunction()
83+
84+
add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp)
85+
add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp)
86+
add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*************************************************************************
2+
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
3+
*
4+
* License for AMD contributions = MIT. See LICENSE for more information
5+
************************************************************************/
6+
7+
#include <benchmark/benchmark.h>
8+
#include <hip/hip_runtime.h>
9+
#include <hip/hip_fp16.h>
10+
#include <hip/hip_bfloat16.h>
11+
#include "amd_detail/hip_float8.h"
12+
13+
#include "benchmark_utils.h"
14+
15+
#include "amd_detail/hip_float8.h"
16+
17+
#include <transformer_engine/cast_hip.h>
18+
#include <transformer_engine/transformer_engine_hip.h>
19+
20+
using namespace te_bench;
21+
using namespace transformer_engine;
22+
using fp8_e4m3 = test::fp8e4m3;
23+
24+
// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B)
25+
#define COMMON_SHAPES \
26+
->Args({1024, 3584}) \
27+
->Args({1024, 4096}) \
28+
->Args({1024, 8192}) \
29+
->Args({1024, 14336}) \
30+
->Args({2048, 4096}) \
31+
->Args({2048, 8192}) \
32+
->Args({2048, 14336}) \
33+
->Args({2048, 28672}) \
34+
->Args({4096, 4096}) \
35+
->Args({4096, 8192}) \
36+
->Args({4096, 16384}) \
37+
->Args({4096, 28672}) \
38+
->Args({8192, 8192}) \
39+
->Args({8192, 16384}) \
40+
->Args({8192, 28672}) \
41+
->Args({8192, 53248}) \
42+
->Args({16384, 8192}) \
43+
->Args({16384, 16384})\
44+
->Args({32768, 8192})
45+
46+
template <typename IType, typename OType, int SCALE_DIM_Y, int SCALE_DIM_X>
47+
static void BM_DequantizeMXFP8(benchmark::State &state) {
48+
const size_t rows = state.range(0);
49+
const size_t cols = state.range(1);
50+
51+
constexpr bool USE_ROWWISE = SCALE_DIM_X > 1;
52+
constexpr bool USE_COLWISE = SCALE_DIM_Y > 1;
53+
54+
const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0;
55+
const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0;
56+
const size_t scale_cols_col = USE_COLWISE ? cols : 0;
57+
58+
std::vector<size_t> shape = {rows, cols};
59+
DType itype = std::is_same_v<IType, fp8_e4m3> ? DType::kFloat8E4M3 : DType::kFloat8E5M2;
60+
DType otype = std::is_same_v<OType, __half> ? DType::kFloat16 :
61+
(std::is_same_v<OType, hip_bfloat16> ? DType::kBFloat16 : DType::kFloat32);
62+
63+
test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, USE_ROWWISE, USE_COLWISE,
64+
NVTE_MXFP8_1D_SCALING, false);
65+
test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, true, false,
66+
NVTE_DELAYED_TENSOR_SCALING, false);
67+
68+
hipStream_t stream;
69+
HIP_CHECK(hipStreamCreate(&stream));
70+
71+
DeviceBuffer<float> temp_fp32(rows * cols);
72+
fill_random_uniform_gpu(temp_fp32.get(), rows * cols, -2.0f, 1.0f, stream);
73+
74+
void *input_data_ptr = USE_ROWWISE ? input_tensor.rowwise_dptr() : input_tensor.columnwise_dptr();
75+
size_t threads = 256;
76+
size_t blocks = (rows * cols + threads - 1) / threads;
77+
cast_fp32_kernel<<<blocks, threads, 0, stream>>>(temp_fp32.get(), static_cast<IType*>(input_data_ptr), rows * cols);
78+
79+
HIP_CHECK(hipStreamSynchronize(stream));
80+
81+
hipEvent_t start, stop;
82+
HIP_CHECK(hipEventCreate(&start));
83+
HIP_CHECK(hipEventCreate(&stop));
84+
85+
warmup_gpu();
86+
87+
for (auto _ : state) {
88+
HIP_CHECK(hipEventRecord(start, stream));
89+
90+
nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
91+
92+
HIP_CHECK(hipEventRecord(stop, stream));
93+
HIP_CHECK(hipEventSynchronize(stop));
94+
95+
float ms = 0;
96+
HIP_CHECK(hipEventElapsedTime(&ms, start, stop));
97+
state.SetIterationTime(ms / 1000.0);
98+
}
99+
100+
HIP_CHECK(hipEventDestroy(start));
101+
HIP_CHECK(hipEventDestroy(stop));
102+
103+
const size_t bytes_read_data = rows * cols * sizeof(IType) *
104+
((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0));
105+
const size_t bytes_read_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) +
106+
(USE_COLWISE ? scale_rows_col * scale_cols_col : 0);
107+
const size_t bytes_write = rows * cols * sizeof(OType);
108+
const size_t total_bytes = bytes_read_data + bytes_read_scales + bytes_write;
109+
110+
set_bytes_processed(state, total_bytes);
111+
112+
HIP_CHECK(hipStreamDestroy(stream));
113+
}
114+
115+
#define REGISTER_DEQUANTIZE_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \
116+
BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 1, 32) \
117+
->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/rowwise") \
118+
COMMON_SHAPES \
119+
->Unit(benchmark::kMicrosecond) \
120+
->UseManualTime(); \
121+
BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 32, 1) \
122+
->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/colwise") \
123+
COMMON_SHAPES \
124+
->Unit(benchmark::kMicrosecond) \
125+
->UseManualTime();
126+
127+
REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, __half, "E4M3", "FP16")
128+
REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, hip_bfloat16, "E4M3", "BF16")
129+
REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, float, "E4M3", "FP32")
130+
131+
BENCHMARK_MAIN();

0 commit comments

Comments
 (0)