Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigABQuantPrefill<T>;

template <typename T>
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;

// template <typename T>
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;

void abquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
Expand Down Expand Up @@ -78,7 +84,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -93,7 +99,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -108,7 +114,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -123,7 +129,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down
29 changes: 29 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
static constexpr bool PreshuffleQuant = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);

static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
};

template <typename PrecType>
struct GemmConfigQuantPrefill : public GemmConfigBase
{
Expand All @@ -209,6 +231,13 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};

template <typename PrecType>
struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped;
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
Expand All @@ -57,7 +58,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
QuantMode,
AQLayout, // for AQLayout
BQLayout, // for BQLayout
false,
transpose_c,
GemmConfig::DoubleSmemBuffer>;

using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
Expand Down Expand Up @@ -88,7 +89,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;

// row-col and tensor quants use the regular pipeline, A/B/AB quants use their own
using PipelineProblem = std::conditional_t<
Expand Down
12 changes: 6 additions & 6 deletions include/ck_tile/core/tensor/sweep_tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ template <typename TileDistributedSpan_, // tile_distributed_span<...>
>
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
{
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;

static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
using DstrSpanImpl = typename remove_cvref_t<TileDistributedSpan_>::Impl;

f(dstr_idx);
});
if constexpr(DstrSpanImpl::size() == 0) // handle the 0-dim span case
f(detail::make_tile_distributed_index(sequence<>{}));
else
static_ford<DstrSpanImpl>{}(
[&](auto dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)); });
}

// unpacked span, this version support span with unpack(multi-arg) functor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
});
});
};

auto q_block_tensor = aq_block_tensor;
if constexpr(Traits::NQPerBlock == 1)
{
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
sweep_tile_span(aq_spans[I0], [&](auto im) {
sweep_tile_span(aq_spans[I1], [&](auto ik) {
q_block_tensor(make_tuple(im, ik)) *=
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
});
});
}
// hot loop:
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
zero_accumulators();
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
Expand Down Expand Up @@ -243,9 +256,29 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
}
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
auto nIter) {
if constexpr(Traits::NQPerBlock == 1)
{
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};

constexpr auto block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};

static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
});
}
else
{
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
Expand Down Expand Up @@ -273,7 +306,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});
});
}
});
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,37 +285,64 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();

// Start from AQ block tensor and then scale it using BQ; this represents
// the combined A/B quantization scales for the block.
auto q_block_tensor = aq_block_tensor;
if constexpr(Traits::NQPerBlock == 1)
{
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
sweep_tile_span(aq_spans[I0{}], [&](auto im) {
sweep_tile_span(aq_spans[I1{}], [&](auto ik) {
q_block_tensor(make_tuple(im, ik)) *=
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
});
});
}

// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
auto nIter) {
CWarpTensor c_warp_tensor;
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;

AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));

if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});

static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;

AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));

BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));

if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
if constexpr(Traits::NQPerBlock == 1)
{
constexpr auto cw_spans = CWarpTensor::get_distributed_spans();
static_assert(cw_spans[I0{}].impl_.size() == 0);
sweep_tile_span(cw_spans[I1{}], [&](auto in) {
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_n = detail::make_tile_distributed_index(
merge_sequences(sequence<nIter>{}, in.impl_));
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
constexpr auto empty_idx = tile_distributed_index<>{};
c_block_tensor(make_tuple(block_idx_m, block_idx_n)) +=
c_warp_tensor(make_tuple(empty_idx, in)) *
q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
});

}
else
{
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
Expand Down Expand Up @@ -387,7 +414,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
b_scale_reg_f);
});
}
});
}
});
});
}
Expand Down
Loading