diff --git a/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp index a58caba3706..b37da6cefe4 100644 --- a/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp @@ -49,18 +49,20 @@ struct MultiReduce2d { using S = typename Problem::BlockShape; constexpr index_t memory_vector_size = 16 / sizeof(XDataType); // Vectorization - constexpr index_t thread_tile_vector_size = - S::ThreadTile_N; // In the continuous dimension, within the tile constexpr auto innermost_reduce_dim = ReduceDims{}.at(number{}); constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1); - constexpr index_t stride_based_vector_size = - is_innermost_contiguous - ? ck_tile::min(memory_vector_size, thread_tile_vector_size) - : 1; // Move at "vectorization" steps if continuous otherwise 1 step - - return stride_based_vector_size; + if constexpr(is_innermost_contiguous) + { + constexpr index_t thread_tile_vector_size = S::ThreadTile_N; + return ck_tile::min(memory_vector_size, thread_tile_vector_size); + } + else + { + constexpr index_t thread_tile_vector_size = S::ThreadTile_M; + return ck_tile::min(memory_vector_size, thread_tile_vector_size); + } } static constexpr index_t CalculateOutputVectorSize() @@ -192,12 +194,6 @@ struct MultiReduce2d const auto reduce_merge_transform = make_merge_transform(reduce_lens); // Dimension(s) to reduce are being flattened - const auto custom_padding_values = ck_tile::apply( - [](auto... args) { - return ck_tile::make_tuple(args.template GetIdentityValue()...); - }, - reduce_ops); // Get the identity element for each operation - constexpr auto x_tensor_vector_size = CalculateInputVectorSize(); auto desc = make_naive_tensor_descriptor( @@ -213,44 +209,54 @@ struct MultiReduce2d auto [m_offset, n_offset] = partitioner.GetInputTileOffsets( block_global_id, block_group_size, num_n_tile_iteration); - static_for<0, number_operations, 1>{}([&](auto i) { - auto buffer_view = make_buffer_view( - p_x, desc.get_element_space_size(), custom_padding_values.get(number{})); - - const auto x_tensor = - tensor_view{buffer_view, desc}; - const auto transformed_x_tensor = pad_tensor_view( - transform_tensor_view(x_tensor, - make_tuple(kept_merge_transform, reduce_merge_transform), - make_tuple(kept_dim, reduce_dims), - make_tuple(sequence<0>{}, sequence<1>{})), - make_tuple(number{}, number{}), - sequence<0, 1>{}); - - auto x_window = - make_tile_window(transformed_x_tensor, - make_tuple(number{}, number{}), - {m_offset, n_offset}, - Policy::template MakeXBlockTileDistribution()); - - using ComputeDataTensorType = decltype(cast_tile(load_tile(x_window))); - - auto y_compute = block_reduce2d.template MakeYBlockTile(); - - set_tile(y_compute, - reduce_ops.get(number{}).template GetIdentityValue()); - - // Reduction loop - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) - { - auto x = load_tile(x_window); - auto x_compute = cast_tile(x); + const auto padding_value = + reduce_ops.get(number<0>{}).template GetIdentityValue(); + auto buffer_view = make_buffer_view( + p_x, desc.get_element_space_size(), padding_value); + + const auto x_tensor = tensor_view{buffer_view, desc}; + const auto transformed_x_tensor = pad_tensor_view( + transform_tensor_view(x_tensor, + make_tuple(kept_merge_transform, reduce_merge_transform), + make_tuple(kept_dim, reduce_dims), + make_tuple(sequence<0>{}, sequence<1>{})), + make_tuple(number{}, number{}), + sequence<0, 1>{}); + + auto x_window = make_tile_window(transformed_x_tensor, + make_tuple(number{}, number{}), + {m_offset, n_offset}, + Policy::template MakeXBlockTileDistribution()); + + using ComputeDataTensorType = decltype(cast_tile(load_tile(x_window))); + + // Initialize all accumulator buffers (one per operation) + auto y_compute_tuple = generate_tuple( + [&](auto i) { + auto y_compute = block_reduce2d.template MakeYBlockTile(); + set_tile(y_compute, reduce_ops.get(i).template GetIdentityValue()); + return y_compute; + }, + number{}); - tile_elementwise_inout(elementwise_ops.get(number{}), x_compute, x_compute); - block_reduce2d(x_compute, y_compute, reduce_ops.get(number{})); + // Reduction loop + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + auto x = load_tile(x_window); + auto x_compute = cast_tile(x); - move_tile_window(x_window, {0, S::Block_N}); - } + static_for<0, number_operations, 1>{}([&](auto i) { + auto x_temp = x_compute; + tile_elementwise_inout(elementwise_ops.get(number{}), x_temp, x_temp); + block_reduce2d(x_temp, y_compute_tuple[i], reduce_ops.get(number{})); + }); + + move_tile_window(x_window, {0, S::Block_N}); + } + + // Synchronize and output all results + static_for<0, number_operations, 1>{}([&](auto i) { + auto& y_compute = y_compute_tuple[i]; block_reduce2d_sync(y_compute, reduce_ops.get(number{})); block_reduce2d_cross_warp_sync( @@ -331,6 +337,7 @@ struct MultiReduce2d /// @note Requirements: /// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution) /// - input_strides[-1] == 1 (for contiguous memory access) + /// - All reduce operations must have the same identity value template CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, InputStrides input_strides) @@ -356,6 +363,39 @@ struct MultiReduce2d return false; } + // Check that all reduce operations have the same identity value + auto reduce_ops = typename Problem::ReduceOp{}; + constexpr auto number_operations = reduce_ops.size(); + + if constexpr(number_operations > 1) + { + const auto first_identity = + reduce_ops.get(number<0>{}).template GetIdentityValue(); + bool all_same = true; + + static_for<1, number_operations, 1>{}([&](auto i) { + const auto current_identity = + reduce_ops.get(i).template GetIdentityValue(); + + // Exact comparison needed on identity elements. These elements are not supposed to + // be the result of any computations, so bitwise comparison is acceptable. This is + // done to avoid errors generated by compiler on flags -Werror,-Wfloat-equal + if(__builtin_memcmp(¤t_identity, &first_identity, sizeof(XDataType)) != 0) + { + all_same = false; + } + }); + + if(!all_same) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("All reduce operations must have the same identity value!"); + } + return false; + } + } + return true; } }; diff --git a/test/ck_tile/reduce/test_multi_reduce2d_threadwise.cpp b/test/ck_tile/reduce/test_multi_reduce2d_threadwise.cpp index 95850c47efc..807588649bb 100644 --- a/test/ck_tile/reduce/test_multi_reduce2d_threadwise.cpp +++ b/test/ck_tile/reduce/test_multi_reduce2d_threadwise.cpp @@ -39,26 +39,20 @@ using TestConfig_F16_Add = std::tuple; -using TestConfig_F16_Add_Max = std::tuple< +using TestConfig_F16_Add_SumSquare = std::tuple< ck_tile::half_t, float, ck_tile::half_t, - ck_tile::tuple, - ck_tile::tuple, - ck_tile::tuple, - ck_tile::tuple, + ck_tile::tuple, + ck_tile::tuple, + ck_tile::tuple, + ck_tile::tuple, Shape1_BlockWarps, Shape1_BlockTile, Shape1_WarpTile, Shape1_ThreadTile>; -using TestTypes = ::testing::Types; +using TestTypes = ::testing::Types; TYPED_TEST_SUITE(TestCkTileMultiReduceThreadwise, TestTypes);