-
Notifications
You must be signed in to change notification settings - Fork 267
Add build time optimization documentation #3608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tenpercent
wants to merge
3
commits into
mpodkory/find-transform-optimization
Choose a base branch
from
mpodkory/build-time-docs
base: mpodkory/find-transform-optimization
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+207
−0
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,207 @@ | ||
| # Build Time Optimization | ||
|
|
||
| Tracking issue: [#3575](https://github.com/ROCm/composable_kernel/issues/3575) | ||
|
|
||
| This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase. | ||
|
|
||
| ## Why Build Time Matters | ||
|
|
||
| Composable Kernel relies heavily on C++ template metaprogramming to achieve GPU kernels with no runtime abstraction penalty. However, deep template instantiation can significantly impact build times. A single translation unit may trigger hundreds of thousands of template instantiations, with each instantiation adding to compile time. | ||
|
|
||
| ## Optimization Techniques | ||
|
|
||
| ### 1. Replace Recursive Templates with Pack Expansion | ||
|
|
||
| Recursive template patterns create O(N) instantiation depth - the compiler must instantiate each level before proceeding to the next: | ||
|
|
||
| ``` | ||
| sequence_gen_impl<5, F> | ||
| → sequence_gen_impl<4, F> | ||
| → sequence_gen_impl<3, F> | ||
| → ... | ||
| ``` | ||
|
|
||
| Using `__make_integer_seq` (Clang/MSVC) combined with pack expansion reduces this to constant depth - the compiler generates the entire sequence in one step internally, without recursive template instantiation. | ||
|
|
||
| **Before** (O(N) recursive instantiation): | ||
|
|
||
| ```cpp | ||
| template <index_t N, typename F, index_t... Is> | ||
| struct sequence_gen_impl | ||
| { | ||
| using type = typename sequence_gen_impl<N-1, F, F{}(Number<N-1>{}), Is...>::type; | ||
| }; | ||
|
|
||
| template <typename F, index_t... Is> | ||
| struct sequence_gen_impl<0, F, Is...> | ||
| { | ||
| using type = Sequence<Is...>; | ||
| }; | ||
| ``` | ||
|
|
||
| **After** (constant depth using compiler intrinsic + pack expansion): | ||
|
|
||
| ```cpp | ||
| namespace detail { | ||
|
|
||
| template <typename T, T... Is> | ||
| struct sequence_gen_helper | ||
| { | ||
| // Apply functor F to all indices via pack expansion | ||
| // F{}(Number<0>{}), F{}(Number<1>{}), ..., F{}(Number<N-1>{}) | ||
| template <typename F> | ||
| using apply = Sequence<F{}(Number<Is>{})...>; | ||
| }; | ||
|
|
||
| } // namespace detail | ||
|
|
||
| template <index_t N, typename F> | ||
| struct sequence_gen | ||
| { | ||
| // __make_integer_seq<sequence_gen_helper, index_t, N> produces | ||
| // sequence_gen_helper<index_t, 0, 1, ..., N-1> with constant depth | ||
| using type = | ||
| typename __make_integer_seq<detail::sequence_gen_helper, index_t, N>::template apply<F>; | ||
| }; | ||
| ``` | ||
|
|
||
| Note: While `std::make_integer_sequence` is the standard C++14 way to generate integer sequences, it only produces `std::integer_sequence<T, ...>`. We use `__make_integer_seq` directly because it accepts any template as its first argument, enabling this pattern where the helper class receives the index pack directly. | ||
|
|
||
| ### 2. Replace Lambdas with Named Functors | ||
|
|
||
| Each lambda expression creates a unique closure type, causing separate template instantiations at every call site. Named functors share a single type across all uses. | ||
|
|
||
| **Before** (lambda creates unique instantiations at each call site): | ||
|
|
||
| ```cpp | ||
| // The lambda inside transform_tensor_descriptor: | ||
| generate_tuple([](auto i) { return Sequence<i>{}; }, Number<N>{}); | ||
| ``` | ||
|
|
||
| **After** (named functor shares instantiations): | ||
|
|
||
| ```cpp | ||
| // Define functor once | ||
| struct generate_identity_sequence | ||
| { | ||
| template <index_t I> | ||
| __host__ __device__ constexpr auto operator()(Number<I>) const | ||
| { | ||
| return Sequence<I>{}; | ||
| } | ||
| }; | ||
|
|
||
| // Use everywhere - shares instantiations | ||
| generate_tuple(generate_identity_sequence{}, Number<N>{}); | ||
| ``` | ||
|
|
||
| This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction). | ||
|
|
||
| **Example: container_concat** | ||
|
|
||
| ```cpp | ||
| // Before: lambda creates unique type per call site | ||
| // (unpack2 applies a functor to all elements from both tuples) | ||
| template <typename... X, typename... Y> | ||
| __host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty) | ||
| { | ||
| return unpack2([](auto&&... zs) { return make_tuple(forward<decltype(zs)>(zs)...); }, tx, ty); | ||
tenpercent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // After: named functor shares instantiations | ||
| struct make_tuple_functor | ||
| { | ||
| template <typename... Ts> | ||
| __host__ __device__ constexpr auto operator()(Ts&&... xs) const | ||
| { | ||
| return make_tuple(forward<Ts>(xs)...); | ||
| } | ||
| }; | ||
|
|
||
| template <typename... X, typename... Y> | ||
| __host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty) | ||
| { | ||
| return unpack2(make_tuple_functor{}, tx, ty); | ||
| } | ||
| ``` | ||
|
|
||
| This reduced `container_concat` instantiations from 186 to 93 (50% reduction). | ||
|
|
||
| **Example: make_uniform_tuple** | ||
|
|
||
| For patterns that create tuples with repeated values: | ||
|
|
||
| ```cpp | ||
| // Before: unique lambda type at each call site | ||
| generate_tuple([](auto) { return some_value; }, Number<N>{}); | ||
|
|
||
| // After: dedicated helper function | ||
| template <index_t N, typename T> | ||
| __host__ __device__ constexpr auto make_uniform_tuple(T&& value) | ||
| { | ||
| return detail::make_uniform_tuple_impl(static_cast<T&&>(value), make_index_sequence<N>{}); | ||
| } | ||
|
|
||
| // Usage | ||
| make_uniform_tuple<N>(some_value); | ||
| ``` | ||
|
|
||
| ### 3. Use Constexpr Loops Instead of Template Recursion | ||
|
|
||
| Template recursion creates N template instantiations for N iterations. A constexpr loop executes at compile time but only requires a single template instantiation. While both are O(N) in complexity, constexpr loops are significantly faster because they avoid the overhead of template instantiation. | ||
|
|
||
| **Before** (O(N) template instantiations): | ||
|
|
||
| ```cpp | ||
| template <index_t Target, typename Seq, index_t Pos> | ||
| struct find_source_index_impl | ||
| { | ||
| static constexpr index_t value = | ||
| (Seq::template At<Pos>() == Target) ? Pos : find_source_index_impl<Target, Seq, Pos+1>::value; | ||
| }; | ||
| ``` | ||
|
|
||
| **After** (single instantiation with constexpr loop): | ||
|
|
||
| ```cpp | ||
| template <index_t Target, index_t... Is> | ||
| __host__ __device__ constexpr index_t find_source_index(Sequence<Is...>) | ||
| { | ||
| constexpr index_t values[] = {Is...}; | ||
| for(index_t i = 0; i < sizeof...(Is); ++i) | ||
| if(values[i] == Target) return i; | ||
| return 0; | ||
| } | ||
| ``` | ||
|
|
||
| This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%. | ||
|
|
||
| ### 4. Use Fold Expressions for Accumulation | ||
|
|
||
| Fold expressions (C++17) can replace recursive template patterns for accumulation operations. | ||
|
|
||
| **Before** (implicit recursion through generate_tuple and container_reduce): | ||
|
|
||
| ```cpp | ||
| const auto element_space_size = container_reduce( | ||
| generate_tuple([&](auto i) { | ||
| return (lengths[i] - I1) * strides[i]; | ||
| }, Number<N>{}), | ||
| math::plus{}, LongNumber<1>{}); | ||
| ``` | ||
|
|
||
| **After** (single fold expression): | ||
|
|
||
| ```cpp | ||
| template <typename... Lengths, typename... Strides, index_t... Is> | ||
| __host__ __device__ constexpr auto compute_element_space_size( | ||
| const Tuple<Lengths...>& lengths, | ||
| const Tuple<Strides...>& strides, | ||
| Sequence<Is...>) | ||
| { | ||
| return (LongNumber<1>{} + ... + | ||
| ((lengths[Number<Is>{}] - Number<1>{}) * strides[Number<Is>{}])); | ||
| } | ||
| ``` | ||
|
|
||
| This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%. | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this down into the source, where we are making the code changes. It's not customer documentation, it's aimed at the developers.
Can we align on the goal of this doc? This is kind of all over the place. If it's general info, it should probably go in the tracking bug. In fact, the cleanest way is some comments in the tracking bug that link to documented changes in the source. Then the only need for a markdown file is to track files we need to work on and what has been optimized. The scripts should be documented, so that's not needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, let's move to, say, include/ck
The high-level goal is to document the optimization attempts for the metaprogramming constructs that we have, as well as collect the techniques in an accessible way
When relying on the tracking issue we need to keep in mind that the source code and github infra are different sources of information; from recent discussions I had an impression we wanted to start storing the design documentation in the source, which this file would be a start for