Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions csrc/bf16_gemm_cutlass.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (c) 2025, FlashInfer.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_fp16.h>

#include <cstddef>
#include <cstdint>
#include <functional>
#include <type_traits>
#include <vector>

#include "flashinfer/gemm/bf16_gemm_cutlass.h"
#include "flashinfer/gemm/bf16_gemm_cutlass_template.h"
#include "flashinfer/gemm/cutlass_gemm_configs.h"
#include "tvm_ffi_utils.h"

using flashinfer::gemm::ClusterShape;
using flashinfer::gemm::CutlassBf16GemmRunner;
using flashinfer::gemm::CutlassBf16GemmRunnerInterface;
using flashinfer::gemm::CutlassGemmConfig;
using flashinfer::gemm::CutlassTileConfigSM100;
using flashinfer::gemm::EpilogueScheduleType;
using flashinfer::gemm::MainloopScheduleType;

namespace flashinfer {
namespace gemm {
template class CutlassBf16GemmRunner<__nv_bfloat16>;
template class CutlassBf16GemmRunner<half>;
} // namespace gemm
} // namespace flashinfer

namespace torch_ext {

namespace {

CutlassGemmConfig getBf16GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) {
auto getCutlassBf16GemmConfigs = []() {
CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner;
return gemmRunner.getConfigs();
};
static std::vector<CutlassGemmConfig> globalConfigs = getCutlassBf16GemmConfigs();
TVM_FFI_ICHECK(tactic >= 0 && tactic < static_cast<int64_t>(globalConfigs.size()))
<< "tactic must be between 0 and " << globalConfigs.size();
return globalConfigs[tactic];
}

template <typename T>
void runGemm(TensorView out, TensorView mat1, TensorView mat2, int64_t m, int64_t n, int64_t k,
int64_t b, CutlassGemmConfig const& gemmConfig, TensorView workspace_buffer) {
CutlassBf16GemmRunner<T> gemmRunner;

int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k);
int64_t const provided_workspace_size =
workspace_buffer.numel() * get_element_size(workspace_buffer);

auto runKernel = [&](void* workspace) {
gemmRunner.gemm(static_cast<__nv_bfloat16*>(mat1.data_ptr()),
static_cast<__nv_bfloat16*>(mat2.data_ptr()), out.data_ptr(), m, n, k, b,
gemmConfig, static_cast<char*>(workspace), required_workspace_size,
get_stream(mat1.device()));
};

if (provided_workspace_size < required_workspace_size) {
Tensor new_workspace =
alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device());
runKernel(new_workspace.data_ptr());
} else {
runKernel(workspace_buffer.data_ptr());
}
}

void bf16_bmm_impl(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer,
int64_t tactic) {
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);

int64_t m, n, k, b;
if (mat1.ndim() == 2) {
TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix";
TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1))
<< "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1)
<< " and " << mat2.size(0) << "x" << mat2.size(1) << ")";
m = mat1.size(0);
n = mat2.size(0);
k = mat2.size(1);
b = 1;
} else if (mat1.ndim() == 3) {
TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices";
TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size ("
<< mat1.size(0) << " and " << mat2.size(0) << ")";
TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2))
<< "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2)
<< " and " << mat2.size(1) << "x" << mat2.size(2) << ")";
m = mat1.size(1);
n = mat2.size(1);
k = mat2.size(2);
b = mat1.size(0);
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices";
}

if (tactic == -1) {
tactic = 0;
}
auto config = getBf16GemmConfig(m, n, k, tactic);

std::vector<int64_t> out_shape =
mat1.ndim() == 2 ? std::vector<int64_t>{m, n} : std::vector<int64_t>{b, m, n};
TVM_FFI_ICHECK_EQ(out.ndim(), static_cast<int>(out_shape.size()))
<< "out must have " << out_shape.size() << " dimensions, but got " << out.ndim();
for (int i = 0; i < static_cast<int>(out_shape.size()); ++i) {
TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i])
<< "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got "
<< out.size(i);
}

switch (encode_dlpack_dtype(out.dtype())) {
case float16_code:
runGemm<half>(out, mat1, mat2, m, n, k, b, config, workspace_buffer);
break;
case bfloat16_code:
runGemm<__nv_bfloat16>(out, mat1, mat2, m, n, k, b, config, workspace_buffer);
break;
default:
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of fp16/bf16.";
}
}

} // namespace

void bf16_gemm(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer,
int64_t tactic) {
bf16_bmm_impl(mat1, mat2, out, workspace_buffer, tactic);
}

int64_t bf16_gemm_tactic_num() {
auto getCutlassConfigs = []() {
CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner;
return gemmRunner.getConfigs();
};
static int64_t totalTactics = getCutlassConfigs().size();
return totalTactics;
}

} // namespace torch_ext

TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm, torch_ext::bf16_gemm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tactic_num, torch_ext::bf16_gemm_tactic_num);
27 changes: 27 additions & 0 deletions csrc/bf16_gemm_cutlass.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2025, FlashInfer.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "flashinfer/gemm/bf16_gemm_template_sm100.h"

namespace flashinfer {
namespace gemm {
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM);
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM);
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM);
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM);
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM);
} // namespace gemm
} // namespace flashinfer
9 changes: 9 additions & 0 deletions docs/api/gemm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ flashinfer.gemm

This module provides a set of GEMM operations.

BF16 GEMM
---------

.. autosummary::
:toctree: ../generated

mm_bf16
bmm_bf16

FP4 GEMM
--------

Expand Down
2 changes: 2 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@
trtllm_fp8_per_tensor_scale_moe,
)
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm import bmm_bf16 as bmm_bf16
from .gemm import bmm_fp8 as bmm_fp8
from .gemm import mm_bf16 as mm_bf16
from .gemm import mm_fp4 as mm_fp4
from .gemm import mm_fp8 as mm_fp8
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100
Expand Down
4 changes: 4 additions & 0 deletions flashinfer/gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm_base import bmm_bf16 as bmm_bf16
from .gemm_base import bmm_fp8 as bmm_fp8
from .gemm_base import mm_bf16 as mm_bf16
from .gemm_base import mm_fp4 as mm_fp4
from .gemm_base import mm_fp8 as mm_fp8
from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100
Expand All @@ -21,7 +23,9 @@

__all__ = [
"SegmentGEMMWrapper",
"bmm_bf16",
"bmm_fp8",
"mm_bf16",
"mm_fp4",
"mm_fp8",
"tgv_gemm_sm100",
Expand Down
Loading