From 9611d2f3715787a27299deb5b1a3d899f8306230 Mon Sep 17 00:00:00 2001 From: Michal Kulikowski Date: Fri, 21 Nov 2025 11:18:51 +0100 Subject: [PATCH] [CK]Refactoring threadwise_tensor_slice_transfer_v3r1.hpp Signed-off-by: Michal Kulikowski --- .../threadwise_tensor_slice_transfer_v3r1.hpp | 539 +++++++----------- 1 file changed, 210 insertions(+), 329 deletions(-) diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 4a6ed62c0e2..1c82ce27c10 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -142,66 +142,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto src_forward_steps = ComputeForwardSteps(src_desc, src_scalar_per_access); + const auto src_backward_steps = ComputeBackwardSteps(src_desc, src_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = + ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -308,20 +264,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -382,37 +326,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = + ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -547,66 +469,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto dst_forward_steps = ComputeForwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_backward_steps = ComputeBackwardSteps(dst_desc, dst_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_dst_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = + ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); + constexpr auto dst_data_idx = ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths, + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -634,20 +512,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -679,183 +545,190 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetSrcCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + return ComputeCoordinateResetStep(); + } - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; + __device__ static constexpr auto GetDstCoordinateResetStep() + { + return ComputeCoordinateResetStep(); + } - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep); + } - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep); + } - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + return ComputeThreadScratchDescriptor(); + } - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); - return reset_src_data_step_; - }(); + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - return reset_src_data_step; + return make_naive_tensor_descriptor_packed(src_access_lengths); } - __device__ static constexpr auto GetDstCoordinateResetStep() + __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + return ComputeThreadScratchDescriptor(); + } - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + protected: + // Helper function to compute forward sweep pattern + // I.e. if we should move forward or backward in each of tensor's dimensions + template + __device__ static constexpr auto + ComputeForwardSweep(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths) + { + StaticallyIndexedArray forward_sweep_; - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + forward_sweep_(I0) = true; - forward_sweep_(I0) = true; + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; + static_for<1, i, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); + forward_sweep_(i) = tmp % 2 == 0; + }); - forward_sweep_(i) = tmp % 2 == 0; - }); + return forward_sweep_; + } - return forward_sweep_; - }(); + // Compute which dimensions should have their coordinates updated during iteration + // A dimension moves when it hasn't reached its end and all higher priority dimensions + // have completed their ranges + template + __device__ static constexpr auto + ComputeMoveOnDim(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths) + { + StaticallyIndexedArray move_on_dim_; - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; }); + }); - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; + return move_on_dim_; + } - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + // Compute data index from ordered access index, converting back to natural order + template + __device__ static constexpr auto + ComputeDataIndex(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths, + const ForwardSweep& forward_sweep, + const DimAccessOrder& dim_access_order, + const ScalarPerAccess& scalar_per_access) + { + Index ordered_idx; - return reset_dst_data_step_; - }(); + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); - return reset_dst_data_step; + return container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access; } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) + // Compute forward coordinate steps for each dimension + template + __device__ static constexpr auto ComputeForwardSteps(const Desc& desc, + const ScalarPerAccess& scalar_per_access) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + return generate_tuple( + [&](auto i) { + Index forward_step_idx; - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; + }); - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + return make_tensor_coordinate_step(desc, forward_step_idx); + }, + Number{}); } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, - const Index& dst_slice_origin_step_idx) + // Compute backward coordinate steps for each dimension + template + __device__ static constexpr auto ComputeBackwardSteps(const Desc& desc, + const ScalarPerAccess& scalar_per_access) { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + return generate_tuple( + [&](auto i) { + Index backward_step_idx; - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; + }); - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + return make_tensor_coordinate_step(desc, backward_step_idx); + }, + Number{}); } - __device__ static constexpr auto GetSrcThreadScratchDescriptor() + // Generic helper to compute thread scratch descriptor + template + __device__ static constexpr auto ComputeThreadScratchDescriptor() { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); + constexpr auto access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(access_lengths), Number{}); // 1st stage of transforms constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + make_naive_tensor_descriptor_packed(access_lengths_and_vector_length); // 2nd stage of transforms constexpr auto transforms = generate_tuple( [&](auto i) { - if constexpr(i == SrcVectorDim) + if constexpr(i == VectorDim) { return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); + make_tuple(access_lengths_and_vector_length[i], + access_lengths_and_vector_length[Number{}])); } else { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + return make_pass_through_transform(access_lengths_and_vector_length[i]); } }, Number{}); constexpr auto low_dim_idss = generate_tuple( [&](auto i) { - if constexpr(i == SrcVectorDim) + if constexpr(i == VectorDim) { return Sequence{}; } @@ -872,63 +745,71 @@ struct ThreadwiseTensorSliceTransfer_v3r1 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } - __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() + // Generic helper to move slice window + template + __device__ static void MoveSliceWindow(const Desc& desc, + Coord& coord, + const Index& slice_origin_step_idx, + GetCoordinateResetStepFunc get_reset_step) { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + // if coord was not reset by RunRead/RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = ResetCoordinateAfterRun + ? slice_origin_step_idx + : slice_origin_step_idx + get_reset_step(); - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(desc, adjusted_step_idx); - return make_naive_tensor_descriptor_packed(src_access_lengths); + move_tensor_coordinate(desc, coord, adjusted_step); } - __device__ static constexpr auto GetDstThreadScratchDescriptor() + // Generic helper to compute coordinate reset step + template + __device__ static constexpr auto ComputeCoordinateResetStep() { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); + constexpr auto dim_access_order = DimAccessOrder{}; - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); + // judge move forward or move backward during the last iteration + constexpr auto ordered_access_lengths_minus_1 = generate_tuple( + [&](auto i) { return Number{}; }, Number{}); + constexpr auto forward_sweep = + ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_access_lengths); - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); + // calculate data index after last iteration, if it has not being reset + constexpr auto data_idx = [&]() { + Index ordered_idx; - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + scalar_per_access; + }(); + + // + constexpr auto reset_data_step = [&]() { + Index reset_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); + + return reset_data_step_; + }(); + + return reset_data_step; } private: