Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -866,8 +866,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -925,8 +924,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -706,8 +705,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
22 changes: 22 additions & 0 deletions include/ck/utility/tuple_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}

// Optimized helper for common pattern: generate_tuple([](auto i) { return Sequence<i.value>{}; },
// N) Creates Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>> without lambda instantiation
namespace detail {
template <index_t... Is>
__host__ __device__ constexpr auto make_identity_sequences_impl(Sequence<Is...>)
{
return make_tuple(Sequence<Is>{}...);
}
} // namespace detail

template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences()
{
return detail::make_identity_sequences_impl(make_index_sequence<N>{});
}

template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences(Number<N>)
{
return generate_identity_sequences<N>();
}

// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ struct Layout
const auto lower_dims =
generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_identity_sequences<Tuple<ShapeDims...>::Size()>();

return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/operations/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());

const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
const auto lower_upper_dims = generate_identity_sequences<8>();

auto sliced_desc = transform_tensor_descriptor(
partition_desc,
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
const auto transforms = GenerateSliceTransforms(idx, shape);
using TransformsTupleType = decltype(transforms);

const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto lower_dims = generate_identity_sequences<old_shape_dims>();
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
Expand Down
6 changes: 2 additions & 4 deletions include/ck/wrapper/utils/layout_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout
},
Number<old_shape_dims>{});

const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto lower_dims = generate_identity_sequences<old_shape_dims>();
const auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
Expand Down Expand Up @@ -492,8 +491,7 @@ __host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& la
},
Number<dims>{});

constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto lower_dims = generate_identity_sequences<dims>();
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
Expand Down
3 changes: 1 addition & 2 deletions include/ck/wrapper/utils/tensor_partition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ make_local_partition(TensorType& tensor,
},
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
generate_identity_sequences<remove_reference_t<decltype(tensor_shape)>::Size()>();
auto sliced_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
// Create layout
Expand Down