Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/scripts/aiter_prebuild_upload.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,8 @@ if [[ "${1:-}" == "--build" ]]; then
fi

# Ensure built libs exist
if [[ ! -f "${EXTRACT_DIR}/libmha_fwd.so" ]]; then
echo "[AITER-PREBUILT] Missing libmha_fwd.so in ${EXTRACT_DIR}" >&2
exit 1
fi
if [[ ! -f "${EXTRACT_DIR}/libmha_bwd.so" ]]; then
echo "[AITER-PREBUILT] Missing libmha_bwd.so in ${EXTRACT_DIR}" >&2
if [[ ! -f "${EXTRACT_DIR}/libmha.a" ]]; then
echo "[AITER-PREBUILT] Missing libmha.a in ${EXTRACT_DIR}" >&2
exit 1
fi

Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 0 files
25 changes: 18 additions & 7 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set(CMAKE_CXX_STANDARD 17)
project(ck_fused_attn LANGUAGES HIP CXX)


set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE")
set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha install prefix in TE")

set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter")
set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel")
Expand Down Expand Up @@ -56,13 +56,13 @@ list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR)

if(DEFINED AITER_MHA_PATH)
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}")
# use pre-built libmha_fwd.so libmha_bwd.so
# use pre-built libraries
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
else()
set(__AITER_MHA_PATH "")
include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")
get_prebuilt_aiter(__AITER_MHA_PATH)

if(__AITER_MHA_PATH STREQUAL "")
# If not available, fallback: Build from source
message(STATUS "[AITER-BUILD] Building aiter from source.")
Expand All @@ -73,10 +73,19 @@ else()
--install-dir ${__AITER_MHA_PATH}
--gpu-archs "${V3_ASM_ARCHS_STR}"
--ck-tile-bf16 ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}
RESULT_VARIABLE AITER_BUILD_RESULT
)
else()
message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}")
if(NOT AITER_BUILD_RESULT EQUAL 0)
message(FATAL_ERROR "[AITER-BUILD] aiter_build.sh failed with exit code ${AITER_BUILD_RESULT}")
endif()
# Validate the source build produced the expected files
foreach(REQUIRED_FILE IN LISTS AITER_PREBUILT_REQUIRED_FILES)
if(NOT EXISTS "${__AITER_MHA_PATH}/${REQUIRED_FILE}")
message(FATAL_ERROR "[AITER-BUILD] Source build completed but ${REQUIRED_FILE} is missing from ${__AITER_MHA_PATH}")
endif()
endforeach()
endif()
message(STATUS "[AITER-BUILD] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}")
endif()

set(ck_fused_attn_SOURCES)
Expand Down Expand Up @@ -122,14 +131,16 @@ endif()
target_include_directories(ck_fused_attn PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha)
target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR})
target_link_options(ck_fused_attn PRIVATE -Wl,--exclude-libs,ALL)

set(__AITER_MHA_LIB "${__AITER_MHA_PATH}/libmha.a")
find_package(hip)
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so)
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_LIB})

target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS})
target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS})
set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")

install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
# copy v3 kernels to destination
foreach(ARCH IN LISTS V3_ASM_ARCHS)
Expand Down
47 changes: 46 additions & 1 deletion transformer_engine/common/ck_fused_attn/aiter_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,52 @@ CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \
GPU_ARCHS="${GPU_ARCHS_VAL}" \
python3 "${AITER_TEST_DIR}/compile.py"

# Check for ar and ranlib
AR_BIN="${AR:-$(command -v ar || true)}"
RANLIB_BIN="${RANLIB:-$(command -v ranlib || true)}"
if [[ -z "${AR_BIN}" ]]; then
echo "[AITER-BUILD] Could not find ar for static archive generation." >&2
exit 1
fi
if [[ -z "${RANLIB_BIN}" ]]; then
echo "[AITER-BUILD] Could not find ranlib for static archive generation." >&2
exit 1
fi

# Create a single unified static archive from both forward and backward object files
out_archive="${AITER_TEST_DIR}/libmha.a"
obj_list=$(mktemp)
rm -f "${obj_list}"

for lib in fwd bwd; do
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is built and used by us, will it be more efficient to make single lib?

src_obj_dir="${AITER_DIR}/aiter/jit/build/libmha_${lib}/build"
if [[ ! -d "${src_obj_dir}" ]]; then
echo "[AITER-BUILD] Missing object directory: ${src_obj_dir}" >&2
rm -f "${obj_list}"
exit 1
fi
find "${src_obj_dir}" -type f -name '*.o' >> "${obj_list}"
done

