Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,31 @@ struct GemmConfigComputeV6 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
};

template <typename PrecType>
struct GemmConfigComputeV7 : public GemmConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;

static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V7;

static constexpr int kBlockPerCu = 2;
};

template <typename PrecType>
struct GemmConfigPreshuffleDecode : public GemmConfigBase
{
Expand Down Expand Up @@ -423,6 +448,15 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
};

template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V7>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV7<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV7<PipelineProblem>;
};

template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
{
Expand Down
16 changes: 16 additions & 0 deletions include/ck_tile/core/tensor/load_tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

template <typename TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto
load_tile_with_elementwise_vectorload1(const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should works with unknow number of params
// Load element_wise API works only when the input typle is a tuple-type
return tile_window[number<0>{}].load_vectorload1(
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
template <typename DistributedTensor_,
typename TileWindow_,
Expand Down
106 changes: 106 additions & 0 deletions include/ck_tile/core/tensor/tile_window.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,112 @@ struct tile_window_with_static_distribution
});
}

template <typename TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_vectorload1(const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load_vectorload1(dst_tensor,
tile_window,
elementwise,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
return dst_tensor;
}

template <typename DistributedTensor,
typename TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void load_vectorload1(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using OldTraits = typename Base::Traits;
using Traits = typename Base::TraitsVectorload1;
using vector_t = thread_buffer<typename Base::DataType, 1>;
using SFC_Ys = typename Traits::SFC_Ys;

constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto sizeOfTuple = TileWindow_::size();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];

static_for<0, NumAccessPerCoord * OldTraits::ScalarPerVector, 1>{}(
[&](auto iCoordAccess) {
constexpr auto iAccess =
number<iCoord * NumAccessPerCoord * OldTraits::ScalarPerVector +
iCoordAccess>{};

// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);

// read from bottom tensor
const auto idx_vec_value = generate_tuple(
[&](auto jj) {
return tile_window[number<jj>{}]
.get_bottom_tensor_view()
.template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
0, // linear offset
bool_constant<oob_conditional_check>{});
},
number<sizeOfTuple>{});

static_for<0, 1, Traits::PackedSize>{}([&](auto j) {
// write into distributed tensor
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<Base::NDimY>{});

constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;

ck_tile::apply(
[&](auto&&... t) {
elementwise(dst_tensor.get_thread_buffer().template at<d>(),
t.template get_as<typename Base::DataType>()[0]...);
},
idx_vec_value);
});
// move thread coordinate
if constexpr(iCoordAccess !=
(NumAccessPerCoord * OldTraits::ScalarPerVector - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);

constexpr auto idx_diff_ps_ys =
container_concat(generate_tuple([&](auto) { return number<0>{}; },
number<Base::NDimP>{}),
idx_diff_ys);

Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord,
bottom_tensor_thread_coord,
idx_diff_ps_ys);
}
});
});
}

template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
Expand Down
62 changes: 62 additions & 0 deletions include/ck_tile/core/tensor/tile_window_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,68 @@ struct tile_window_with_tile_dstr_base
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
};

struct TraitsVectorload1
{
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename TileWindowBase::DataType>>::PackedSize;

static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides();

index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;

for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}

return make_tuple(VectorDimY_, ScalarPerVector_);
}

static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector = 1;
using vector_t =
thread_buffer<typename TileWindowBase::DataType, ScalarPerVector / PackedSize>;

static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});

/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;

return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();

static constexpr auto get_space_filling_curve()
{
constexpr auto thread_tensor_lengths_ys =
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());

// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;

return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_),
false /*!!! no snaked curve! */>{};
}

using SFC_Ys = decltype(get_space_filling_curve());

static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();

static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
};

// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v7.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
Expand Down
22 changes: 18 additions & 4 deletions include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,19 @@ struct UniversalGemmKernel
return false;
}

if(GemmPipeline::GetPipelineName() == "COMPUTE_V7")
{
if(GemmPipeline::kPadK == false || GemmPipeline::kPadM == false ||
GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Compute pipeline v7 needs all paddings enabled!");
}
return false;
}
}

const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
: GemmPipeline::template GetVectorSizeA<false>();
bool AsTensorIsValid = {true};
Expand All @@ -439,7 +452,7 @@ struct UniversalGemmKernel
}
AsTensorIsValid = false;
}
if(kargs.K % vectorSizeA != 0)
if(kargs.K % vectorSizeA != 0 && GemmPipeline::GetPipelineName() != "COMPUTE_V7")
{
const auto remainder = kargs.K % vectorSizeA;
constexpr ck_tile::index_t APackedSize =
Expand Down Expand Up @@ -471,7 +484,7 @@ struct UniversalGemmKernel
}
AsTensorIsValid = false;
}
if(kargs.M % vectorSizeA != 0)
if(kargs.M % vectorSizeA != 0 && GemmPipeline::GetPipelineName() != "COMPUTE_V7")
{
const auto remainder = kargs.M % vectorSizeA;
constexpr ck_tile::index_t APackedSize =
Expand Down Expand Up @@ -511,7 +524,7 @@ struct UniversalGemmKernel
}
BsTensorIsValid = false;
}
if(kargs.N % vectorSizeB != 0)
if(kargs.N % vectorSizeB != 0 && GemmPipeline::GetPipelineName() != "COMPUTE_V7")
{
const auto remainder = kargs.N % vectorSizeB;
constexpr ck_tile::index_t BPackedSize =
Expand Down Expand Up @@ -544,7 +557,8 @@ struct UniversalGemmKernel
}
BsTensorIsValid = false;
}
if(kargs.K % vectorSizeB != 0)
if(kargs.K % vectorSizeB != 0 &&
GemmPipeline::GetPipelineName() != "COMPUTE_V7")
{
const auto remainder = kargs.K % vectorSizeB;
constexpr ck_tile::index_t BPackedSize =
Expand Down
Loading