-
Notifications
You must be signed in to change notification settings - Fork 25
Mxfp8 cast optimization #507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
alextmagro
wants to merge
5
commits into
dev
Choose a base branch
from
mxfp8_cast_optimization
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
74a82fc
MXFP8 Cast Kernel Optimizations
alextmagro d2351bc
cleanup and use of TRANSFORMER_ENGINE_SWITCH_CONDITION
alextmagro 87e5752
Rework utils to use test/cpp/test_common*
alextmagro 91f8441
Fix clamping with fast_tanh
alextmagro 330c40e
Formatting and comments
alextmagro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| cmake_minimum_required(VERSION 3.18) | ||
|
|
||
| if(NOT DEFINED CMAKE_CXX_COMPILER) | ||
| set(CMAKE_CXX_COMPILER hipcc) | ||
| endif() | ||
|
|
||
| project(transformer_engine_benchmarks LANGUAGES CXX) | ||
|
|
||
| set(CMAKE_CXX_STANDARD 17) | ||
| set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
|
|
||
| find_package(HIP REQUIRED) | ||
|
|
||
| include(FetchContent) | ||
| FetchContent_Declare( | ||
| benchmark | ||
| GIT_REPOSITORY https://github.com/google/benchmark.git | ||
| GIT_TAG v1.8.3 | ||
| ) | ||
| set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) | ||
| set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable gtest in benchmark" FORCE) | ||
| FetchContent_MakeAvailable(benchmark) | ||
|
|
||
| include_directories( | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common/include | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/utils | ||
| ) | ||
|
|
||
| if(DEFINED ENV{NVTE_ROCM_ARCH}) | ||
| set(GPU_TARGETS $ENV{NVTE_ROCM_ARCH}) | ||
| else() | ||
| set(GPU_TARGETS "gfx942;gfx950") | ||
| endif() | ||
|
|
||
| set(COMMON_COMPILE_OPTIONS | ||
| -Wall | ||
| -Wextra | ||
| -O3 | ||
| -DNDEBUG | ||
| -DUSE_ROCM | ||
| --offload-arch=${GPU_TARGETS} | ||
| ) | ||
|
|
||
| find_library(TRANSFORMER_ENGINE_LIB | ||
| NAMES transformer_engine | ||
| PATHS ${CMAKE_CURRENT_SOURCE_DIR}/../.. | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib | ||
| /usr/local/lib | ||
| $ENV{HOME}/.local/lib | ||
| NO_DEFAULT_PATH | ||
| ) | ||
|
|
||
| if(NOT TRANSFORMER_ENGINE_LIB) | ||
| message(WARNING "TransformerEngine library not found in expected paths. Trying system paths...") | ||
| find_library(TRANSFORMER_ENGINE_LIB NAMES transformer_engine) | ||
| endif() | ||
|
|
||
| if(TRANSFORMER_ENGINE_LIB) | ||
| message(STATUS "Found TransformerEngine library: ${TRANSFORMER_ENGINE_LIB}") | ||
| else() | ||
| message(FATAL_ERROR "TransformerEngine library not found. Please build TransformerEngine first:\n" | ||
| " cd ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" | ||
| " pip install -e . --no-build-isolation\n" | ||
| "Searched paths:\n" | ||
| " ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" | ||
| " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake\n" | ||
| " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib") | ||
| endif() | ||
|
|
||
| function(add_te_benchmark TARGET_NAME SOURCE_FILE) | ||
| set(TEST_COMMON_HIP ${CMAKE_CURRENT_SOURCE_DIR}/utils/test_common.hip) | ||
| set_source_files_properties(${TEST_COMMON_HIP} PROPERTIES LANGUAGE CXX) | ||
| add_executable(${TARGET_NAME} ${SOURCE_FILE} ${TEST_COMMON_HIP}) | ||
| target_compile_options(${TARGET_NAME} PRIVATE ${COMMON_COMPILE_OPTIONS}) | ||
| target_compile_definitions(${TARGET_NAME} PRIVATE NVTE_ROCM_BENCHMARK) | ||
| target_link_libraries(${TARGET_NAME} PRIVATE | ||
| benchmark::benchmark | ||
| ${TRANSFORMER_ENGINE_LIB} | ||
| hiprand | ||
| ) | ||
| set_target_properties(${TARGET_NAME} PROPERTIES HIP_ARCHITECTURES "${GPU_TARGETS}") | ||
| endfunction() | ||
|
|
||
| add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) | ||
| add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) | ||
| add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * License for AMD contributions = MIT. See LICENSE for more information | ||
| ************************************************************************/ | ||
|
|
||
| #include <benchmark/benchmark.h> | ||
| #include <hip/hip_runtime.h> | ||
| #include <hip/hip_fp16.h> | ||
| #include <hip/hip_bfloat16.h> | ||
| #include "amd_detail/hip_float8.h" | ||
|
|
||
| #include "benchmark_utils.h" | ||
|
|
||
| #include <transformer_engine/cast_hip.h> | ||
| #include <transformer_engine/transformer_engine_hip.h> | ||
|
|
||
| using namespace te_bench; | ||
| using namespace transformer_engine; | ||
| using fp8_e4m3 = test::fp8e4m3; | ||
|
|
||
| // Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) | ||
| #define COMMON_SHAPES \ | ||
| ->Args({1024, 3584}) \ | ||
| ->Args({1024, 4096}) \ | ||
| ->Args({1024, 8192}) \ | ||
| ->Args({1024, 14336}) \ | ||
| ->Args({2048, 4096}) \ | ||
| ->Args({2048, 8192}) \ | ||
| ->Args({2048, 14336}) \ | ||
| ->Args({2048, 28672}) \ | ||
| ->Args({4096, 4096}) \ | ||
| ->Args({4096, 8192}) \ | ||
| ->Args({4096, 16384}) \ | ||
| ->Args({4096, 28672}) \ | ||
| ->Args({8192, 8192}) \ | ||
| ->Args({8192, 16384}) \ | ||
| ->Args({8192, 28672}) \ | ||
| ->Args({8192, 53248}) \ | ||
| ->Args({16384, 8192}) \ | ||
| ->Args({16384, 16384})\ | ||
| ->Args({32768, 8192}) | ||
|
|
||
| template <typename IType, typename OType, int SCALE_DIM_Y, int SCALE_DIM_X> | ||
| static void BM_DequantizeMXFP8(benchmark::State &state) { | ||
| const size_t rows = state.range(0); | ||
| const size_t cols = state.range(1); | ||
|
|
||
| constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; | ||
| constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; | ||
|
|
||
| const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0; | ||
| const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; | ||
| const size_t scale_cols_col = USE_COLWISE ? cols : 0; | ||
|
|
||
| std::vector<size_t> shape = {rows, cols}; | ||
| DType itype = std::is_same_v<IType, fp8_e4m3> ? DType::kFloat8E4M3 : DType::kFloat8E5M2; | ||
| DType otype = std::is_same_v<OType, __half> ? DType::kFloat16 : | ||
| (std::is_same_v<OType, hip_bfloat16> ? DType::kBFloat16 : DType::kFloat32); | ||
|
|
||
| test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, USE_ROWWISE, USE_COLWISE, | ||
| NVTE_MXFP8_1D_SCALING, false); | ||
| test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, true, false, | ||
| NVTE_DELAYED_TENSOR_SCALING, false); | ||
|
|
||
| hipStream_t stream; | ||
| HIP_CHECK(hipStreamCreate(&stream)); | ||
|
|
||
| DeviceBuffer<float> temp_fp32(rows * cols); | ||
| fill_random_uniform_gpu(temp_fp32.get(), rows * cols, -2.0f, 1.0f, stream); | ||
|
|
||
| void *input_data_ptr = USE_ROWWISE ? input_tensor.rowwise_dptr() : input_tensor.columnwise_dptr(); | ||
| size_t threads = 256; | ||
| size_t blocks = (rows * cols + threads - 1) / threads; | ||
| cast_fp32_kernel<<<blocks, threads, 0, stream>>>(temp_fp32.get(), static_cast<IType*>(input_data_ptr), rows * cols); | ||
|
|
||
| HIP_CHECK(hipStreamSynchronize(stream)); | ||
|
|
||
| hipEvent_t start, stop; | ||
| HIP_CHECK(hipEventCreate(&start)); | ||
| HIP_CHECK(hipEventCreate(&stop)); | ||
|
|
||
| warmup_gpu(); | ||
|
|
||
| for (auto _ : state) { | ||
| HIP_CHECK(hipEventRecord(start, stream)); | ||
|
|
||
| nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); | ||
|
|
||
| HIP_CHECK(hipEventRecord(stop, stream)); | ||
| HIP_CHECK(hipEventSynchronize(stop)); | ||
|
|
||
| float ms = 0; | ||
| HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); | ||
| state.SetIterationTime(ms / 1000.0); | ||
| } | ||
|
|
||
| HIP_CHECK(hipEventDestroy(start)); | ||
| HIP_CHECK(hipEventDestroy(stop)); | ||
|
|
||
| const size_t bytes_read_data = rows * cols * sizeof(IType) * | ||
| ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); | ||
| // Scales are single byte, E8M0 type | ||
| const size_t bytes_read_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + | ||
| (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); | ||
| const size_t bytes_write = rows * cols * sizeof(OType); | ||
| const size_t total_bytes = bytes_read_data + bytes_read_scales + bytes_write; | ||
|
|
||
| set_bytes_processed(state, total_bytes); | ||
|
|
||
| HIP_CHECK(hipStreamDestroy(stream)); | ||
| } | ||
|
|
||
| #define REGISTER_DEQUANTIZE_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \ | ||
| BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 1, 32) \ | ||
| ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/rowwise") \ | ||
| COMMON_SHAPES \ | ||
| ->Unit(benchmark::kMicrosecond) \ | ||
| ->UseManualTime(); \ | ||
| BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 32, 1) \ | ||
| ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/colwise") \ | ||
| COMMON_SHAPES \ | ||
| ->Unit(benchmark::kMicrosecond) \ | ||
| ->UseManualTime(); | ||
|
|
||
| REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, __half, "E4M3", "FP16") | ||
| REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, hip_bfloat16, "E4M3", "BF16") | ||
| REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, float, "E4M3", "FP32") | ||
|
|
||
| BENCHMARK_MAIN(); | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file hipified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is written in hip, but references the hipified headers in te. Would you prefer we write this in CUDA then hipify it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is OK to run it directly with HIP. The question is if hipify is called on it - from the headers it seems it is not. BWT, why <> notation is used for TE headers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the benchmarks are built separate from TE library, best practice is to use <> to ensure we are using the installed library's headers.