total_objs=$(wc -l < "${obj_list}")
if [[ "${total_objs}" -eq 0 ]]; then
echo "[AITER-BUILD] No object files found for fwd/bwd" >&2
rm -f "${obj_list}"
exit 1
fi

rm -f "${out_archive}"
# Use a file list to avoid ARG_MAX limits with thousands of object files
"${AR_BIN}" qc "${out_archive}" @"${obj_list}"

if [[ -n "${RANLIB_BIN}" ]]; then
"${RANLIB_BIN}" "${out_archive}"
fi

echo "[AITER-BUILD] Created static archive: ${out_archive} (${total_objs} objects)"
rm -f "${obj_list}"

if [ -n "${INSTALL_DIR}" ]; then
mkdir -p "${INSTALL_DIR}"
cp "${AITER_TEST_DIR}/libmha_fwd.so" "${AITER_TEST_DIR}/libmha_bwd.so" "${INSTALL_DIR}/"
cp "${out_archive}" "${INSTALL_DIR}/"
fi
30 changes: 26 additions & 4 deletions transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,25 @@ function(get_aiter_cache_key ROCM_VER_PARAM KEY_VAR CACHE_DIR_VAR)
set(${CACHE_DIR_VAR} "${AITER_CACHE_ROOT}/${_KEY}" PARENT_SCOPE)
endfunction()

# Required files that must be present in a valid AITER prebuilt cache
set(AITER_PREBUILT_REQUIRED_FILES
"libmha.a"
)

# Validate existing cache path
function(is_aiter_cache_valid ROCM_VER_PARAM CACHE_VALID)
get_aiter_cache_key("${ROCM_VER_PARAM}" KEY EXTRACT_DIR)
if(EXISTS "${EXTRACT_DIR}/libmha_fwd.so" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.so")
set(${CACHE_VALID} TRUE PARENT_SCOPE)
message(STATUS "[AITER-PREBUILT] Found Cached build files at ${EXTRACT_DIR}")
if(NOT EXISTS "${EXTRACT_DIR}")
return()
endif()
foreach(REQUIRED_FILE IN LISTS AITER_PREBUILT_REQUIRED_FILES)
if(NOT EXISTS "${EXTRACT_DIR}/${REQUIRED_FILE}")
message(WARNING "[AITER-PREBUILT] Cache at ${EXTRACT_DIR} is missing ${REQUIRED_FILE}")
return()
endif()
endforeach()
set(${CACHE_VALID} TRUE PARENT_SCOPE)
message(STATUS "[AITER-PREBUILT] Found valid cached build files at ${EXTRACT_DIR}")
endfunction()

# Main function to get prebuilt aiter libs.
Expand All @@ -56,7 +68,7 @@ function(get_prebuilt_aiter PREBUILT_DIR_VAR)
endforeach()

# Cache is invalid/outdated - clean it and some build files that depend on AITER libs path
file(REMOVE_RECURSE "${AITER_CACHE_ROOT}")
file(REMOVE_RECURSE "${EXTRACT_DIR}")
file(REMOVE_RECURSE "${CMAKE_BINARY_DIR}/_deps")

#TODO: remove ROCM_VER from the check once the change is integrated to all features modifying AITER
Expand Down Expand Up @@ -107,5 +119,15 @@ function(download_aiter_prebuilt ROCM_VER_PARAM DOWNLOAD_SUCCESS)
# Download & extract prebuilt files
FetchContent_MakeAvailable(aiter_prebuilt)
message(STATUS "[AITER-PREBUILT] Successfully downloaded to ${EXTRACT_DIR}")

# Validate downloaded contents before declaring success
foreach(REQUIRED_FILE IN LISTS AITER_PREBUILT_REQUIRED_FILES)
if(NOT EXISTS "${EXTRACT_DIR}/${REQUIRED_FILE}")
message(WARNING "[AITER-PREBUILT] Downloaded cache is missing ${REQUIRED_FILE} — discarding and falling back to source build.")
file(REMOVE_RECURSE "${EXTRACT_DIR}")
return()
endif()
endforeach()
message(STATUS "[AITER-PREBUILT] Downloaded cache validated successfully.")
set(${DOWNLOAD_SUCCESS} TRUE PARENT_SCOPE)
endfunction()