From 390e2ced519982f42013317977627daae1ba14ab Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 02:56:26 +0000 Subject: [PATCH 01/42] [Dev] Optimize make jobs calculation in installation scripts --- .gitmodules | 2 +- 3rdparty/tilelang | 2 +- install.sh | 10 +++++++--- install_amd.sh | 8 +++++--- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/.gitmodules b/.gitmodules index adbfcc33f..d4a0ba81e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,7 +5,7 @@ [submodule "3rdparty/tilelang"] path = 3rdparty/tilelang url = https://github.com/tile-ai/tilelang - branch = bitblas + branch = main [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/tile-ai/cutlass diff --git a/3rdparty/tilelang b/3rdparty/tilelang index b09e2b5cc..7a82c5b8d 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit b09e2b5cc6abfe94c35249cb99ad899ef394964e +Subproject commit 7a82c5b8d781c72a4f96b19e45281dc49b780480 diff --git a/install.sh b/install.sh index 77c258706..5abdcdfda 100755 --- a/install.sh +++ b/install.sh @@ -110,8 +110,12 @@ if [ $? -ne 0 ]; then exit 1 fi +CORES=$(nproc) +MAKE_JOBS=$(( CORES * 75 / 100 )) +echo "Using $MAKE_JOBS jobs for make..." + echo "Building TVM with make..." -make -j +make -j${MAKE_JOBS} if [ $? -ne 0 ]; then echo "Error: TVM build failed." exit 1 @@ -134,7 +138,7 @@ if [ $? -ne 0 ]; then exit 1 fi -make -j +make -j${MAKE_JOBS} if [ $? -ne 0 ]; then echo "Error: TileLang build failed." exit 1 @@ -185,4 +189,4 @@ else fi # Reload ~/.bashrc to apply the changes -source ~/.bashrc +source ~/.bashrc \ No newline at end of file diff --git a/install_amd.sh b/install_amd.sh index dec3dedcf..9b562cb52 100755 --- a/install_amd.sh +++ b/install_amd.sh @@ -60,7 +60,9 @@ cp cmake/config.cmake build cd build echo "set(USE_LLVM llvm-config-16)" >> config.cmake && echo "set(USE_ROCM /opt/rocm)" >> config.cmake -cmake .. && make -j && cd ../../.. +CORES=$(nproc) +MAKE_JOBS=$(( CORES * 75 / 100 )) +cmake .. && make -j${MAKE_JOBS} && cd ../../.. TVM_PREBUILD_PATH=$(realpath .) @@ -77,7 +79,7 @@ if [ $? -ne 0 ]; then exit 1 fi -make -j +make -j${MAKE_JOBS} if [ $? -ne 0 ]; then echo "Error: TileLang build failed." exit 1 @@ -127,4 +129,4 @@ else fi # Reload ~/.bashrc to apply the changes -source ~/.bashrc +source ~/.bashrc \ No newline at end of file From d35ea7a3d950f3962d09731bc57982c34b02a81a Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 02:57:41 +0000 Subject: [PATCH 02/42] update tvm --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index d310bd5aa..accd582d3 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d310bd5aadce96145546fb7a87a6d325ea392b2b +Subproject commit accd582d3a006b6c3473187e1c155fa535343d8a From 34a2e80ec5fafc48b4df2965dbb14f3cc45f036d Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 03:01:57 +0000 Subject: [PATCH 03/42] update submodule to newest --- .gitmodules | 2 +- 3rdparty/cutlass | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index d4a0ba81e..f24f53d91 100644 --- a/.gitmodules +++ b/.gitmodules @@ -9,4 +9,4 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/tile-ai/cutlass - branch = tldev + branch = main diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a2954a8fd..5e497243f 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a2954a8fdd9a73852f2c1ddea97d0e8a579cfb25 +Subproject commit 5e497243f7ad13a2aa842143f9b10bbb23d98292 From f3062f973cd0672efa8d21fa8410196c7015493f Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 03:05:50 +0000 Subject: [PATCH 04/42] update cutlass --- .gitmodules | 4 ---- 3rdparty/cutlass | 1 - 2 files changed, 5 deletions(-) delete mode 160000 3rdparty/cutlass diff --git a/.gitmodules b/.gitmodules index f24f53d91..65f39f9bf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,7 +6,3 @@ path = 3rdparty/tilelang url = https://github.com/tile-ai/tilelang branch = main -[submodule "3rdparty/cutlass"] - path = 3rdparty/cutlass - url = https://github.com/tile-ai/cutlass - branch = main diff --git a/3rdparty/cutlass b/3rdparty/cutlass deleted file mode 160000 index 5e497243f..000000000 --- a/3rdparty/cutlass +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5e497243f7ad13a2aa842143f9b10bbb23d98292 From 7d93650ba222e82fed43d71c5999d930031b9454 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 03:06:47 +0000 Subject: [PATCH 05/42] add cutlass module --- .gitmodules | 3 +++ 3rdparty/cutlass | 1 + 2 files changed, 4 insertions(+) create mode 160000 3rdparty/cutlass diff --git a/.gitmodules b/.gitmodules index 65f39f9bf..25a3b1a8c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,3 +6,6 @@ path = 3rdparty/tilelang url = https://github.com/tile-ai/tilelang branch = main +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000..5e497243f --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 5e497243f7ad13a2aa842143f9b10bbb23d98292 From 7438991d9a3a4aecde1cc3648950b8b2cc04f1f2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 03:21:28 +0000 Subject: [PATCH 06/42] update tvm --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index accd582d3..4776d3119 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit accd582d3a006b6c3473187e1c155fa535343d8a +Subproject commit 4776d3119e43fe064d890b94be7587d53106b9e3 From 54d7d3e52c1a383b164153cc85db70cf54c80624 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 04:31:44 +0000 Subject: [PATCH 07/42] add tl.lower --- bitblas/tl/lower.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 bitblas/tl/lower.py diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py new file mode 100644 index 000000000..17a4db1af --- /dev/null +++ b/bitblas/tl/lower.py @@ -0,0 +1,23 @@ +from bitblas import tilelang as tilelang +from bitblas import tvm +from tvm.target import Target +from tvm import tir +from typing import Optional, Union + +def lower( + func_or_mod: Union[tir.PrimFunc, tvm.IRModule], + target: Union[str, Target] = "auto", + target_host: Optional[Union[str, Target]] = None, + runtime_only=False, +): + result = tilelang.lower( + func_or_mod, + target, + target_host=target_host, + runtime_only=runtime_only, + ) + if runtime_only is True: + return result.host_mod + else: + return result.host_mod, result.parms + From 8078a4dfdb15573e4dd421dd705856b330752982 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 04:46:32 +0000 Subject: [PATCH 08/42] rename tl_lower --- bitblas/tl/lower.py | 2 +- bitblas/tl/tuner.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index 17a4db1af..cefc4f988 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -4,7 +4,7 @@ from tvm import tir from typing import Optional, Union -def lower( +def tl_lower( func_or_mod: Union[tir.PrimFunc, tvm.IRModule], target: Union[str, Target] = "auto", target_host: Optional[Union[str, Target]] = None, diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 539838393..12779ecef 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -22,6 +22,7 @@ ) from bitblas.common import MAX_ERROR_MESSAGE_LENGTH from bitblas.base.base_scheduler import BaseScheduler +from bitblas.tl.lower import tl_lower logger = logging.getLogger(__name__) @@ -122,7 +123,7 @@ def tvm_callback_cuda_postproc(code, _): "tir.disable_cse_tir": True, **(config.pass_context if config.pass_context else {}) }): - rt_mod = tilelang.lower(tl_prim_func, arch.target, runtime_only=True) + rt_mod = tl_lower(tl_prim_func, arch.target, runtime_only=True) from tvm.contrib.tar import tar # Import the tar module From 0e2f7db5b1580e81c2986181ba9b5768e7e8547d Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 05:44:52 +0000 Subject: [PATCH 09/42] replace with tl_lower --- bitblas/ops/operator.py | 3 +- .../test_general_matmul_ops_backend.py | 3 +- .../test_general_matmul_tilelang_impl.py | 7 ++-- .../test_general_matmul_tilelang_kernel.py | 36 ++++++++++--------- testing/python/tilelang/test_simplifier.py | 3 +- .../python/tilelang/test_tilelang_amd_gemm.py | 3 +- .../tilelang/test_tilelang_dequantize_gemm.py | 6 ++-- .../test_tilelang_dyanmic_symbolic.py | 7 ++-- .../tilelang/test_tilelang_flash_atten.py | 7 ++-- testing/python/tilelang/test_tilelang_gemm.py | 3 +- .../tilelang/test_tilelang_gemm_s4_mma.py | 5 +-- .../tilelang/test_tilelang_gemm_simt.py | 3 +- .../tilelang/test_tilelang_mfma_macro_gemm.py | 2 +- .../tilelang/test_tilelang_mma_macro_gemm.py | 11 +++--- 14 files changed, 56 insertions(+), 43 deletions(-) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 304723eb8..46e6226e5 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -21,6 +21,7 @@ from bitblas.builder.lib_generator import LibraryGenerator from bitblas.common import MAX_ERROR_MESSAGE_LENGTH from bitblas.utils import retrieve_func_from_module +from bitblas.tl.lower import tl_lower from dataclasses import dataclass import logging import re @@ -192,7 +193,7 @@ def tvm_callback_hip_postproc(code, _): if self.is_tir_backend(): rt_mod = tvm.build(self.scheduled_ir_module, target=target) elif self.is_tilelang_backend(): - rt_mod = tilelang.lower( + rt_mod = tl_lower( self.scheduled_ir_module, target=target, runtime_only=True) else: raise ValueError(f"Unsupported backend: {self.backend}") diff --git a/testing/python/operators/test_general_matmul_ops_backend.py b/testing/python/operators/test_general_matmul_ops_backend.py index 8d80f7d87..23c312f41 100644 --- a/testing/python/operators/test_general_matmul_ops_backend.py +++ b/testing/python/operators/test_general_matmul_ops_backend.py @@ -5,6 +5,7 @@ from bitblas import MatmulConfig, Matmul import logging from bitblas import set_log_level +from bitblas.tl.lower import tl_lower set_log_level(logging.DEBUG) @@ -36,7 +37,7 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la matmul = Matmul(config=matmul_config, enable_tuning=False) func = matmul.prim_func import tilelang - rt_mod, params = tilelang.lower(func) + rt_mod, params = tl_lower(func) print(rt_mod) assert get_codegen_result(matmul) diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 3f258bab8..edf5766b7 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -3,6 +3,7 @@ from bitblas import tvm as tvm from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower import bitblas.testing from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( matmul_blocked, @@ -47,7 +48,7 @@ def assert_matmul_blocked_correctness(M, enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -105,7 +106,7 @@ def assert_matmul_macro_tensorcore_correctness( num_stages=num_stages, enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code represents generated cuda source @@ -164,7 +165,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index dc1f1c424..2007a43f8 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -4,6 +4,8 @@ from bitblas import tvm as tvm import bitblas.testing from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower + from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( MatmulTileLibraryScheduler,) @@ -51,7 +53,7 @@ def assert_matmul_blocked_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -109,7 +111,7 @@ def assert_matmul_blocked_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -155,7 +157,7 @@ def assert_matmul_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -217,7 +219,7 @@ def assert_matmul_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -263,7 +265,7 @@ def assert_matmul_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -337,7 +339,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -394,7 +396,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -459,7 +461,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -509,7 +511,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -588,7 +590,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( ) print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -666,7 +668,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -765,7 +767,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -849,7 +851,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -970,7 +972,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ) print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1078,7 +1080,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1213,7 +1215,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1344,7 +1346,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( ).with_default_config() if verbose: print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source diff --git a/testing/python/tilelang/test_simplifier.py b/testing/python/tilelang/test_simplifier.py index 18613edc9..fee7fcdf1 100644 --- a/testing/python/tilelang/test_simplifier.py +++ b/testing/python/tilelang/test_simplifier.py @@ -1,6 +1,7 @@ import tvm from bitblas import tilelang as tilelang import tilelang.language as T +from bitblas.tl.lower import tl_lower def modify( @@ -73,7 +74,7 @@ def test_matmul(): mod = tvm.IRModule({func.attrs["global_symbol"]: func}) mod = tilelang.transform.Simplify()(mod) - rt_mod, params = tilelang.lower(mod.functions_items()[0][1], runtime_only=False) + rt_mod, params = tl_lower(mod.functions_items()[0][1], runtime_only=False) # TODO Profiler only support TensorType, not dynamic variable profiler = tilelang.Profiler(rt_mod, params, result_idx=[2]) diff --git a/testing/python/tilelang/test_tilelang_amd_gemm.py b/testing/python/tilelang/test_tilelang_amd_gemm.py index 20abd415c..eafb24745 100644 --- a/testing/python/tilelang/test_tilelang_amd_gemm.py +++ b/testing/python/tilelang/test_tilelang_amd_gemm.py @@ -4,6 +4,7 @@ from bitblas import tvm as tvm import bitblas.testing from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower def matmul( @@ -84,7 +85,7 @@ def run_gemm( num_threads, k_pack=k_pack, ) - mod, params = tilelang.lower(program) + mod, params = tl_lower(program) mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 2e4873f89..a592ed940 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -12,7 +12,7 @@ from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform,) - +from bitblas.tl.lower import tl_lower from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 torch.manual_seed(0) @@ -123,7 +123,7 @@ def run_gemm( num_threads, ) - mod, params = tilelang.lower(program) + mod, params = tl_lower(program) mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) out = mod.run_once() @@ -367,7 +367,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index ae027cbf4..47261990c 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -10,6 +10,7 @@ import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter) +from bitblas.tl.lower import tl_lower torch.manual_seed(0) @@ -178,7 +179,7 @@ def main( def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -271,7 +272,7 @@ def assert_tl_matmul_block_correctness( num_stages, num_threads, ) - mod, params = tilelang.lower(program) + mod, params = tl_lower(program) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) @@ -370,7 +371,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( num_stages, num_threads, ) - mod, params = tilelang.lower(program) + mod, params = tl_lower(program) if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index 2c1c834ee..f430da419 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -12,6 +12,7 @@ import logging from bitblas import set_log_level from bitblas.ops.general_flashatten.tilelang.flashatten import flashatten_blocked +from bitblas.tl.lower import tl_lower set_log_level(logging.DEBUG) @@ -66,7 +67,7 @@ def flashattn_tilelang(batch, heads, seq_len, dim, trans_K, dtypeQKV, dtypeAccu, num_stages=num_stages, is_causal=is_causal, ) - mod, params = tilelang.lower(tl_prim_func) + mod, params = tl_lower(tl_prim_func) mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) from flash_attn.flash_attn_interface import flash_attn_func # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tvm.tl.utils @@ -177,7 +178,7 @@ def main( return main - mod, params = tilelang.lower(kernel()) + mod, params = tl_lower(kernel()) mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) @@ -398,7 +399,7 @@ def main( return main - mod, params = tilelang.lower(kernel()) + mod, params = tl_lower(kernel()) mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.1, atol=0.1) diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index bd26fcc1f..2cdfbb656 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -4,6 +4,7 @@ from bitblas import tvm as tvm import bitblas.testing from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower def matmul( @@ -81,7 +82,7 @@ def run_gemm( num_stages, num_threads, ) - mod, params = tilelang.lower(program) + mod, params = tl_lower(program) mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index b32fd7833..db0ffd44b 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -16,6 +16,7 @@ INT4TensorCoreIntrinEmitterWithLadderTransform, ) from bitblas.base import simplify_prim_func +from bitblas.tl.lower import tl_lower torch.manual_seed(0) @@ -173,7 +174,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -368,7 +369,7 @@ def main( def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None diff --git a/testing/python/tilelang/test_tilelang_gemm_simt.py b/testing/python/tilelang/test_tilelang_gemm_simt.py index 33e5abae6..47176c92a 100644 --- a/testing/python/tilelang/test_tilelang_gemm_simt.py +++ b/testing/python/tilelang/test_tilelang_gemm_simt.py @@ -10,6 +10,7 @@ import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.base import simplify_prim_func +from bitblas.tl.lower import tl_lower torch.manual_seed(0) @@ -142,7 +143,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() print(src_code) # src_code is the generated cuda source diff --git a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py index b1f16c207..a579efdf0 100644 --- a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py @@ -172,7 +172,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index dbdfd1034..d4157871f 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -15,6 +15,7 @@ ) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 from bitblas.base import simplify_prim_func +from bitblas.tl.lower import tl_lower torch.manual_seed(0) @@ -186,7 +187,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -387,7 +388,7 @@ def main( def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -564,7 +565,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -824,7 +825,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -1035,7 +1036,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ matmul = tl_matmul_with_ladder_input_weight_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_a, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None From 019d235332565f77cd4c4e808e74172fd1d46b8b Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 06:09:12 +0000 Subject: [PATCH 10/42] fix typo and func --- bitblas/tl/lower.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index cefc4f988..5b36514da 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -15,9 +15,10 @@ def tl_lower( target, target_host=target_host, runtime_only=runtime_only, + enable_host_codegen=True, ) if runtime_only is True: - return result.host_mod + return result.rt_mod else: - return result.host_mod, result.parms + return result.rt_mod, result.params From 45be64c3a4a06152a2d4c1d758ab91ebc18ae966 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 08:29:47 +0000 Subject: [PATCH 11/42] Refactor and enhance TileLang profiling and lowering functionality - Introduced TLProfiler class for improved profiling capabilities in TileLang. - Updated various test files to utilize TLProfiler instead of the previous profiler implementation. - Enhanced the tl_lower function to support additional target optimizations and device calls. - Added CUDA and HIP compilation callbacks to handle device-specific compilation. - Improved parameter extraction and target handling in the lowering process. - Ensured compatibility with new profiling methods across multiple test cases. --- bitblas/tl/lower.py | 221 ++++++++++++++- bitblas/tl/profiler.py | 253 ++++++++++++++++++ .../test_general_matmul_tilelang_impl.py | 7 +- .../test_general_matmul_tilelang_kernel.py | 35 +-- testing/python/tilelang/test_simplifier.py | 4 +- .../python/tilelang/test_tilelang_amd_gemm.py | 4 +- .../tilelang/test_tilelang_dequantize_gemm.py | 5 +- .../test_tilelang_dyanmic_symbolic.py | 7 +- .../tilelang/test_tilelang_flash_atten.py | 7 +- testing/python/tilelang/test_tilelang_gemm.py | 3 +- .../tilelang/test_tilelang_gemm_s4_mma.py | 5 +- .../tilelang/test_tilelang_gemm_simt.py | 3 +- .../tilelang/test_tilelang_mfma_macro_gemm.py | 3 +- .../tilelang/test_tilelang_mma_macro_gemm.py | 11 +- 14 files changed, 513 insertions(+), 55 deletions(-) create mode 100644 bitblas/tl/profiler.py diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index 5b36514da..64db656be 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -1,8 +1,137 @@ +import os +import os.path as osp +from typing import Union, Optional, Callable, List from bitblas import tilelang as tilelang -from bitblas import tvm -from tvm.target import Target +from tilelang import tvm as tvm from tvm import tir -from typing import Optional, Union +from tvm.ir import CallingConv +from tvm.target import Target +from tilelang.contrib import hipcc, nvcc +from tilelang.engine.param import KernelParam +from tilelang.utils.target import determine_target +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) + + +def is_cpu_device_backend(target: Target): + return target.kind.name == "c" + + +def has_device_kernel_launch(attrs) -> bool: + """Check if the attributes indicate a device kernel launch.""" + return bool(attrs and "calling_conv" in attrs and + attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) + + +def is_device_call_c_device(func: tir.PrimFunc): + attrs = func.attrs + + # Check if it's a C target + if "target" in attrs and attrs["target"].kind.name == "c": + return True + + return has_device_kernel_launch(attrs) + + +def is_device_call(func: tir.PrimFunc): + return has_device_kernel_launch(func.attrs) + + +def get_device_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: + return is_device_call_c_device if is_device_c else is_device_call + + +def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: + return lambda func: not get_device_call(is_device_c)(func) + + +@tvm.register_func("tilelang_callback_cuda_compile", override=True) +def tilelang_callback_cuda_compile(code, target): + project_root = osp.join(osp.dirname(__file__), "../..") + if "TL_TEMPLATE_PATH" in os.environ: + tl_template_path = os.environ["TL_TEMPLATE_PATH"] + else: + tl_template_path = osp.abspath(osp.join(project_root, "src")) + # TODO(lei): this indeed should be renamed into + # TL_CUTLASS_INCLUDE_PATH in the future + if "TL_CUTLASS_PATH" in os.environ: + cutlass_path = os.environ["TL_CUTLASS_PATH"] + else: + cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) + compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) + + # special handle for Hopper + if compute_version == "90": + arch = ["-arch=sm_90a"] + format = "cubin" + else: + arch = [f"-arch=sm_{compute_version}"] + format = "cubin" + + # printing out number of registers + debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage" + ptx = nvcc.compile_cuda( + code, + format, + arch, + options=[ + "-std=c++17", + debug_option, + "--use_fast_math", + "-I" + tl_template_path, + "-I" + cutlass_path, + ], + verbose=False, + ) + + return ptx + + +@tvm.register_func("tilelang_callback_hip_compile", override=True) +def tilelang_callback_hip_compile(code, target): + project_root = osp.join(osp.dirname(__file__), "../..") + tl_template_path = osp.abspath(osp.join(project_root, "src")) + + # TODO(lei): actually this indeed should be renamed into + # TL_COMPOSABLE_KERNEL_INCLUDE_PATH in the future + if "TL_COMPOSABLE_KERNEL_PATH" in os.environ: + ck_path = os.environ["TL_COMPOSABLE_KERNEL_PATH"] + else: + ck_path = osp.abspath(osp.join(project_root, "3rdparty/composable_kernel/include")) + + hsaco = hipcc.compile_hip( + code, + target_format="hsaco", + options=[ + "-std=c++17", + "-I" + tl_template_path, + "-I" + ck_path, + ], + verbose=False, + ) + + return hsaco + + +def extrac_params(func: tir.PrimFunc) -> List[KernelParam]: + tensor_types = [] + for var in func.params: + if var in func.buffer_map: + tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) + else: + tensor_types.append(KernelParam.from_var(var)) + return tensor_types + + +def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]): + + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + + return target_host + def tl_lower( func_or_mod: Union[tir.PrimFunc, tvm.IRModule], @@ -10,15 +139,81 @@ def tl_lower( target_host: Optional[Union[str, Target]] = None, runtime_only=False, ): - result = tilelang.lower( - func_or_mod, - target, - target_host=target_host, - runtime_only=runtime_only, - enable_host_codegen=True, - ) - if runtime_only is True: - return result.rt_mod + + mod = func_or_mod + if isinstance(func_or_mod, tir.PrimFunc): + func = func_or_mod + params = extrac_params(func) if not runtime_only else None + mod = tvm.IRModule({func.attrs["global_symbol"]: func}) + + if isinstance(target, str): + target = determine_target(target) + + target_host = canon_target_host(target, target_host) + + target_host = tvm.target.Target.canon_target(target_host) + target = tvm.target.Target(target, target_host) + + _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) + _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) + + # Phase 1: Lower and legalize the IR + mod = LowerAndLegalize(mod, target) + + # Phase 2: Optimize the IR for the target + mod = OptimizeForTarget(mod, target) + + host_mod = tir.transform.Filter(_is_host_call)(mod) + host_mod = tir.transform.BindTarget(target_host)(host_mod) + host_mod = tir.transform.FP8StorageLegalize()(host_mod) + host_mod = tir.transform.BF16StorageLegalize()(host_mod) + host_mod = tir.transform.LowerTVMBuiltin()(host_mod) + host_mod = tir.transform.LowerCustomDatatypes()(host_mod) + host_mod = tir.transform.LowerIntrin()(host_mod) + host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod) + host_mod = tir.transform.CombineContextCall()(host_mod) + + if target_host.kind.name == "llvm": + host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host) + elif target_host.kind.name == "c": + if is_cpu_device_backend(target): + host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host) + else: + host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host) else: - return result.rt_mod, result.params + raise ValueError(f"Target host {target_host.kind.name} is not supported") + device_mod = tir.transform.Filter(_is_device_call)(mod) + device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod) + device_mod = tir.transform.LowerIntrin()(device_mod) + device_mod = tir.transform.Simplify()(device_mod) + + if target.kind.name == "cuda": + # Debug comments to get the code + # code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target) + device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target) + elif target.kind.name == "hip": + device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) + elif target.kind.name == "c": + device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target) + elif target.kind.name == "llvm": + device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target) + elif target.kind.name == "webgpu": + device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target) + else: + raise ValueError(f"Target {target.kind.name} is not supported") + + host_mod.import_module(device_mod) + + if target_host.kind.name == "c": + # cpu host should be recompiled + # TODO(lei): this is a hack to make the C host backend work + temp_dir = tvm.contrib.utils.tempdir() + tmp_lib_path = temp_dir.relpath("tmp.so") + host_mod.export_library(tmp_lib_path) + host_mod = tvm.runtime.load_module(tmp_lib_path) + + if runtime_only is True: + return host_mod + else: + return host_mod, params \ No newline at end of file diff --git a/bitblas/tl/profiler.py b/bitblas/tl/profiler.py new file mode 100644 index 000000000..e41e9f19c --- /dev/null +++ b/bitblas/tl/profiler.py @@ -0,0 +1,253 @@ +from bitblas import tilelang as tilelang +from typing import List, Literal, Optional, Callable +from functools import partial +import torch +from contextlib import suppress + +import tvm +from tvm.relay import TensorType + +from tilelang.engine import lower +from tilelang.jit.adapter import TorchDLPackKernelAdapter +from tilelang.utils.tensor import ( + get_tensor_supply, + TensorSupplyType, + torch_assert_close, + adapt_torch2tvm, +) + + +class TLProfiler(TorchDLPackKernelAdapter): + + def __init__( + self, + mod, + params: List[TensorType], + result_idx: List[int], + supply_type: TensorSupplyType = TensorSupplyType.Normal, + ): + super().__init__(mod, params, result_idx) + self.supply = get_tensor_supply(supply_type) + + def _get_inputs(self, with_output=False): + ins = [] + for i in range(len(self.params)): + if with_output or i not in self.result_idx: + ins.append(self.supply(self.params[i])) + return ins + + def assert_allclose( + self, + reference_program: Callable, + atol: float = 1e-2, + rtol: float = 1e-2, + max_mismatched_ratio=0.01, + ): + ins = self._get_inputs() + ref_outs = reference_program(*ins) + torch.cuda.synchronize() + lib_outs = self.func(*ins) + torch.cuda.synchronize() + + if isinstance(lib_outs, torch.Tensor): + lib_outs = [lib_outs] + if isinstance(ref_outs, torch.Tensor): + ref_outs = [ref_outs] + assert len(lib_outs) == len(ref_outs) + # torch.set_printoptions(edgeitems=torch.inf) + for lhs, rhs in zip(lib_outs, ref_outs): + # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) + # total_elements = lhs.numel() + # num_not_close = (~close_mask).sum().item() + # percentage_not_close = (num_not_close / total_elements) * 100 + # print(f"{percentage_not_close:.2f}% of the elements are not close.") + # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") + torch_assert_close( + lhs, + rhs, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + ) + + def assert_consistent(self, repeat=10): + # Used to check no race condition inside the kernel + ins = self._get_inputs() + ref_outs = self.func(*ins) + + for _ in range(repeat): + lib_outs = self.func(*ins) + for lhs, rhs in zip(lib_outs, ref_outs): + assert torch.allclose(lhs, rhs), [ + "result is not consistent", + lhs, + rhs, + ] + + def run_once(self, func: Optional[Callable] = None): + ins = self._get_inputs() + if not func: + func = self.__call__ + return func(*ins) + + def determine_profiler(self, + func: Optional[Callable] = None, + profiler: Literal["torch", "tvm", "auto"] = "auto"): + if profiler == "auto": + if func is None or isinstance(func, tvm.runtime.Module): + return "tvm" + else: + return "torch" + return profiler + + def do_bench( + self, + func: Optional[Callable] = None, + warmup: int = 25, + rep: int = 100, + n_warmup: int = 1, + n_repeat: int = 1, + profiler: Literal["torch", "tvm", "auto"] = "auto", + input_tensors: List[torch.Tensor] = None, + ) -> float: + profiler = self.determine_profiler(func, profiler) + if profiler == "torch": + ins = self._get_inputs() if input_tensors is None else input_tensors + bench_func = partial(func, *ins) + return do_bench( + bench_func, + warmup=warmup, + rep=rep, + _n_warmup=n_warmup, + _n_repeat=n_repeat, + ) + elif profiler == "tvm": + if func is None: + func = self.mod + assert isinstance( + func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" + ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) + target = "cuda" + + with suppress(Exception): + target = self.mod.imported_modules[0].type_key + + assert target in ["cuda", "hip"], f"Unknown target: {target}" + + device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) + time_evaluator = self.mod.time_evaluator( + self.mod.entry_name, device, number=rep, repeat=n_repeat) + tvm_inputs = [adapt_torch2tvm(inp) for inp in ins] + # Transform Latency to ms + return time_evaluator(*tvm_inputs).mean * 1e3 + else: + raise ValueError(f"Unknown profiler: {profiler}") + + +def do_bench( + fn, + warmup=25, + rep=100, + _n_warmup=0, + _n_repeat=0, + grad_to_none=None, + quantiles=None, + fast_flush=True, + return_mode="mean", +) -> float: + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + + Returns: + float: The median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + """ + assert return_mode in ["min", "max", "mean", "median"] + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + if _n_warmup > 0: + n_warmup = _n_warmup + if _n_repeat > 0: + n_repeat = _n_repeat + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +_cached = {} + + +def cached(func, result_idx: List[int], *args): + global _cached + key = (func, tuple(result_idx), *args) + if key not in _cached: + program = func(*args) + mod, params = lower(program) + mod = TorchDLPackKernelAdapter(mod, params, result_idx) + _cached[key] = mod + return _cached[key] \ No newline at end of file diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index edf5766b7..e521aadb5 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -10,6 +10,7 @@ matmul_macro_tensorcore, matmul_macro_tensorcore_weight_propagation_level_ldmatrix, ) +from bitblas.tl.profiler import TLProfiler import torch import torch.backends @@ -58,7 +59,7 @@ def assert_matmul_blocked_correctness(M, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -116,7 +117,7 @@ def assert_matmul_macro_tensorcore_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -186,7 +187,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 2007a43f8..773aab028 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -5,6 +5,7 @@ import bitblas.testing from bitblas import tilelang as tilelang from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( MatmulTileLibraryScheduler,) @@ -63,7 +64,7 @@ def assert_matmul_blocked_with_default_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -121,7 +122,7 @@ def assert_matmul_blocked_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -165,7 +166,7 @@ def assert_matmul_fine_grained_with_default_correctness( B = (torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B else torch.rand( K, N, device="cuda", dtype=getattr(torch, in_dtype))) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25) @@ -229,7 +230,7 @@ def assert_matmul_fine_grained_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -286,7 +287,7 @@ def assert_matmul_weight_propagation_with_default_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -360,7 +361,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -407,7 +408,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") print(latency) @@ -472,7 +473,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") print(latency) @@ -536,7 +537,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -615,7 +616,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -694,7 +695,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -793,7 +794,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -886,7 +887,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -1007,7 +1008,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -1137,7 +1138,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1270,7 +1271,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1417,7 +1418,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) diff --git a/testing/python/tilelang/test_simplifier.py b/testing/python/tilelang/test_simplifier.py index fee7fcdf1..678006ec9 100644 --- a/testing/python/tilelang/test_simplifier.py +++ b/testing/python/tilelang/test_simplifier.py @@ -2,7 +2,7 @@ from bitblas import tilelang as tilelang import tilelang.language as T from bitblas.tl.lower import tl_lower - +from bitblas.tl.profiler import TLProfiler def modify( with_B: bool = False, @@ -76,7 +76,7 @@ def test_matmul(): rt_mod, params = tl_lower(mod.functions_items()[0][1], runtime_only=False) # TODO Profiler only support TensorType, not dynamic variable - profiler = tilelang.Profiler(rt_mod, params, result_idx=[2]) + profiler = TLProfiler(rt_mod, params, result_idx=[2]) import torch a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() diff --git a/testing/python/tilelang/test_tilelang_amd_gemm.py b/testing/python/tilelang/test_tilelang_amd_gemm.py index eafb24745..a57c6adc2 100644 --- a/testing/python/tilelang/test_tilelang_amd_gemm.py +++ b/testing/python/tilelang/test_tilelang_amd_gemm.py @@ -5,7 +5,7 @@ import bitblas.testing from bitblas import tilelang as tilelang from bitblas.tl.lower import tl_lower - +from bitblas.tl.profiler import TLProfiler def matmul( M, @@ -86,7 +86,7 @@ def run_gemm( k_pack=k_pack, ) mod, params = tl_lower(program) - mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index a592ed940..3750bb2be 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -13,6 +13,7 @@ from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 torch.manual_seed(0) @@ -124,7 +125,7 @@ def run_gemm( ) mod, params = tl_lower(program) - mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [2], tilelang.TensorSupplyType.Integer) out = mod.run_once() assert out is not None @@ -406,7 +407,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index 47261990c..5adf37382 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -11,6 +11,7 @@ from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter) from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -189,7 +190,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -278,7 +279,7 @@ def assert_tl_matmul_block_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): @@ -382,7 +383,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index f430da419..e8940ca67 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -13,6 +13,7 @@ from bitblas import set_log_level from bitblas.ops.general_flashatten.tilelang.flashatten import flashatten_blocked from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler set_log_level(logging.DEBUG) @@ -68,7 +69,7 @@ def flashattn_tilelang(batch, heads, seq_len, dim, trans_K, dtypeQKV, dtypeAccu, is_causal=is_causal, ) mod, params = tl_lower(tl_prim_func) - mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod = TLProfiler(mod, params, [3], tilelang.TensorSupplyType.Normal) from flash_attn.flash_attn_interface import flash_attn_func # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tvm.tl.utils ins = mod._get_inputs() @@ -179,7 +180,7 @@ def main( return main mod, params = tl_lower(kernel()) - mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod = TLProfiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) @@ -400,7 +401,7 @@ def main( return main mod, params = tl_lower(kernel()) - mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod = TLProfiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.1, atol=0.1) diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index 2cdfbb656..e397cee49 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -5,6 +5,7 @@ import bitblas.testing from bitblas import tilelang as tilelang from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler def matmul( @@ -83,7 +84,7 @@ def run_gemm( num_threads, ) mod, params = tl_lower(program) - mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index db0ffd44b..e4da6dd7d 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -17,6 +17,7 @@ ) from bitblas.base import simplify_prim_func from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -185,7 +186,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, compressed_B, C) print(C) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") @@ -392,7 +393,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(compressed_B.cpu()).cuda() mod(compressed_A, LB, C) diff --git a/testing/python/tilelang/test_tilelang_gemm_simt.py b/testing/python/tilelang/test_tilelang_gemm_simt.py index 47176c92a..df98df390 100644 --- a/testing/python/tilelang/test_tilelang_gemm_simt.py +++ b/testing/python/tilelang/test_tilelang_gemm_simt.py @@ -11,6 +11,7 @@ from bitblas.tl.utils import get_swizzle_layout from bitblas.base import simplify_prim_func from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -158,7 +159,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py index a579efdf0..d0faf9fbe 100644 --- a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py @@ -11,6 +11,7 @@ from bitblas.tl.mfma_macro_generator import ( MatrixCoreIntrinEmitter,) from bitblas.base import simplify_prim_func +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -186,7 +187,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index d4157871f..8f8c3fcfa 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -16,6 +16,7 @@ from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 from bitblas.base import simplify_prim_func from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -201,7 +202,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -398,7 +399,7 @@ def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -584,7 +585,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(B.cpu()).cuda() @@ -864,7 +865,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) @@ -1069,7 +1070,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ ladder_permutate_b = bitblas.ops.LadderPermutate(ladder_permutate_config_B) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) LA = ladder_permutate_a(A.cpu()).cuda() LB = ladder_permutate_b(B.cpu()).cuda() From de261049f4332d48ce43eae53cbcf8c97b919b94 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 14:21:35 +0000 Subject: [PATCH 12/42] update tilelang --- 3rdparty/tilelang | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tilelang b/3rdparty/tilelang index 7a82c5b8d..3198a2f46 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit 7a82c5b8d781c72a4f96b19e45281dc49b780480 +Subproject commit 3198a2f46c373a6b6967bd33332145812a47b44e From 5d9887bcd2dfeed3ff5cbb5c85cadbd6afb48b57 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 14:39:06 +0000 Subject: [PATCH 13/42] bug fix for dynamic shape --- bitblas/ops/operator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 46e6226e5..55294129f 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -188,6 +188,7 @@ def tvm_callback_hip_postproc(code, _): config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, + "tl.disable_dynamic_tail_split": False, **(self.pass_context if self.pass_context else {}), }): if self.is_tir_backend(): From e287353edc10cb0c47e842f38da67d91cfc5ded5 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 14:57:43 +0000 Subject: [PATCH 14/42] fix lower --- bitblas/tl/lower.py | 84 +++----------------- testing/python/module/test_bitblas_linear.py | 22 ++--- 2 files changed, 21 insertions(+), 85 deletions(-) diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index 64db656be..e28a9d2ad 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -140,80 +140,16 @@ def tl_lower( runtime_only=False, ): - mod = func_or_mod - if isinstance(func_or_mod, tir.PrimFunc): - func = func_or_mod - params = extrac_params(func) if not runtime_only else None - mod = tvm.IRModule({func.attrs["global_symbol"]: func}) - - if isinstance(target, str): - target = determine_target(target) - - target_host = canon_target_host(target, target_host) - - target_host = tvm.target.Target.canon_target(target_host) - target = tvm.target.Target(target, target_host) - - _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) - _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) - - # Phase 1: Lower and legalize the IR - mod = LowerAndLegalize(mod, target) - - # Phase 2: Optimize the IR for the target - mod = OptimizeForTarget(mod, target) - - host_mod = tir.transform.Filter(_is_host_call)(mod) - host_mod = tir.transform.BindTarget(target_host)(host_mod) - host_mod = tir.transform.FP8StorageLegalize()(host_mod) - host_mod = tir.transform.BF16StorageLegalize()(host_mod) - host_mod = tir.transform.LowerTVMBuiltin()(host_mod) - host_mod = tir.transform.LowerCustomDatatypes()(host_mod) - host_mod = tir.transform.LowerIntrin()(host_mod) - host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod) - host_mod = tir.transform.CombineContextCall()(host_mod) - - if target_host.kind.name == "llvm": - host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host) - elif target_host.kind.name == "c": - if is_cpu_device_backend(target): - host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host) - else: - host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host) - else: - raise ValueError(f"Target host {target_host.kind.name} is not supported") - - device_mod = tir.transform.Filter(_is_device_call)(mod) - device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod) - device_mod = tir.transform.LowerIntrin()(device_mod) - device_mod = tir.transform.Simplify()(device_mod) - - if target.kind.name == "cuda": - # Debug comments to get the code - # code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target) - device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target) - elif target.kind.name == "hip": - device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target) - elif target.kind.name == "c": - device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target) - elif target.kind.name == "llvm": - device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target) - elif target.kind.name == "webgpu": - device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target) - else: - raise ValueError(f"Target {target.kind.name} is not supported") - - host_mod.import_module(device_mod) - - if target_host.kind.name == "c": - # cpu host should be recompiled - # TODO(lei): this is a hack to make the C host backend work - temp_dir = tvm.contrib.utils.tempdir() - tmp_lib_path = temp_dir.relpath("tmp.so") - host_mod.export_library(tmp_lib_path) - host_mod = tvm.runtime.load_module(tmp_lib_path) + result = tilelang.lower( + func_or_mod, + target=target, + target_host=target_host, + runtime_only=runtime_only, + enable_device_compile=True, + enable_host_codegen=True, + ) if runtime_only is True: - return host_mod + return result.rt_mod else: - return host_mod, params \ No newline at end of file + return result.rt_mod, result.params \ No newline at end of file diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 286aff237..8c94970f0 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -42,11 +42,11 @@ def correctness_consistent(m, in_features, out_features, bias): bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2) -def test_correctness_consistent(): - correctness_consistent(1, 1024, 1024, False) - correctness_consistent(1, 1024, 1024, True) - correctness_consistent(1024, 1024, 1024, True) - correctness_consistent([1, 1024], 1024, 1024, True) +# def test_correctness_consistent(): +# correctness_consistent(1, 1024, 1024, False) +# correctness_consistent(1, 1024, 1024, True) +# correctness_consistent(1024, 1024, 1024, True) +# correctness_consistent([1, 1024], 1024, 1024, True) def correctness_weight_only_dequantize( @@ -167,13 +167,13 @@ def correctness_weight_only_dequantize( def test_correctness_weight_only_dequantize(): - correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) - correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) - correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None) + # correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + # correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + # correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None) correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None) - correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original") - correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original") - correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale") + # correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original") + # correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original") + # correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale") def profile(model, input_data): From baeb8c9998cae0e68b994268e18f207a66b303e2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 21 Apr 2025 15:38:53 +0000 Subject: [PATCH 15/42] update lower --- bitblas/tl/lower.py | 131 +------------------------------------------- bitblas/tl/tuner.py | 1 + 2 files changed, 2 insertions(+), 130 deletions(-) diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index e28a9d2ad..89f13c9a3 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -1,137 +1,8 @@ -import os -import os.path as osp -from typing import Union, Optional, Callable, List +from typing import Union, Optional from bitblas import tilelang as tilelang from tilelang import tvm as tvm from tvm import tir -from tvm.ir import CallingConv from tvm.target import Target -from tilelang.contrib import hipcc, nvcc -from tilelang.engine.param import KernelParam -from tilelang.utils.target import determine_target -from tilelang.engine.phase import ( - LowerAndLegalize, - OptimizeForTarget, -) - - -def is_cpu_device_backend(target: Target): - return target.kind.name == "c" - - -def has_device_kernel_launch(attrs) -> bool: - """Check if the attributes indicate a device kernel launch.""" - return bool(attrs and "calling_conv" in attrs and - attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH) - - -def is_device_call_c_device(func: tir.PrimFunc): - attrs = func.attrs - - # Check if it's a C target - if "target" in attrs and attrs["target"].kind.name == "c": - return True - - return has_device_kernel_launch(attrs) - - -def is_device_call(func: tir.PrimFunc): - return has_device_kernel_launch(func.attrs) - - -def get_device_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: - return is_device_call_c_device if is_device_c else is_device_call - - -def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: - return lambda func: not get_device_call(is_device_c)(func) - - -@tvm.register_func("tilelang_callback_cuda_compile", override=True) -def tilelang_callback_cuda_compile(code, target): - project_root = osp.join(osp.dirname(__file__), "../..") - if "TL_TEMPLATE_PATH" in os.environ: - tl_template_path = os.environ["TL_TEMPLATE_PATH"] - else: - tl_template_path = osp.abspath(osp.join(project_root, "src")) - # TODO(lei): this indeed should be renamed into - # TL_CUTLASS_INCLUDE_PATH in the future - if "TL_CUTLASS_PATH" in os.environ: - cutlass_path = os.environ["TL_CUTLASS_PATH"] - else: - cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) - compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) - - # special handle for Hopper - if compute_version == "90": - arch = ["-arch=sm_90a"] - format = "cubin" - else: - arch = [f"-arch=sm_{compute_version}"] - format = "cubin" - - # printing out number of registers - debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage" - ptx = nvcc.compile_cuda( - code, - format, - arch, - options=[ - "-std=c++17", - debug_option, - "--use_fast_math", - "-I" + tl_template_path, - "-I" + cutlass_path, - ], - verbose=False, - ) - - return ptx - - -@tvm.register_func("tilelang_callback_hip_compile", override=True) -def tilelang_callback_hip_compile(code, target): - project_root = osp.join(osp.dirname(__file__), "../..") - tl_template_path = osp.abspath(osp.join(project_root, "src")) - - # TODO(lei): actually this indeed should be renamed into - # TL_COMPOSABLE_KERNEL_INCLUDE_PATH in the future - if "TL_COMPOSABLE_KERNEL_PATH" in os.environ: - ck_path = os.environ["TL_COMPOSABLE_KERNEL_PATH"] - else: - ck_path = osp.abspath(osp.join(project_root, "3rdparty/composable_kernel/include")) - - hsaco = hipcc.compile_hip( - code, - target_format="hsaco", - options=[ - "-std=c++17", - "-I" + tl_template_path, - "-I" + ck_path, - ], - verbose=False, - ) - - return hsaco - - -def extrac_params(func: tir.PrimFunc) -> List[KernelParam]: - tensor_types = [] - for var in func.params: - if var in func.buffer_map: - tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) - else: - tensor_types.append(KernelParam.from_var(var)) - return tensor_types - - -def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]): - - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - - return target_host - def tl_lower( func_or_mod: Union[tir.PrimFunc, tvm.IRModule], diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 12779ecef..b5ea16b17 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -121,6 +121,7 @@ def tvm_callback_cuda_postproc(code, _): config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, + "tl.disable_dynamic_tail_split": False, **(config.pass_context if config.pass_context else {}) }): rt_mod = tl_lower(tl_prim_func, arch.target, runtime_only=True) From 66c0a9f89cfc768e3df6f095994b439e4c241a6c Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Apr 2025 04:30:45 +0000 Subject: [PATCH 16/42] bug dix --- bitblas/base/utils.py | 1 + .../test_general_matmul_tile_schedule.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 92822d1bd..dd61bd458 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -265,6 +265,7 @@ def tvm_callback_cuda_postproc(code, _): with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, + "tl.disable_dynamic_tail_split": False, **config.pass_context }): rt_mod = tvm.build(mod, target=arch.target) diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 38f49786e..70784d960 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -73,7 +73,8 @@ def assert_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": True + "tir.merge_static_smem": True, + "tl.disable_dynamic_tail_split": False }): ref_rt_mod = tvm.build(ref_sch.mod, target=target) @@ -97,7 +98,8 @@ def assert_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": True + "tir.merge_static_smem": True, + "tl.disable_dynamic_tail_split": False }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) @@ -173,7 +175,8 @@ def assert_correctness_with_ladder_ldmatrix_propagate( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": False + "tir.merge_static_smem": False, + "tl.disable_dynamic_tail_split": False }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) # Evaluate the correctness @@ -286,7 +289,8 @@ def assert_dequant_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": False + "tir.merge_static_smem": False, + "tl.disable_dynamic_tail_split": False, }): ref_rt_mod = tvm.build(ref_sch.mod, target=target) @@ -310,7 +314,8 @@ def assert_dequant_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": False + "tir.merge_static_smem": False, + "tl.disable_dynamic_tail_split": False }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) @@ -427,7 +432,8 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.merge_static_smem": False, - "tir.disable_cse_tir": True + "tir.disable_cse_tir": True, + "tl.disable_dynamic_tail_split": False }): rt_mod = tvm.build(block_reduce_sch.mod, target=target) src_code = rt_mod.imported_modules[0].get_source() From 0912d6125870f97a2cb1955d5dd6f19a50615b31 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Apr 2025 06:05:44 +0000 Subject: [PATCH 17/42] bug fix for lower and profiler --- bitblas/tl/lower.py | 26 ++++++++++++------- .../test_general_matmul_tilelang_kernel.py | 8 +++--- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index 89f13c9a3..dab9844ae 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -10,16 +10,22 @@ def tl_lower( target_host: Optional[Union[str, Target]] = None, runtime_only=False, ): - - result = tilelang.lower( - func_or_mod, - target=target, - target_host=target_host, - runtime_only=runtime_only, - enable_device_compile=True, - enable_host_codegen=True, - ) - + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": True, + "tir.disable_cse_tir": True, + "tl.disable_dynamic_tail_split": False, + }): + result = tilelang.lower( + func_or_mod, + target=target, + target_host=target_host, + runtime_only=runtime_only, + enable_host_codegen=True, + enable_device_compile=True, + ) + print("Lowering result:") + print(result.rt_mod) if runtime_only is True: return result.rt_mod else: diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 773aab028..c25d5f9c4 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -410,7 +410,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod, warmup=25) print(latency) # Ensure that the latency is not None @@ -475,7 +475,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod.func, warmup=25) print(latency) # Ensure that the latency is not None @@ -704,7 +704,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( print(f"{lop3_compressed_B=}") mod(compressed_A, lop3_compressed_B, C) print(C) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod.func, warmup=25) print(latency) # Ensure that the latency is not None assert latency is not None @@ -803,7 +803,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( print(f"{lop3_compressed_B=}") mod(compressed_A, lop3_compressed_B, C) print(C) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod.func, warmup=25) print(latency) # Ensure that the latency is not None assert latency is not None From d707f5571d61369dee917265123b3a112bb11eeb Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Apr 2025 06:18:27 +0000 Subject: [PATCH 18/42] fix lower adn update tvm --- 3rdparty/tvm | 2 +- bitblas/tl/lower.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 4776d3119..ea3cdeb17 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 4776d3119e43fe064d890b94be7587d53106b9e3 +Subproject commit ea3cdeb17755c3fe5bfd1f181380793fe3b0eca0 diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index dab9844ae..a8a552699 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -12,8 +12,6 @@ def tl_lower( ): with tvm.transform.PassContext( config={ - "tir.use_async_copy": True, - "tir.disable_cse_tir": True, "tl.disable_dynamic_tail_split": False, }): result = tilelang.lower( From dd23170e5e2a1e416e567db989a757dad44ebf30 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Apr 2025 06:39:40 +0000 Subject: [PATCH 19/42] remove unnecessary change --- 3rdparty/tvm | 2 +- bitblas/base/utils.py | 1 - bitblas/ops/operator.py | 1 - bitblas/tl/tuner.py | 1 - testing/python/module/test_bitblas_linear.py | 12 ++++++------ .../operators/test_general_matmul_tile_schedule.py | 6 ------ 6 files changed, 7 insertions(+), 16 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index ea3cdeb17..b29452549 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit ea3cdeb17755c3fe5bfd1f181380793fe3b0eca0 +Subproject commit b2945254932cffa89922ec7f6e868d726aed0f6a diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index dd61bd458..92822d1bd 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -265,7 +265,6 @@ def tvm_callback_cuda_postproc(code, _): with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, - "tl.disable_dynamic_tail_split": False, **config.pass_context }): rt_mod = tvm.build(mod, target=arch.target) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 55294129f..46e6226e5 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -188,7 +188,6 @@ def tvm_callback_hip_postproc(code, _): config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, - "tl.disable_dynamic_tail_split": False, **(self.pass_context if self.pass_context else {}), }): if self.is_tir_backend(): diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index b5ea16b17..12779ecef 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -121,7 +121,6 @@ def tvm_callback_cuda_postproc(code, _): config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, - "tl.disable_dynamic_tail_split": False, **(config.pass_context if config.pass_context else {}) }): rt_mod = tl_lower(tl_prim_func, arch.target, runtime_only=True) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 8c94970f0..1cd9de677 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -167,13 +167,13 @@ def correctness_weight_only_dequantize( def test_correctness_weight_only_dequantize(): - # correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) - # correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) - # correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None) correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None) - # correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original") - # correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original") - # correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale") + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale") def profile(model, input_data): diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 70784d960..73f83402b 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -74,7 +74,6 @@ def assert_correctness_with_block_reduce( with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.merge_static_smem": True, - "tl.disable_dynamic_tail_split": False }): ref_rt_mod = tvm.build(ref_sch.mod, target=target) @@ -99,7 +98,6 @@ def assert_correctness_with_block_reduce( with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.merge_static_smem": True, - "tl.disable_dynamic_tail_split": False }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) @@ -176,7 +174,6 @@ def assert_correctness_with_ladder_ldmatrix_propagate( with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.merge_static_smem": False, - "tl.disable_dynamic_tail_split": False }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) # Evaluate the correctness @@ -290,7 +287,6 @@ def assert_dequant_correctness_with_block_reduce( with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.merge_static_smem": False, - "tl.disable_dynamic_tail_split": False, }): ref_rt_mod = tvm.build(ref_sch.mod, target=target) @@ -315,7 +311,6 @@ def assert_dequant_correctness_with_block_reduce( with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.merge_static_smem": False, - "tl.disable_dynamic_tail_split": False }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) @@ -433,7 +428,6 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( "tir.use_async_copy": True, "tir.merge_static_smem": False, "tir.disable_cse_tir": True, - "tl.disable_dynamic_tail_split": False }): rt_mod = tvm.build(block_reduce_sch.mod, target=target) src_code = rt_mod.imported_modules[0].get_source() From 611ecd652def1134005f77c2bf6f4217a649ff2d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Apr 2025 06:40:54 +0000 Subject: [PATCH 20/42] remove commnent --- testing/python/module/test_bitblas_linear.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 1cd9de677..286aff237 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -42,11 +42,11 @@ def correctness_consistent(m, in_features, out_features, bias): bitblas.testing.torch_assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2) -# def test_correctness_consistent(): -# correctness_consistent(1, 1024, 1024, False) -# correctness_consistent(1, 1024, 1024, True) -# correctness_consistent(1024, 1024, 1024, True) -# correctness_consistent([1, 1024], 1024, 1024, True) +def test_correctness_consistent(): + correctness_consistent(1, 1024, 1024, False) + correctness_consistent(1, 1024, 1024, True) + correctness_consistent(1024, 1024, 1024, True) + correctness_consistent([1, 1024], 1024, 1024, True) def correctness_weight_only_dequantize( From 16cd073428231a9d4d478be69d44b5f2dad6b7f0 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:10:49 +0800 Subject: [PATCH 21/42] fix lint --- testing/python/operators/test_general_matmul_ops_backend.py | 1 - testing/python/tilelang/test_tilelang_mfma_macro_gemm.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/python/operators/test_general_matmul_ops_backend.py b/testing/python/operators/test_general_matmul_ops_backend.py index 23c312f41..d62238a95 100644 --- a/testing/python/operators/test_general_matmul_ops_backend.py +++ b/testing/python/operators/test_general_matmul_ops_backend.py @@ -36,7 +36,6 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la ) matmul = Matmul(config=matmul_config, enable_tuning=False) func = matmul.prim_func - import tilelang rt_mod, params = tl_lower(func) print(rt_mod) assert get_codegen_result(matmul) diff --git a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py index d0faf9fbe..d46e4286d 100644 --- a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py @@ -12,6 +12,7 @@ MatrixCoreIntrinEmitter,) from bitblas.base import simplify_prim_func from bitblas.tl.profiler import TLProfiler +from bitblas.tl.lower import tl_lower torch.manual_seed(0) From c121ab4a6d661694195b615d8884a2b657ea85b9 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:15:12 +0800 Subject: [PATCH 22/42] fix lint --- bitblas/base/roller/bestfit.py | 8 +- bitblas/base/roller/node.py | 5 +- .../base/roller/shape_inference/__init__.py | 2 +- bitblas/base/roller/shape_inference/common.py | 12 +- bitblas/base/roller/shape_inference/tir.py | 23 ++-- bitblas/gpu/base.py | 2 +- bitblas/gpu/fallback.py | 15 +-- bitblas/gpu/general_reduction.py | 65 +++------- bitblas/gpu/reduction.py | 34 ++--- bitblas/gpu/rmsnorm.py | 8 +- bitblas/gpu/transpose.py | 4 +- bitblas/ops/impl/convolution2d_impl.py | 20 ++- bitblas/tl/lower.py | 10 +- bitblas/utils/target_detector.py | 2 + .../BitNet/benchmark_model_10k_loops.py | 1 + integration/BitNet/configuration_bitnet.py | 10 +- integration/BitNet/eval_gpu_memory.py | 13 +- integration/BitNet/eval_utils.py | 17 ++- integration/BitNet/tokenization_bitnet.py | 74 +++++------ .../mlc_llm/test_weight_only_transform.py | 120 ++++++++---------- integration/pytorch/test_bitblas_linear.py | 4 +- testing/python/tilelang/test_simplifier.py | 1 + .../python/tilelang/test_tilelang_amd_gemm.py | 1 + tools/get_available_targets.py | 6 +- 24 files changed, 217 insertions(+), 240 deletions(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index ad8ec20a8..e4a5ea538 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """Benifit For BitBLAS Schedule""" + + class Block: + def __init__(self, start, end, is_free): self.start = start self.end = end @@ -21,6 +23,7 @@ def __repr__(self) -> str: class BestFit: + def __init__(self, align=32): self.limit = 0 self.list = [] @@ -39,8 +42,7 @@ def malloc(self, size) -> Block: if remain != 0: found.end -= remain self.list.insert( - self.list.index(found) + 1, Block(found.end, found.end + remain, True) - ) + self.list.index(found) + 1, Block(found.end, found.end + remain, True)) return found elif len(self.list) > 0 and self.list[-1].is_free: add = size - self.list[-1].size() diff --git a/bitblas/base/roller/node.py b/bitblas/base/roller/node.py index c9d648019..c0b49f6bf 100644 --- a/bitblas/base/roller/node.py +++ b/bitblas/base/roller/node.py @@ -232,8 +232,9 @@ def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): if rstep is None: rstep = {} shape = { - self.block_analyzer.get_output_buffers(block)[0].name: - [tvm.arith.ConstIntBound(0, val - 1) for val in tile] for block in self.schedule_stages + self.block_analyzer.get_output_buffers(block)[0].name: [ + tvm.arith.ConstIntBound(0, val - 1) for val in tile + ] for block in self.schedule_stages } return self.ana.infer(shape, rstep, targets) diff --git a/bitblas/base/roller/shape_inference/__init__.py b/bitblas/base/roller/shape_inference/__init__.py index 188aa0bb7..275287f65 100644 --- a/bitblas/base/roller/shape_inference/__init__.py +++ b/bitblas/base/roller/shape_inference/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .tir import get_analyzer_by_tir # pylint: disable=unused-import +from .tir import get_analyzer_by_tir # pylint: disable=unused-import diff --git a/bitblas/base/roller/shape_inference/common.py b/bitblas/base/roller/shape_inference/common.py index 730bbbeef..7de953cc3 100644 --- a/bitblas/base/roller/shape_inference/common.py +++ b/bitblas/base/roller/shape_inference/common.py @@ -8,16 +8,21 @@ class Statement(): - def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): + + def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, + range_map: OrderedDict): self.output = output self.dependent_region = dependent_region self.var_map = var_map self.range_map = range_map + def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + class InputShapeInference(): + def __init__(self, deps: List[Statement]): self.deps = deps @@ -34,7 +39,7 @@ def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, i for name, regions in dep.dependent_region.items(): for region in regions: bounds = [ana.const_int_bound(index) for index in region] - if name in shape: # simply merge two bounds + if name in shape: # simply merge two bounds bounds = [_merge_two_bounds(x, y) for x, y in zip(shape[name], bounds)] shape[name] = bounds @@ -44,7 +49,7 @@ def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, i def infer(self, shape, rstep: Dict[str, int] = {}): if isinstance(shape, (list, tuple)): - shape = {"output0" : [arith.ConstIntBound(0, val - 1) for val in shape]} + shape = {"output0": [arith.ConstIntBound(0, val - 1) for val in shape]} shape = self._infer(shape, rstep) return shape @@ -63,4 +68,3 @@ def get_input_exprs(self, output_exprs): input_expr = [ana.simplify(index) for index in region] result[name] = input_expr return result - diff --git a/bitblas/base/roller/shape_inference/tir.py b/bitblas/base/roller/shape_inference/tir.py index 35bf0b7d8..234f4cd64 100644 --- a/bitblas/base/roller/shape_inference/tir.py +++ b/bitblas/base/roller/shape_inference/tir.py @@ -8,6 +8,7 @@ class Statement: + def __init__(self, block_analyzer, block: BlockRV): self.block_analyzer = block_analyzer self.block = block @@ -79,6 +80,7 @@ def __repr__(self): class DependencyAnalysis(object): + def __init__(self, deps): self.deps = deps # issue: duplicate name when we have two same ops. @@ -90,8 +92,8 @@ def _construct_unique_name2dep(self, deps): This is a workaround for the issue that we have two same ops' fuse case. See https://github.com/apache/tvm/issues/16433 """ - _names:Set = set() - name2dep:Mapping = {} + _names: Set = set() + name2dep: Mapping = {} for dep in deps: output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] base_name = output_buffer.name @@ -105,7 +107,7 @@ def _construct_unique_name2dep(self, deps): _names.add(base_name) name2dep[base_name] = dep return name2dep - + def get_or_create_node(self, name): if name not in self.mapping: self.mapping[name] = TensorDepNode(name) @@ -114,8 +116,7 @@ def get_or_create_node(self, name): def traverse_dependencies(self, compute): if isinstance(compute, Statement): node = self.get_or_create_node( - compute.block_analyzer.get_output_buffers(compute.block)[0].name - ) + compute.block_analyzer.get_output_buffers(compute.block)[0].name) # Loop through input tensors for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): # Get the input node @@ -169,6 +170,7 @@ def _find_path_recursive(self, current_node, target_name, visited, path): class InputShapeInference: + def __init__(self, deps: List[Statement]): self.deps = deps self.target_mapping = {} @@ -242,9 +244,10 @@ def construct_dependency_target(self, targets: Tuple[str]): self.target_mapping[targets] = input_vars, mapping return input_vars, mapping - def infer( - self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int] = {}, targets=None - ): + def infer(self, + shape: Dict[str, List[arith.ConstIntBound]], + rstep: Dict[str, int] = {}, + targets=None): compute_targets = tuple(shape.keys()) input_vars, mapping = self.construct_dependency_target(compute_targets) ana = arith.Analyzer() @@ -257,8 +260,7 @@ def infer( # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. if ax.var.name in rstep: bound = arith.ConstIntBound( - int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1) - ) + int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1)) else: bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) ana.update(ax.var, bound, True) @@ -318,6 +320,7 @@ def get_input_exprs(self, output_exprs): def region_exist_in_list(a, list) -> bool: + def expr_is_same(a, b) -> bool: if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): return a.value == b.value diff --git a/bitblas/gpu/base.py b/bitblas/gpu/base.py index 3bf927244..78e06658b 100644 --- a/bitblas/gpu/base.py +++ b/bitblas/gpu/base.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # /* Modifications Copyright (c) Microsoft. */ # The code below is mostly copied from apache/tvm base.py in dlight. """Base schedule rule for GPU operators.""" diff --git a/bitblas/gpu/fallback.py b/bitblas/gpu/fallback.py index 3711d3682..420473036 100644 --- a/bitblas/gpu/fallback.py +++ b/bitblas/gpu/fallback.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm fallback.py in dlight. # pylint: disable=missing-docstring @@ -61,15 +61,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring dom_kind = block.dom_kind() block = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): continue for loop, iter_type in zip(sch.get_loops(block), dom_kind): @@ -92,4 +86,3 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.decompose_reduction(block, r_loop) return sch - \ No newline at end of file diff --git a/bitblas/gpu/general_reduction.py b/bitblas/gpu/general_reduction.py index cc03acd99..96fa4b2a0 100644 --- a/bitblas/gpu/general_reduction.py +++ b/bitblas/gpu/general_reduction.py @@ -47,10 +47,8 @@ def apply( # pylint: disable=too-many-locals num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), + lambda *iters: ([tir.const(0, iters[0].dtype)] * + (len(dom_kind) - num_last_block_iter) + list(iters)), ndim=num_last_block_iter, ) sch.transform_block_layout(block_infos[-1].block_rv, index_map) @@ -117,15 +115,10 @@ def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many dom_kind = block.dom_kind() block_rv = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block_rv) - ] - ) - or len(sch.get_loops(block.block_rv)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ]) or len(sch.get_loops(block.block_rv)) == 0): continue for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): @@ -248,15 +241,10 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many dom_kind = block.dom_kind() block_rv = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block_rv) - ] - ) - or len(sch.get_loops(block.block_rv)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ]) or len(sch.get_loops(block.block_rv)) == 0): continue for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): @@ -276,10 +264,8 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), + lambda *iters: ([tir.const(0, iters[0].dtype)] * + (len(dom_kind) - num_last_block_iter) + list(iters)), ndim=num_last_block_iter, ) sch.transform_block_layout(block_infos[-1].block_rv, index_map) @@ -295,13 +281,11 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many vthread_axis = [] thread_axis = [] inner_axis = [] - for s_loop, block_factor, step_factor, thread_factor in zip( - s_loops, block_factors, step_factors, thread_factors - ): + for s_loop, block_factor, step_factor, thread_factor in zip(s_loops, block_factors, + step_factors, thread_factors): block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) vthread_loop, inner_loop = sch.split( - inner_loop, factors=[None, thread_factor * step_factor] - ) + inner_loop, factors=[None, thread_factor * step_factor]) thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor]) block_axis.append(block_loop) vthread_axis.append(vthread_loop) @@ -317,13 +301,8 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many vthread_axis = list(reversed(vthread_axis)) # inner virtual thread first axis_order = ( - block_axis - + vthread_axis - + thread_axis - + reduce_outer_axis - + reduce_inner_axis - + inner_axis - ) + block_axis + vthread_axis + thread_axis + reduce_outer_axis + reduce_inner_axis + + inner_axis) sch.reorder(*axis_order) blck_fused = sch.fuse(*block_axis) @@ -362,8 +341,8 @@ def prod(iterable): return reduce(lambda x, y: x * y, iterable, 1) _, tx, tv = sch.split( - sch.fuse(*loops[dim_offset:]), factors=[None, int(prod(thread_factors)), vectorize] - ) + sch.fuse(*loops[dim_offset:]), factors=[None, + int(prod(thread_factors)), vectorize]) sch.vectorize(tv) sch.bind(tx, "threadIdx.x") @@ -407,10 +386,8 @@ def prod(iterable): num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), + lambda *iters: ([tir.const(0, iters[0].dtype)] * + (len(dom_kind) - num_last_block_iter) + list(iters)), ndim=num_last_block_iter, ) sch.transform_block_layout(block_infos[-1].block_rv, index_map) diff --git a/bitblas/gpu/reduction.py b/bitblas/gpu/reduction.py index 9d6aada75..a90da0558 100644 --- a/bitblas/gpu/reduction.py +++ b/bitblas/gpu/reduction.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm reduction.py in dlight. """A rule for reduction. """ @@ -43,9 +43,9 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: if not isinstance(buffer_store.value, tir.Add): return None if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, ): return None return buffer_store.value.b @@ -81,11 +81,8 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- block_stmt = sch.get(block) # Step 1. Check reduction block - if ( - (not block_info.is_reduction()) - or len(block_stmt.writes) != 1 - or _get_reduction_expr(block_stmt) is None - ): + if ((not block_info.is_reduction()) or len(block_stmt.writes) != 1 or + _get_reduction_expr(block_stmt) is None): return None # Step 2. Normalize the block, merge spatial and reduction iters is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize( @@ -100,13 +97,11 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- return None # Step 3. Do the scheduling if is_inner_reduction: - self._sch_inner_reduction( - sch, target, block, c_factor, epilogue, loop_order, s_split_index - ) + self._sch_inner_reduction(sch, target, block, c_factor, epilogue, loop_order, + s_split_index) else: - self._sch_inner_spatial( - sch, target, block, block_info, c_factor, epilogue, loop_order, s_split_index - ) + self._sch_inner_spatial(sch, target, block, block_info, c_factor, epilogue, loop_order, + s_split_index) return sch def _normalize( # pylint: disable=too-many-branches @@ -185,8 +180,7 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments # pylint: disable=invalid-name _, r, _ = sch.get_loops(block) (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking - target, [sch.get(r)] - ) + target, [sch.get(r)]) _, tx = sch.split(r, factors=[None, len_tx]) # Schedule the RF block @@ -208,8 +202,7 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments new_order_s = [s[loop_order[i]] for i in range(len(s))] sch.reorder(*new_order_s) new_order_s[s_split_index], c = sch.split( - new_order_s[s_split_index], factors=[None, unroll_spatial_factor] - ) + new_order_s[s_split_index], factors=[None, unroll_spatial_factor]) sch.reorder(*new_order_s, c) s = sch.fuse(*new_order_s) sch.reorder(s, tx, c) @@ -270,8 +263,7 @@ def _sch_inner_spatial( new_order_s = [s[loop_order[i]] for i in range(len(s))] sch.reorder(*new_order_s) new_order_s[s_split_index], c = sch.split( - new_order_s[s_split_index], factors=[None, unroll_spatial_factor] - ) + new_order_s[s_split_index], factors=[None, unroll_spatial_factor]) sch.reorder(*new_order_s, c) s = sch.fuse(*new_order_s) sch.reorder(s, c, r) diff --git a/bitblas/gpu/rmsnorm.py b/bitblas/gpu/rmsnorm.py index 0d8b37998..bd27c0219 100644 --- a/bitblas/gpu/rmsnorm.py +++ b/bitblas/gpu/rmsnorm.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm rmsnorm.py in dlight. # pylint: disable=missing-docstring @@ -117,8 +117,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring block_loop, loops = sch.get_loops(block=read) thread_loop, _, _ = sch.split( - loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True - ) + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True) sch.bind(block_loop, thread_axis="blockIdx.x") sch.bind(thread_loop, thread_axis="threadIdx.x") sch.vectorize(sch.get_loops(block=read)[-1]) @@ -129,8 +128,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) block_loop, loops = sch.get_loops(block=norm) thread_loop, _, _ = sch.split( - loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True - ) + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True) sch.bind(thread_loop, thread_axis="threadIdx.x") sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) diff --git a/bitblas/gpu/transpose.py b/bitblas/gpu/transpose.py index 6dc025c07..8dc04f10c 100644 --- a/bitblas/gpu/transpose.py +++ b/bitblas/gpu/transpose.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm transpose.py in dlight. """Reduction rule for operators including softmax, layer norm, RMS norm, etc""" @@ -82,7 +82,7 @@ def apply( # pylint: disable=too-many-locals prologue = None # the optional decoding block if transpose_block_idx > 0: - spatials = try_inline_contiguous_spatial(sch, blocks[: transpose_block_idx - 1]) + spatials = try_inline_contiguous_spatial(sch, blocks[:transpose_block_idx - 1]) assert len(spatials) == 0 prologue = blocks[transpose_block_idx - 1].block_rv diff --git a/bitblas/ops/impl/convolution2d_impl.py b/bitblas/ops/impl/convolution2d_impl.py index c7d21d7c8..f637fd8c1 100644 --- a/bitblas/ops/impl/convolution2d_impl.py +++ b/bitblas/ops/impl/convolution2d_impl.py @@ -52,9 +52,13 @@ def conv2d_nhwc_ohwi( C = te.compute( out_shape, lambda n, h, w, f: te.sum( - pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), - c,].astype(accum_dtype) * B[f, kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( - dilation_w), c].astype(accum_dtype), + pad[ + n, + h * stride_h + kh * tir.any(dilation_h), + w * stride_w + kw * tir.any(dilation_w), + c, + ].astype(accum_dtype) * B[f, kh - 1 - tir.any(dilation_h), kw - 1 - tir.any(dilation_w), + c].astype(accum_dtype), axis=[kh, kw, c], ), name="C", @@ -117,9 +121,13 @@ def conv2d_nhwc_hwio( C = te.compute( out_shape, lambda n, h, w, f: te.sum( - pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), - c,].astype(accum_dtype) * B[kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( - dilation_w), c, f].astype(accum_dtype), + pad[ + n, + h * stride_h + kh * tir.any(dilation_h), + w * stride_w + kw * tir.any(dilation_w), + c, + ].astype(accum_dtype) * B[kh - 1 - tir.any(dilation_h), kw - 1 - tir.any(dilation_w), c, + f].astype(accum_dtype), axis=[kh, kw, c], ), name="C", diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py index a8a552699..86507cf3e 100644 --- a/bitblas/tl/lower.py +++ b/bitblas/tl/lower.py @@ -4,16 +4,16 @@ from tvm import tir from tvm.target import Target + def tl_lower( func_or_mod: Union[tir.PrimFunc, tvm.IRModule], target: Union[str, Target] = "auto", target_host: Optional[Union[str, Target]] = None, runtime_only=False, ): - with tvm.transform.PassContext( - config={ - "tl.disable_dynamic_tail_split": False, - }): + with tvm.transform.PassContext(config={ + "tl.disable_dynamic_tail_split": False, + }): result = tilelang.lower( func_or_mod, target=target, @@ -27,4 +27,4 @@ def tl_lower( if runtime_only is True: return result.rt_mod else: - return result.rt_mod, result.params \ No newline at end of file + return result.rt_mod, result.params diff --git a/bitblas/utils/target_detector.py b/bitblas/utils/target_detector.py index 71d6dcc1f..9e3427817 100644 --- a/bitblas/utils/target_detector.py +++ b/bitblas/utils/target_detector.py @@ -23,6 +23,7 @@ "NVIDIA PG506-232": "NVIDIA A100", } + def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): """ Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. @@ -52,6 +53,7 @@ def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): return gpus[gpu_id] + def find_best_match(tags, query): """ Finds the best match for a query within a list of tags using fuzzy string matching. diff --git a/integration/BitNet/benchmark_model_10k_loops.py b/integration/BitNet/benchmark_model_10k_loops.py index 3a838c2ca..47f423b70 100644 --- a/integration/BitNet/benchmark_model_10k_loops.py +++ b/integration/BitNet/benchmark_model_10k_loops.py @@ -18,6 +18,7 @@ seq_len = args.seq_len batch_size = args.batch_size + def profile(model, input_data): import time diff --git a/integration/BitNet/configuration_bitnet.py b/integration/BitNet/configuration_bitnet.py index 2f7f7aa7f..5f4937b87 100644 --- a/integration/BitNet/configuration_bitnet.py +++ b/integration/BitNet/configuration_bitnet.py @@ -22,7 +22,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} @@ -183,13 +182,14 @@ def _rope_scaling_validation(self): if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}" - ) + f"got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") \ No newline at end of file + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, + float) or rope_scaling_factor <= 1.0: + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/integration/BitNet/eval_gpu_memory.py b/integration/BitNet/eval_gpu_memory.py index b67106eeb..92ce14930 100644 --- a/integration/BitNet/eval_gpu_memory.py +++ b/integration/BitNet/eval_gpu_memory.py @@ -11,6 +11,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) + def profile(model, input_data): import time @@ -34,22 +35,20 @@ def get_runtime(num_repeats=1): times = get_runtime(num_repeats) return np.mean(times) + def main(): model = BitnetForCausalLM.from_pretrained( '1bitLLM/bitnet_b1_58-3B', device_map='auto', - low_cpu_mem_usage=True, + low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() - print( - f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB" - ) + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB") with torch.no_grad(): model._post_process_weights() - print( - f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB" - ) + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + if __name__ == '__main__': main() diff --git a/integration/BitNet/eval_utils.py b/integration/BitNet/eval_utils.py index a7a57dd8a..01621c89a 100644 --- a/integration/BitNet/eval_utils.py +++ b/integration/BitNet/eval_utils.py @@ -11,17 +11,24 @@ def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) + def get_test_dataset(dataset_name, tokenizer, seqlen=2048): if dataset_name == "wikitext2": testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') testdata = "".join(testdata['text']).split('\n') elif dataset_name == "c4": - testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text'] + testdata = load_dataset( + 'allenai/c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation')['text'] else: raise NotImplementedError - + testdata = [item for item in testdata if item != ""] - tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata] + tokenized_text = [ + tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] + for item in testdata + ] data, doc = [], [tokenizer.bos_token_id] for sen in tokenized_text: @@ -37,6 +44,7 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048): class LMEvalAdaptor(BaseLM): + def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): super().__init__() @@ -129,5 +137,4 @@ def _model_call(self, inps): def _model_generate(self, context, max_length, eos_token_id): return self.model.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False - ) \ No newline at end of file + context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/integration/BitNet/tokenization_bitnet.py b/integration/BitNet/tokenization_bitnet.py index 09b482f72..202559fae 100644 --- a/integration/BitNet/tokenization_bitnet.py +++ b/integration/BitNet/tokenization_bitnet.py @@ -17,7 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tokenization classes for LLaMA.""" import os from shutil import copyfile @@ -29,7 +28,6 @@ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer from transformers.utils import logging - if TYPE_CHECKING: from transformers.tokenization_utils_base import TextInput @@ -39,10 +37,12 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -159,10 +159,14 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken( + bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken( + eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken( + unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken( + pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token if legacy is None: logger.warning_once( @@ -170,8 +174,7 @@ def __init__( " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565" - ) + " https://github.com/huggingface/transformers/pull/24565") legacy = True self.legacy = legacy @@ -211,7 +214,8 @@ def get_spm_processor(self, from_slow=False): with open(self.vocab_file, "rb") as f: sp_model = f.read() - model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model_pb2 = import_protobuf( + f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False @@ -257,7 +261,8 @@ def tokenize(self, text: "TextInput", **kwargs) -> List[str]: tokens = super().tokenize(text, **kwargs) - if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + if len(tokens + ) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @@ -279,7 +284,7 @@ def _tokenize(self, text, **kwargs): # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" @@ -327,11 +332,12 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) + out_vocab_file = os.path.join(save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"]) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile( + self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -351,9 +357,10 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: + def get_special_tokens_mask(self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -371,26 +378,19 @@ def get_special_tokens_mask( """ if already_has_special_tokens: return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) bos_token_id = [1] if self.add_bos_token else [] eos_token_id = [1] if self.add_eos_token else [] if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return ( - bos_token_id - + ([0] * len(token_ids_0)) - + eos_token_id - + bos_token_id - + ([0] * len(token_ids_1)) - + eos_token_id - ) + return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + + ([0] * len(token_ids_1)) + eos_token_id) - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + def create_token_type_ids_from_sequences(self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT sequence pair mask has the following format: @@ -473,10 +473,10 @@ def default_chat_template(self): "{% elif message['role'] == 'assistant' %}" "{{ ' ' + content.strip() + ' ' + eos_token }}" "{% endif %}" - "{% endfor %}" - ) - template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + "{% endfor %}") + template = template.replace("USE_DEFAULT_PROMPT", + "true" if self.use_default_system_prompt else "false") default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) - return template \ No newline at end of file + return template diff --git a/integration/mlc_llm/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py index 7f6887277..272314ec8 100644 --- a/integration/mlc_llm/test_weight_only_transform.py +++ b/integration/mlc_llm/test_weight_only_transform.py @@ -36,7 +36,8 @@ def get_default_result(ref_mod, input_tensors, target, device): bitblas.gpu.Reduction(), bitblas.gpu.GeneralReduction(), bitblas.gpu.Fallback(), - )(ref_mod) + )( + ref_mod) ref_mod = tvm.tir.transform.MakePackedAPI()(ref_mod) ex = relax.build(ref_mod, target) vm = relax.VirtualMachine(ex, device) @@ -55,8 +56,10 @@ def get_fast_tune_result(ref_mod, input_tensors, target, device): def test_lop3_transform(): + @I.ir_module class Before: + @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul8_add1( lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), @@ -67,31 +70,26 @@ def fused_fused_decode3_fused_NT_matmul8_add1( T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") - NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), 1, T.int64(4096)), "float16" - ) + NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), 1, T.int64(4096)), + "float16") # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) + decode_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), + "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): with T.block("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), + decode_intermediate_intermediate[v_i, v_j] = (T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] + T.uint32(15), + ), + ) - T.float16(7)) * lv48[v_i, v_j // T.int64(32)] for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) @@ -103,10 +101,8 @@ def fused_fused_decode3_fused_NT_matmul8_add1( with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] - * decode_intermediate_intermediate[v_i2, v_k] - ) + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k]) @R.function def main( @@ -134,8 +130,8 @@ def main( relax_mod = AnnotateDecodeInformation()(relax_mod) with dispatch_target: relax_mod = WeightOnlyLayoutPropagation( - transform_level=0, faster_conversion=False - )(relax_mod) + transform_level=0, faster_conversion=False)( + relax_mod) input_tensors = get_dummy_input_arrays(ref_mod["main"], tvm.cpu()) @@ -157,22 +153,22 @@ def main( print("relax ", res) -def test_matmul_transform(transform_level = 2): +def test_matmul_transform(transform_level=2): @I.ir_module class Before: + @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul8_add1( - p_lv41: T.handle, - lv47: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), - p_output0: T.handle, + p_lv41: T.handle, + lv47: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + p_output0: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") - NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), 1, T.int64(4096)), "float16" - ) + NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), 1, T.int64(4096)), + "float16") for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): @@ -182,9 +178,8 @@ def fused_fused_decode3_fused_NT_matmul8_add1( with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] * lv47[v_i2, v_k] - ) + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * lv47[v_i2, v_k]) @R.function def main( @@ -210,8 +205,8 @@ def main( relax_mod = AnnotateDecodeInformation()(relax_mod) with dispatch_target: relax_mod = WeightOnlyLayoutPropagation( - transform_level=transform_level, faster_conversion=False - )(relax_mod) + transform_level=transform_level, faster_conversion=False)( + relax_mod) input_tensors = get_dummy_input_arrays(ref_mod["main"], tvm.cpu()) @@ -237,6 +232,7 @@ def test_dequantize_matmul_transform(transform_level=2): @I.ir_module class Before: + @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul8_add1( lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), @@ -246,34 +242,28 @@ def fused_fused_decode3_fused_NT_matmul8_add1( ): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() - lv41 = T.match_buffer( - p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ) - NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ) + lv41 = T.match_buffer(p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16") + NT_matmul_intermediate = T.match_buffer(p_output0, + (T.int64(1), T.int64(1), T.int64(4096)), + "float16") # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) + decode_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), + "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): with T.block("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), + decode_intermediate_intermediate[v_i, v_j] = (T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] + T.uint32(15), + ), + ) - T.float16(7)) * lv48[v_i, v_j // T.int64(32)] for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) @@ -285,10 +275,8 @@ def fused_fused_decode3_fused_NT_matmul8_add1( with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] - * decode_intermediate_intermediate[v_i2, v_k] - ) + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k]) @R.function def main( @@ -319,8 +307,8 @@ def main( relax_mod = AnnotateDecodeInformation()(relax_mod) with dispatch_target: relax_mod = WeightOnlyLayoutPropagation( - transform_level=transform_level, faster_conversion=False - )(relax_mod) + transform_level=transform_level, faster_conversion=False)( + relax_mod) input_tensors = get_dummy_input_arrays(ref_mod["main"], device) print(relax_mod) print("=======================ref llvm result=======================") @@ -337,9 +325,7 @@ def main( ref_res = get_fast_tune_result(ref_mod, input_tensors, dispatch_target, device) print("ref_mod", ref_res) print(relax_mod) - bitblas_res = get_fast_tune_result( - relax_mod, input_tensors, dispatch_target, device - ) + bitblas_res = get_fast_tune_result(relax_mod, input_tensors, dispatch_target, device) print("bitblas_res", bitblas_res) diff --git a/integration/pytorch/test_bitblas_linear.py b/integration/pytorch/test_bitblas_linear.py index 18a969fc5..fed036839 100644 --- a/integration/pytorch/test_bitblas_linear.py +++ b/integration/pytorch/test_bitblas_linear.py @@ -92,8 +92,8 @@ def test_profile_performance(m, in_features, out_features, bias): torch_latency = profile(linear_bitblas, input_data) bitblas_latency = linear_bitblas.bitblas_matmul.profile_latency() print(f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}") - assert (abs(torch_latency - bitblas_latency) / torch_latency < - 0.1), f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" + assert (abs(torch_latency - bitblas_latency) / torch_latency + < 0.1), f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" if __name__ == "__main__": diff --git a/testing/python/tilelang/test_simplifier.py b/testing/python/tilelang/test_simplifier.py index 678006ec9..3b202ad50 100644 --- a/testing/python/tilelang/test_simplifier.py +++ b/testing/python/tilelang/test_simplifier.py @@ -4,6 +4,7 @@ from bitblas.tl.lower import tl_lower from bitblas.tl.profiler import TLProfiler + def modify( with_B: bool = False, with_bias: bool = False, diff --git a/testing/python/tilelang/test_tilelang_amd_gemm.py b/testing/python/tilelang/test_tilelang_amd_gemm.py index a57c6adc2..e63a78370 100644 --- a/testing/python/tilelang/test_tilelang_amd_gemm.py +++ b/testing/python/tilelang/test_tilelang_amd_gemm.py @@ -7,6 +7,7 @@ from bitblas.tl.lower import tl_lower from bitblas.tl.profiler import TLProfiler + def matmul( M, N, diff --git a/tools/get_available_targets.py b/tools/get_available_targets.py index 2c2753d7a..fb685b1b4 100644 --- a/tools/get_available_targets.py +++ b/tools/get_available_targets.py @@ -4,14 +4,16 @@ from bitblas.utils import get_all_nvidia_targets from tabulate import tabulate + def main(): # Get all available Nvidia targets targets = get_all_nvidia_targets() - + # Print available targets to console in a table format table = [[i + 1, target] for i, target in enumerate(targets)] headers = ["Index", "Target"] print(tabulate(table, headers, tablefmt="pretty")) + if __name__ == "__main__": - main() \ No newline at end of file + main() From 72d13e788eccd20ca9b03d066c416c52cbbff6e8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:23:43 +0800 Subject: [PATCH 23/42] fix lint --- README.md | 2 +- bitblas/base/roller/bestfit.py | 2 +- bitblas/base/roller/shape_inference/tir.py | 2 +- docs/ExtendOperatorsWithDSL.md | 2 +- docs/Installation.md | 2 +- docs/PythonAPI.md | 2 +- docs/QuickStart.md | 2 +- integration/GPTQModel/README.md | 2 +- maint/scripts/apply_mit_license.sh | 2 +- pyproject.toml | 2 +- testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu | 4 ++-- testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp | 10 +++++----- .../lop3_type_conversion/lowprecision_to_float16.cu | 4 ++-- .../cpp/lop3_type_conversion/lowprecision_to_int4.cu | 4 ++-- .../cpp/lop3_type_conversion/lowprecision_to_int8.cu | 4 ++-- tutorials/1.fast_and_efficient_codegen.ipynb | 2 +- tutorials/3.fast_decoding.ipynb | 2 +- tutorials/6.tile-language.ipynb | 12 +++++++++--- 18 files changed, 34 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index ecf76ba5b..a6f7cc673 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ We are continuously expanding the support matrix. If you have any specific requi - **Python Version**: >= 3.8 - **CUDA Version**: >= 11.0 -The easiest way to install BitBLAS is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal. +The easiest way to install BitBLAS is directly from the PyPi using pip. To install the latest version, run the following command in your terminal. ```bash pip install bitblas diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index e4a5ea538..3006d4e2e 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Benifit For BitBLAS Schedule""" +"""Benefit For BitBLAS Schedule""" class Block: diff --git a/bitblas/base/roller/shape_inference/tir.py b/bitblas/base/roller/shape_inference/tir.py index 234f4cd64..eb5acd0cc 100644 --- a/bitblas/base/roller/shape_inference/tir.py +++ b/bitblas/base/roller/shape_inference/tir.py @@ -384,7 +384,7 @@ def fvisit(x): with T.init(): T_dense_reindex[T.int64(0), v0, v1] = T.float16(0) T_dense_reindex[T.int64(0), v0, v1] = T_dense_reindex[T.int64(0), v0, v1] + A_reindex[T.int64(0), v0, v2] * B_reindex[T.int64(0), v1, v2] - For exmaple, the T_dense_reindex has three dims, however there're only two spatial loops. + For example, the T_dense_reindex has three dims, however there're only two spatial loops. """ continue index.append(expr) diff --git a/docs/ExtendOperatorsWithDSL.md b/docs/ExtendOperatorsWithDSL.md index ec62356b5..9a29b1528 100644 --- a/docs/ExtendOperatorsWithDSL.md +++ b/docs/ExtendOperatorsWithDSL.md @@ -144,7 +144,7 @@ scheduled_ir_module = fast_tune_with_dynamic_range( } ) -# fianlly, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. +# finally, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. ''' @IRModule class MatmulNT: diff --git a/docs/Installation.md b/docs/Installation.md index a50d478ef..ab341c4f7 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -7,7 +7,7 @@ - **Python Version**: >= 3.8 - **CUDA Version**: >= 11.0 -The easiest way to install BitBLAS is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal. +The easiest way to install BitBLAS is directly from the PyPi using pip. To install the latest version, run the following command in your terminal. **Note**: Currently, BitBLAS whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build BitBLAS from source](https://github.com/microsoft/BitBLAS/blob/main/docs/Installation.md#building-from-source).** diff --git a/docs/PythonAPI.md b/docs/PythonAPI.md index 45f23df2b..1e248671f 100644 --- a/docs/PythonAPI.md +++ b/docs/PythonAPI.md @@ -184,7 +184,7 @@ Returns: The output tensor. #### `init_params()` -Initializes parameters handles (convert constant params into ctypes void pointer) for the computation. We currently put this fuction in the forward function, so you do not need to call it manually. But if you lift this function out of the forward function, you can call it manually to aoid the transformation. +Initializes parameters handles (convert constant params into ctypes void pointer) for the computation. We currently put this function in the forward function, so you do not need to call it manually. But if you lift this function out of the forward function, you can call it manually to aoid the transformation. #### `load_and_transform_weight(weight, scales=None, zeros=None, bias=None)` diff --git a/docs/QuickStart.md b/docs/QuickStart.md index e808f1ce2..b2e8d6fa8 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -116,7 +116,7 @@ print("BitBLAS output:", output_tensor) torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-2) ``` -The init stage of the ```bitblas.Matmul``` class will take minutes to finish, as it will use hardware informations to do a one-time kernel library initialization. +The init stage of the ```bitblas.Matmul``` class will take minutes to finish, as it will use hardware information to do a one-time kernel library initialization. ## Example: bitblas.Linear module for PyTorch diff --git a/integration/GPTQModel/README.md b/integration/GPTQModel/README.md index ee4c14399..ebeb60b77 100644 --- a/integration/GPTQModel/README.md +++ b/integration/GPTQModel/README.md @@ -1,3 +1,3 @@ -BitBLAS has been fully integraded into [GPTQModel](https://github.com/ModelCloud/GPTQModel) since v0.9.1. +BitBLAS has been fully integrated into [GPTQModel](https://github.com/ModelCloud/GPTQModel) since v0.9.1. Please reference [sample code](https://github.com/ModelCloud/GPTQModel/blob/main/examples/inference/run_with_different_backends.py) for usage on using `backend=BACKEND.BITBLAS`within GPTQModel. \ No newline at end of file diff --git a/maint/scripts/apply_mit_license.sh b/maint/scripts/apply_mit_license.sh index 3d27d6cf0..c7fd7d42d 100755 --- a/maint/scripts/apply_mit_license.sh +++ b/maint/scripts/apply_mit_license.sh @@ -3,7 +3,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -echo "Add MIT liscense boilerplate..." +echo "Add MIT license boilerplate..." PWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # TO source code root pushd "${PWD}/../../" > /dev/null diff --git a/pyproject.toml b/pyproject.toml index 85fd7db04..190353ded 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ ignore = [ "E402", # star imports "F405", "F403", - # ambigous name + # ambiguous name "E741", # line too long "E501", diff --git a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu index 257f49a31..9e67953df 100644 --- a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu +++ b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu @@ -6,9 +6,9 @@ #include #include "i4matmul.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp index a12a57dcd..0c3e2d0cc 100644 --- a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -362,7 +362,7 @@ __global__ void Marlin( a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); } - // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependencies between // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. const int4* B_ptr[b_sh_wr_iters]; #pragma unroll @@ -429,7 +429,7 @@ __global__ void Marlin( auto fetch_to_registers = [&] (int k, int pipe) { // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the - // compiler and correspondingly a noticable drop in performance. + // compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; @@ -513,7 +513,7 @@ __global__ void Marlin( }; // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over - // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // the results. As the striped partitioning minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. auto global_reduce = [&] (bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. @@ -656,7 +656,7 @@ __global__ void Marlin( a_gl_rd += a_gl_rd_delta_o * stages; // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most - // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + // readable, other ways of writing the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; @@ -757,7 +757,7 @@ int marlin_cuda( cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); if (thread_k == -1 || thread_n == -1) { if (prob_m <= 16) { - // For small batchizes, better partioning is slightly more important than better compute utilization + // For small batchizes, better partitioning is slightly more important than better compute utilization thread_k = 128; thread_n = 128; } else { diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu index 7307ad1fe..42b1649ea 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu @@ -6,9 +6,9 @@ #include #include "fast_decoding.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu index d39a85dcd..ba095ec65 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu @@ -6,9 +6,9 @@ #include #include "fast_decoding.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu index 0a3b45a77..1d48358ee 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu @@ -6,9 +6,9 @@ #include #include "fast_decoding.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/tutorials/1.fast_and_efficient_codegen.ipynb b/tutorials/1.fast_and_efficient_codegen.ipynb index 4c34d06e6..7275ffc09 100644 --- a/tutorials/1.fast_and_efficient_codegen.ipynb +++ b/tutorials/1.fast_and_efficient_codegen.ipynb @@ -104,7 +104,7 @@ } ], "source": [ - "# import fast tunning related toolkits\n", + "# import fast tuning related toolkits\n", "from bitblas.base.roller.policy import DefaultPolicy\n", "from bitblas.base.arch import CUDA\n", "from bitblas.base.utils import apply_and_build\n", diff --git a/tutorials/3.fast_decoding.ipynb b/tutorials/3.fast_decoding.ipynb index f8b8b093e..2bf26393f 100644 --- a/tutorials/3.fast_decoding.ipynb +++ b/tutorials/3.fast_decoding.ipynb @@ -8,7 +8,7 @@ "source": [ "# Fast Dequantizatoin\n", "\n", - "How to enbale fast dequantization (INT4/2/1 -> FP16/INT8) or (FP8 -> FP16)?\n", + "How to enable fast dequantization (INT4/2/1 -> FP16/INT8) or (FP8 -> FP16)?\n", "\n", "![image.png](./img/FastDequantization.png)" ] diff --git a/tutorials/6.tile-language.ipynb b/tutorials/6.tile-language.ipynb index 7826f0743..eb96f5161 100644 --- a/tutorials/6.tile-language.ipynb +++ b/tutorials/6.tile-language.ipynb @@ -7,7 +7,7 @@ "source": [ "# Tile Language in BitBLAS\n", "\n", - "More flexiable, More Efficient Tile Programming Languange compared with Triton\n", + "More flexiable, More Efficient Tile Programming Language compared with Triton\n", "\n", "## Features\n", "\n", @@ -26,6 +26,12 @@ "\n" ] }, + { + "cell_type": "markdown", + "id": "267bf3b1", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": 2, @@ -192,7 +198,7 @@ "id": "aec3ac4e-385f-44cb-b439-2357901e1d86", "metadata": {}, "source": [ - "TL also provide interface for users to manupulate the memory layout, pipeline and enable rasterization for better L2 Cache Locality. Here is an example of how to use the memory layout and rasterization:" + "TL also provide interface for users to manipulate the memory layout, pipeline and enable rasterization for better L2 Cache Locality. Here is an example of how to use the memory layout and rasterization:" ] }, { @@ -318,7 +324,7 @@ "id": "48f49319-ad63-4673-8130-b010dc8ba22e", "metadata": {}, "source": [ - "## If you want fine-grained control over dequantization at the thread leve" + "## If you want fine-grained control over dequantization at the thread level" ] }, { From 8f1c9d6b30780ee3dac562da8b40e67b4671b6b1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:28:13 +0800 Subject: [PATCH 24/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=9D=A1=E4=BB=B6?= =?UTF-8?q?=E5=88=A4=E6=96=AD=EF=BC=8C=E5=90=88=E5=B9=B6=E7=9B=B8=E4=BC=BC?= =?UTF-8?q?=E7=9A=84=E7=B1=BB=E5=9E=8B=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/shape_inference/tir.py | 4 +--- bitblas/gpu/gemv.py | 2 +- bitblas/gpu/utils.py | 4 +--- integration/BitNet/eval_utils.py | 5 +---- tutorials/2.auto_tensorization.ipynb | 1 - tutorials/4.dynamic_shape_codegen.ipynb | 1 - tutorials/5.ladder_end2end.ipynb | 4 +--- 7 files changed, 5 insertions(+), 16 deletions(-) diff --git a/bitblas/base/roller/shape_inference/tir.py b/bitblas/base/roller/shape_inference/tir.py index eb5acd0cc..1c614d0d3 100644 --- a/bitblas/base/roller/shape_inference/tir.py +++ b/bitblas/base/roller/shape_inference/tir.py @@ -343,9 +343,7 @@ def walk_indice(expr): return expr else: return None - elif isinstance(expr, tir.expr.ConstExpr): - return expr - elif isinstance(expr, tir.Var): + elif isinstance(expr, tir.expr.ConstExpr) or isinstance(expr, tir.Var): return expr elif isinstance(expr, tir.ProducerLoad): return None diff --git a/bitblas/gpu/gemv.py b/bitblas/gpu/gemv.py index 60a290a81..bed06a055 100644 --- a/bitblas/gpu/gemv.py +++ b/bitblas/gpu/gemv.py @@ -500,7 +500,7 @@ def apply( if not isinstance(len_S, int): TS, TR = 1, 64 - while TS * TR > target.max_num_threads: + while target.max_num_threads < TS * TR: if TS > 1: TS //= 2 else: diff --git a/bitblas/gpu/utils.py b/bitblas/gpu/utils.py index e3a5b6098..575624832 100644 --- a/bitblas/gpu/utils.py +++ b/bitblas/gpu/utils.py @@ -38,9 +38,7 @@ def suggest_threads_per_block( ) -> List[int]: if target.kind.name == "cuda": threads = 1024 - elif target.kind.name == "rocm": - threads = 256 - elif target.kind.name == "metal": + elif target.kind.name == "rocm" or target.kind.name == "metal": threads = 256 else: threads = 64 diff --git a/integration/BitNet/eval_utils.py b/integration/BitNet/eval_utils.py index 01621c89a..fc24fdee7 100644 --- a/integration/BitNet/eval_utils.py +++ b/integration/BitNet/eval_utils.py @@ -1,7 +1,6 @@ import torch import numpy as np -import torch.nn.functional as F from lm_eval.base import BaseLM from datasets import load_dataset @@ -81,9 +80,7 @@ def max_length(self): return 2048 elif "llama" in self.model_name: return 2048 # TODO: did not check this - elif "mpt" in self.model_name: - return 2048 - elif "falcon" in self.model_name: + elif "mpt" in self.model_name or "falcon" in self.model_name: return 2048 else: print(self.model.config) diff --git a/tutorials/2.auto_tensorization.ipynb b/tutorials/2.auto_tensorization.ipynb index ebf072959..cfd438852 100644 --- a/tutorials/2.auto_tensorization.ipynb +++ b/tutorials/2.auto_tensorization.ipynb @@ -21,7 +21,6 @@ "outputs": [], "source": [ "import bitblas\n", - "from bitblas import tvm\n", "from tvm import te, tir" ] }, diff --git a/tutorials/4.dynamic_shape_codegen.ipynb b/tutorials/4.dynamic_shape_codegen.ipynb index 69554f91b..27f382c59 100644 --- a/tutorials/4.dynamic_shape_codegen.ipynb +++ b/tutorials/4.dynamic_shape_codegen.ipynb @@ -19,7 +19,6 @@ "outputs": [], "source": [ "import bitblas\n", - "import torch\n", "\n", "# enabling debug output\n", "\n", diff --git a/tutorials/5.ladder_end2end.ipynb b/tutorials/5.ladder_end2end.ipynb index 50441da32..5fbb0125c 100644 --- a/tutorials/5.ladder_end2end.ipynb +++ b/tutorials/5.ladder_end2end.ipynb @@ -19,9 +19,7 @@ "id": "5bfda684-1ca2-4999-8609-505a3c974b7e", "metadata": {}, "outputs": [], - "source": [ - "import ladder" - ] + "source": [] }, { "cell_type": "code", From 3787236e891532fd009d6422e4194dbe6f7af3c7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:46:29 +0800 Subject: [PATCH 25/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20infer=20=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E7=9A=84=20rstep=20=E5=8F=82=E6=95=B0=E5=A4=84?= =?UTF-8?q?=E7=90=86=EF=BC=8C=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/shape_inference/common.py | 4 +++- bitblas/base/roller/shape_inference/tir.py | 11 +++++------ bitblas/gpu/general_reduction.py | 10 ++-------- bitblas/gpu/rmsnorm.py | 11 ++--------- integration/mlc_llm/test_weight_only_transform.py | 6 +++--- 5 files changed, 15 insertions(+), 27 deletions(-) diff --git a/bitblas/base/roller/shape_inference/common.py b/bitblas/base/roller/shape_inference/common.py index 7de953cc3..6cc6d228b 100644 --- a/bitblas/base/roller/shape_inference/common.py +++ b/bitblas/base/roller/shape_inference/common.py @@ -47,7 +47,9 @@ def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, i shape[name] = [c.max_value - c.min_value + 1 for c in bounds] return shape - def infer(self, shape, rstep: Dict[str, int] = {}): + def infer(self, shape, rstep: Dict[str, int] = None): + if rstep is None: + rstep = {} if isinstance(shape, (list, tuple)): shape = {"output0": [arith.ConstIntBound(0, val - 1) for val in shape]} shape = self._infer(shape, rstep) diff --git a/bitblas/base/roller/shape_inference/tir.py b/bitblas/base/roller/shape_inference/tir.py index 1c614d0d3..19d70857e 100644 --- a/bitblas/base/roller/shape_inference/tir.py +++ b/bitblas/base/roller/shape_inference/tir.py @@ -246,8 +246,10 @@ def construct_dependency_target(self, targets: Tuple[str]): def infer(self, shape: Dict[str, List[arith.ConstIntBound]], - rstep: Dict[str, int] = {}, + rstep: Dict[str, int] = None, targets=None): + if rstep is None: + rstep = {} compute_targets = tuple(shape.keys()) input_vars, mapping = self.construct_dependency_target(compute_targets) ana = arith.Analyzer() @@ -327,10 +329,7 @@ def expr_is_same(a, b) -> bool: return structural_equal(a, b) def region_is_same(a, b) -> bool: - for indice_a, indice_b in zip(a, b): - if not expr_is_same(indice_a, indice_b): - return False - return True + return all(expr_is_same(indice_a, indice_b) for indice_a, indice_b in zip(a, b)) return any([region_is_same(a, x) for x in list]) @@ -343,7 +342,7 @@ def walk_indice(expr): return expr else: return None - elif isinstance(expr, tir.expr.ConstExpr) or isinstance(expr, tir.Var): + elif isinstance(expr, (tir.expr.ConstExpr, tir.Var)): return expr elif isinstance(expr, tir.ProducerLoad): return None diff --git a/bitblas/gpu/general_reduction.py b/bitblas/gpu/general_reduction.py index 96fa4b2a0..fd6da8a71 100644 --- a/bitblas/gpu/general_reduction.py +++ b/bitblas/gpu/general_reduction.py @@ -188,10 +188,7 @@ def prod(iterable): dim_offset = ( len(reduce_inner_axis) + len(reduce_outer_axis) + 2 ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis - if input_region.buffer.name in config.vectorize: - vectorize = config.vectorize[input_region.buffer.name] - else: - vectorize = 1 + vectorize = config.vectorize.get(input_region.buffer.name, 1) loops = sch.get_loops(cache_shared) if len(loops) == dim_offset: @@ -326,10 +323,7 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many dim_offset = ( len(vthread_axis) + len(reduce_outer_axis) + 2 ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis - if input_region.buffer.name in config.vectorize: - vectorize = config.vectorize[input_region.buffer.name] - else: - vectorize = 1 + vectorize = config.vectorize.get(input_region.buffer.name, 1) loops = sch.get_loops(cache_shared) if len(loops) == dim_offset: diff --git a/bitblas/gpu/rmsnorm.py b/bitblas/gpu/rmsnorm.py index bd27c0219..941cf37d2 100644 --- a/bitblas/gpu/rmsnorm.py +++ b/bitblas/gpu/rmsnorm.py @@ -52,11 +52,7 @@ def identify_cast_or_load_block(block: Block) -> bool: if len(load.indices) != len(store.indices): return False - for lhs, rhs in zip(load.indices, store.indices): - if not lhs.same_as(rhs): - return False - - return True + return all(lhs.same_as(rhs) for lhs, rhs in zip(load.indices, store.indices)) def identify_rsqrt_block(block: Block) -> bool: @@ -84,10 +80,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring target: Target, _: bool, ) -> tir.Schedule: - if target.kind.name == "cuda": - num_tx = 512 - else: - num_tx = 64 + num_tx = 512 if target.kind.name == "cuda" else 64 sch = tir.Schedule(func) root = sch.get_block(name="root", func_name="main") diff --git a/integration/mlc_llm/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py index 272314ec8..bebf7b301 100644 --- a/integration/mlc_llm/test_weight_only_transform.py +++ b/integration/mlc_llm/test_weight_only_transform.py @@ -68,7 +68,7 @@ def fused_fused_decode3_fused_NT_matmul8_add1( p_output0: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() + T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), 1, T.int64(4096)), "float16") @@ -165,7 +165,7 @@ def fused_fused_decode3_fused_NT_matmul8_add1( p_output0: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() + T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), 1, T.int64(4096)), "float16") @@ -241,7 +241,7 @@ def fused_fused_decode3_fused_NT_matmul8_add1( p_output0: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() + T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16") NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), T.int64(4096)), From 24c414c0cc5e4eefa574f6b7cf36164b3369ee31 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:49:57 +0800 Subject: [PATCH 26/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20BestFit=20=E7=B1=BB?= =?UTF-8?q?=E7=9A=84=20malloc=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E6=9D=A1=E4=BB=B6=E5=88=A4=E6=96=AD=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/bestfit.py | 3 +-- bitblas/base/roller/policy/__init__.py | 4 ++-- bitblas/base/roller/shape_inference/__init__.py | 2 +- bitblas/gpu/reduction.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index 3006d4e2e..b87d385d0 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -33,8 +33,7 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size: - if not found or found.size() > block.size(): + if block.is_free and block.size() >= size and not found or found.size() > block.size(): found = block if found: found.is_free = False diff --git a/bitblas/base/roller/policy/__init__.py b/bitblas/base/roller/policy/__init__.py index 09ed1d51b..d8c5c0b19 100644 --- a/bitblas/base/roller/policy/__init__.py +++ b/bitblas/base/roller/policy/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .default import DefaultPolicy -from .tensorcore import TensorCorePolicy +from .default import DefaultPolicy # noqa: F401 +# from .tensorcore import TensorCorePolicy diff --git a/bitblas/base/roller/shape_inference/__init__.py b/bitblas/base/roller/shape_inference/__init__.py index 275287f65..2a709d434 100644 --- a/bitblas/base/roller/shape_inference/__init__.py +++ b/bitblas/base/roller/shape_inference/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .tir import get_analyzer_by_tir # pylint: disable=unused-import +# from .tir import get_analyzer_by_tir # pylint: disable=unused-import diff --git a/bitblas/gpu/reduction.py b/bitblas/gpu/reduction.py index a90da0558..a7430231b 100644 --- a/bitblas/gpu/reduction.py +++ b/bitblas/gpu/reduction.py @@ -135,7 +135,7 @@ def _normalize( # pylint: disable=too-many-branches s_loops.append(loop) if iter_to_info: - for var, info in iter_to_info.items(): + for _var, info in iter_to_info.items(): if info.kind == "S" and info.dom.extent == 1: s_loops.append(info.loop_rv) else: From f88b23825f40211bab4be33b062fcaf56c60342e Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:54:21 +0800 Subject: [PATCH 27/42] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=9C=AA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=9A=84=20TensorCorePolicy=20=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=EF=BC=8C=E4=BF=9D=E6=8C=81=E4=BB=A3=E7=A0=81=E6=95=B4=E6=B4=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/policy/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitblas/base/roller/policy/__init__.py b/bitblas/base/roller/policy/__init__.py index d8c5c0b19..750bca05e 100644 --- a/bitblas/base/roller/policy/__init__.py +++ b/bitblas/base/roller/policy/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .default import DefaultPolicy # noqa: F401 -# from .tensorcore import TensorCorePolicy +from .default import DefaultPolicy # noqa: F401 \ No newline at end of file From 9e0551c64844d532224bb238190e6bd7902d6d68 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 4 May 2025 23:55:04 +0800 Subject: [PATCH 28/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20BestFit=20=E7=B1=BB?= =?UTF-8?q?=E7=9A=84=20malloc=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E6=9D=A1=E4=BB=B6=E5=88=A4=E6=96=AD=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=80=BB=E8=BE=91=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/bestfit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index b87d385d0..dd21aa080 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -33,7 +33,7 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size and not found or found.size() > block.size(): + if block.is_free and block.size() >= size and (not found or found.size() > block.size()): found = block if found: found.is_free = False From a10cabf831de10a2dffc11a198e7f4b0662cde84 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:03:17 +0800 Subject: [PATCH 29/42] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E5=8C=96=E8=84=9A=E6=9C=AC=E4=BB=A5=E6=98=BE=E7=A4=BA=E6=9B=B4?= =?UTF-8?q?=E8=AF=A6=E7=BB=86=E7=9A=84=E6=9C=AA=E6=9A=82=E5=AD=98=E6=9B=B4?= =?UTF-8?q?=E6=94=B9=EF=BC=8C=E4=BE=BF=E4=BA=8E=E5=AE=A1=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- format.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/format.sh b/format.sh index 5d3056123..f8455f621 100755 --- a/format.sh +++ b/format.sh @@ -193,7 +193,7 @@ if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' echo 'Changes not staged for commit:' echo - git --no-pager diff --name-only + git --no-pager diff exit 1 fi From f42318f9acbd22fe26a40f849e3fea7cd318d344 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:06:13 +0800 Subject: [PATCH 30/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20BestFit=20=E7=B1=BB?= =?UTF-8?q?=E7=9A=84=20malloc=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E6=9D=A1=E4=BB=B6=E5=88=A4=E6=96=AD=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=8F=90=E5=8D=87=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/bestfit.py | 3 ++- bitblas/base/roller/policy/__init__.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index dd21aa080..8d5a55eda 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -33,7 +33,8 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size and (not found or found.size() > block.size()): + if block.is_free and block.size() >= size and (not found or + found.size() > block.size()): found = block if found: found.is_free = False diff --git a/bitblas/base/roller/policy/__init__.py b/bitblas/base/roller/policy/__init__.py index 750bca05e..ec8c69c00 100644 --- a/bitblas/base/roller/policy/__init__.py +++ b/bitblas/base/roller/policy/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .default import DefaultPolicy # noqa: F401 \ No newline at end of file +from .default import DefaultPolicy # noqa: F401 From 3afbe9c9ceb8aaab0ee22b76f448f5b737bca9cc Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:11:39 +0800 Subject: [PATCH 31/42] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20BestFit=20=E7=B1=BB?= =?UTF-8?q?=20malloc=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84=E5=BE=AA?= =?UTF-8?q?=E7=8E=AF=E5=B5=8C=E5=A5=97=EF=BC=8C=E6=8F=90=E5=8D=87=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E9=80=BB=E8=BE=91=E6=B8=85=E6=99=B0=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/bestfit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index 8d5a55eda..0accd5d6c 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -33,9 +33,9 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size and (not found or - found.size() > block.size()): - found = block + if block.is_free and block.size() >= size and (not found or + found.size() > block.size()): + found = block if found: found.is_free = False remain = found.size() - size From b1a68512f6b40728d8ca432989df046dc9066a95 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:13:47 +0800 Subject: [PATCH 32/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20BestFit=20=E7=B1=BB?= =?UTF-8?q?=20malloc=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84=E6=9D=A1?= =?UTF-8?q?=E4=BB=B6=E5=88=A4=E6=96=AD=E9=80=BB=E8=BE=91=EF=BC=8C=E6=8F=90?= =?UTF-8?q?=E5=8D=87=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/bestfit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index 0accd5d6c..08bdd32c6 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -33,9 +33,10 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size and (not found or - found.size() > block.size()): - found = block + if (block.is_free and + block.size() >= size and + (not found or block.size() < found.size())): + found = block if found: found.is_free = False remain = found.size() - size From d319d1fa6d39aa210e89c79c0ce31280f3086fcc Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:16:28 +0800 Subject: [PATCH 33/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20BestFit=20=E7=B1=BB?= =?UTF-8?q?=20malloc=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84=E6=9D=A1?= =?UTF-8?q?=E4=BB=B6=E5=88=A4=E6=96=AD=E9=80=BB=E8=BE=91=EF=BC=8C=E6=8F=90?= =?UTF-8?q?=E5=8D=87=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB=E6=80=A7=EF=BC=9B?= =?UTF-8?q?=E8=B0=83=E6=95=B4=E5=AF=BC=E5=85=A5=E8=AF=AD=E5=8F=A5=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/bestfit.py | 3 +-- bitblas/base/roller/policy/__init__.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index 08bdd32c6..c40fd7f67 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -33,8 +33,7 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if (block.is_free and - block.size() >= size and + if (block.is_free and block.size() >= size and (not found or block.size() < found.size())): found = block if found: diff --git a/bitblas/base/roller/policy/__init__.py b/bitblas/base/roller/policy/__init__.py index ec8c69c00..2b7cb4af4 100644 --- a/bitblas/base/roller/policy/__init__.py +++ b/bitblas/base/roller/policy/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .default import DefaultPolicy # noqa: F401 +from .default import DefaultPolicy # noqa: F401 From f19482082d65494e454c32fc42556edfe18ea3ba Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:25:18 +0800 Subject: [PATCH 34/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20BestFit=20=E7=B1=BB?= =?UTF-8?q?=20malloc=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E6=A0=BC=E5=BC=8F=EF=BC=8C=E6=8F=90=E5=8D=87=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/bestfit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index c40fd7f67..e938e4d07 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -33,9 +33,9 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if (block.is_free and block.size() >= size and + if (block.is_free and block.size() >= size and (not found or block.size() < found.size())): - found = block + found = block if found: found.is_free = False remain = found.size() - size From 376afa20073f7cee2f568d8f2524ce5378fafe5a Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:26:30 +0800 Subject: [PATCH 35/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20format.sh=20?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=EF=BC=8C=E6=9B=B4=E6=96=B0=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E8=AF=B4=E6=98=8E=E5=92=8C=E5=90=88=E5=B9=B6=E5=9F=BA=E5=87=86?= =?UTF-8?q?=E7=9A=84=E7=A1=AE=E5=AE=9A=E9=80=BB=E8=BE=91=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E5=B7=A5=E5=85=B7=E7=89=88=E6=9C=AC=E6=A3=80=E6=9F=A5?= =?UTF-8?q?=E7=9A=84=E9=B2=81=E6=A3=92=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- format.sh | 265 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 172 insertions(+), 93 deletions(-) diff --git a/format.sh b/format.sh index f8455f621..c09777503 100755 --- a/format.sh +++ b/format.sh @@ -6,13 +6,14 @@ # Usage: # # Do work and commit your work. -# # Format files that differ from origin/main. +# # Format files that differ from the determined merge base (upstream/main, origin/main, or local main). # bash format.sh # # Commit changed files with message 'Run yapf and ruff' # # -# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# YAPF + Ruff + Codespell. This script formats, lints, and spell-checks changed files +# based on the merge-base with upstream/main, origin/main, or local main. # You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails @@ -23,29 +24,133 @@ builtin cd "$(dirname "${BASH_SOURCE:-$0}")" ROOT="$(git rev-parse --show-toplevel)" builtin cd "$ROOT" || exit 1 +# --- Tool Version Checks --- YAPF_VERSION=$(yapf --version | awk '{print $2}') RUFF_VERSION=$(ruff --version | awk '{print $2}') -CODESPELL_VERSION=$(codespell --version) +# Handle potential variations in codespell version output +CODESPELL_RAW_VERSION=$(codespell --version) +if [[ "$CODESPELL_RAW_VERSION" == codespell* ]]; then + CODESPELL_VERSION=$(echo "$CODESPELL_RAW_VERSION" | awk '{print $2}') # Assuming format "codespell x.y.z" +else + CODESPELL_VERSION="$CODESPELL_RAW_VERSION" # Use as is if format is different +fi + -# # params: tool name, tool version, required version +# params: tool name, tool version, required version from file tool_version_check() { - if [[ $2 != $3 ]]; then - echo "Wrong $1 version installed: $3 is required, not $2." - exit 1 + local tool_name=$1 + local installed_version=$2 + local requirement_line + local required_version + + # Find the requirement line robustly (handles == and ===) + requirement_line=$(grep "^${tool_name}[=]=" requirements-dev.txt) || requirement_line=$(grep "^${tool_name}=" requirements-dev.txt) + + if [ -z "$requirement_line" ]; then + echo "Warning: Could not find requirement for '$tool_name' in requirements-dev.txt." + return # Don't exit, just warn if requirement is missing + fi + + # Extract version after the last '=' + required_version=$(echo "$requirement_line" | rev | cut -d'=' -f1 | rev) + + # Special handling for codespell if it only prints version number + if [[ "$tool_name" == "codespell" ]] && [[ "$installed_version" != codespell* ]]; then + # If installed_version is just the number, compare directly + if [[ "$installed_version" != "$required_version" ]]; then + echo "Wrong $tool_name version installed: $required_version is required, not $installed_version." + echo "Requirement line: $requirement_line" + exit 1 + fi + else + # Standard comparison (handles 'tool x.y.z' or just 'x.y.z' if awk worked) + # Extract version number from installed_version if needed + local installed_version_num=$installed_version + if [[ "$installed_version" == ${tool_name}* ]]; then + installed_version_num=$(echo "$installed_version" | awk '{print $2}') + fi + + if [[ "$installed_version_num" != "$required_version" ]]; then + echo "Wrong $tool_name version installed: $required_version is required, not $installed_version_num (from '$installed_version')." + echo "Requirement line: $requirement_line" + exit 1 + fi fi } -tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "yapf" "$YAPF_VERSION" +tool_version_check "ruff" "$RUFF_VERSION" +tool_version_check "codespell" "$CODESPELL_VERSION" + +# --- Determine Merge Base --- +# Define the upstream repository URL to compare against +UPSTREAM_REPO="https://github.com/microsoft/BitBLAS" +MERGEBASE="" # Initialize MERGEBASE variable + +echo "Determining merge base for diff..." + +# 1. Try to compare directly with the main branch of the upstream repository +if git ls-remote --exit-code "$UPSTREAM_REPO" main &>/dev/null; then + echo "Attempting to find merge base with upstream: $UPSTREAM_REPO main" + MERGEBASE_CMD_OUTPUT=$(git fetch "$UPSTREAM_REPO" main --quiet --no-tags 2>/dev/null && git merge-base FETCH_HEAD HEAD) + FETCH_STATUS=$? + if [ $FETCH_STATUS -eq 0 ] && [ -n "$MERGEBASE_CMD_OUTPUT" ]; then + MERGEBASE="$MERGEBASE_CMD_OUTPUT" + echo "Successfully found merge base with upstream: $MERGEBASE" + else + echo "Warning: Could not determine merge base with $UPSTREAM_REPO main (fetch/merge-base failed or no common ancestor). Falling back..." + fi +fi + +# 2. If MERGEBASE could not be obtained from upstream, try using origin/main +if [ -z "$MERGEBASE" ] && git show-ref --verify --quiet refs/remotes/origin/main; then + echo "Falling back to merge base with origin/main" + BASE_BRANCH="origin/main" + MERGEBASE_CMD_OUTPUT=$(git merge-base "$BASE_BRANCH" HEAD) + MERGEBASE_STATUS=$? + if [ $MERGEBASE_STATUS -eq 0 ] && [ -n "$MERGEBASE_CMD_OUTPUT" ]; then + MERGEBASE="$MERGEBASE_CMD_OUTPUT" + echo "Successfully found merge base with $BASE_BRANCH: $MERGEBASE" + else + echo "Warning: Could not determine merge base with $BASE_BRANCH. Falling back..." + fi +fi + +# 3. If even origin/main doesn't work, try using the local main branch +if [ -z "$MERGEBASE" ]; then + echo "Falling back to merge base with local main" + BASE_BRANCH="main" + if git show-ref --verify --quiet "refs/heads/$BASE_BRANCH"; then + MERGEBASE_CMD_OUTPUT=$(git merge-base "$BASE_BRANCH" HEAD) + MERGEBASE_STATUS=$? + if [ $MERGEBASE_STATUS -eq 0 ] && [ -n "$MERGEBASE_CMD_OUTPUT" ]; then + MERGEBASE="$MERGEBASE_CMD_OUTPUT" + echo "Successfully found merge base with $BASE_BRANCH: $MERGEBASE" + else + echo "Warning: Could not determine merge base with local $BASE_BRANCH." + fi + else + echo "Warning: Local branch '$BASE_BRANCH' not found." + fi +fi + +# 4. Final check for MERGEBASE +if [ -z "$MERGEBASE" ]; then + echo "Error: Could not determine a suitable merge base. Unable to proceed with diffing changed files." + exit 1 +fi + +echo "Using final merge base: $MERGEBASE" +# --- Merge Base Determined --- -echo 'bitblas yapf: Check Start' + +# --- YAPF Formatting --- +echo '--- bitblas yapf: Check Start ---' YAPF_FLAGS=( '--recursive' '--parallel' ) - YAPF_EXCLUDES=( '--exclude' 'build/**' ) @@ -55,149 +160,123 @@ format() { yapf --in-place "${YAPF_FLAGS[@]}" "$@" } -# Format files that differ from main branch. Ignores dirs that are not slated -# for autoformat yet. +# Format files that differ from the determined merge base. format_changed() { - # The `if` guard ensures that the list of filenames is not empty, which - # could cause yapf to receive 0 positional arguments, making it hang - # waiting for STDIN. - # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that - # exist on both branches. - if git show-ref --verify --quiet refs/remotes/origin/main; then - BASE_BRANCH="origin/main" - else - BASE_BRANCH="main" - fi - - MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" - + # Use the globally determined $MERGEBASE if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + echo "Running yapf on changed Python files..." git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" + else + echo "No Python files changed according to yapf." fi - } # Format all files format_all() { + echo "Running yapf on all Python files..." yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" . } -## This flag formats individual files. --files *must* be the first command line -## arg to use this option. +# YAPF Execution Logic if [[ "$1" == '--files' ]]; then format "${@:2}" - # If `--all` is passed, then any further arguments are ignored and the - # entire python directory is formatted. elif [[ "$1" == '--all' ]]; then format_all else - # Format only the files that changed in last commit. format_changed fi -echo 'bitblas yapf: Done' +echo '--- bitblas yapf: Done ---' + -echo 'bitblas codespell: Check Start' -# check spelling of specified files +# --- Codespell Check --- +echo '--- bitblas codespell: Check Start ---' + +# Check spelling of specified files spell_check() { codespell "$@" } +# Check spelling based on pyproject.toml config (usually checks all relevant files) spell_check_all(){ + echo "Running codespell based on pyproject.toml..." codespell --toml pyproject.toml } -# Spelling check of files that differ from main branch. +# Check spelling of files that differ from the determined merge base. spell_check_changed() { - # The `if` guard ensures that the list of filenames is not empty, which - # could cause ruff to receive 0 positional arguments, making it hang - # waiting for STDIN. - # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that - # exist on both branches. - if git show-ref --verify --quiet refs/remotes/origin/main; then - BASE_BRANCH="origin/main" + # Use the globally determined $MERGEBASE + # Check Python and potentially other relevant text files (adjust patterns as needed) + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' '*.md' '*.rst' &>/dev/null; then + echo "Running codespell on changed text files..." + # Note: Consider filtering for files codespell actually handles if needed + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' '*.md' '*.rst' | xargs \ + codespell --quiet-level 3 # Adjust quiet level as needed else - BASE_BRANCH="main" - fi - - MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" - - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - codespell + echo "No relevant text files changed according to codespell." fi } -# Run Codespell -## This flag runs spell check of individual files. --files *must* be the first command line -## arg to use this option. +# Codespell Execution Logic if [[ "$1" == '--files' ]]; then spell_check "${@:2}" - # If `--all` is passed, then any further arguments are ignored and the - # entire python directory is linted. elif [[ "$1" == '--all' ]]; then spell_check_all else - # Check spelling only of the files that changed in last commit. spell_check_changed fi -echo 'bitblas codespell: Done' +echo '--- bitblas codespell: Done ---' + + +# --- Ruff Linting --- +echo '--- bitblas ruff: Check Start ---' -echo 'bitblas ruff: Check Start' # Lint specified files lint() { ruff check "$@" } -# Lint files that differ from main branch. Ignores dirs that are not slated -# for autolint yet. +# Lint files that differ from the determined merge base. lint_changed() { - # The `if` guard ensures that the list of filenames is not empty, which - # could cause ruff to receive 0 positional arguments, making it hang - # waiting for STDIN. - # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that - # exist on both branches. - if git show-ref --verify --quiet refs/remotes/origin/main; then - BASE_BRANCH="origin/main" - else - BASE_BRANCH="main" - fi - - MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" - + # Use the globally determined $MERGEBASE if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + echo "Running ruff check on changed Python files..." git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ ruff check + else + echo "No Python files changed according to ruff." fi - } -# Run Ruff -### This flag lints individual files. --files *must* be the first command line -### arg to use this option. +# Ruff Execution Logic if [[ "$1" == '--files' ]]; then lint "${@:2}" - # If `--all` is passed, then any further arguments are ignored and the - # entire python directory is linted. elif [[ "$1" == '--all' ]]; then - lint BitBLAS tests + echo "Running ruff check on specified directories..." + # Adjust directories as needed for your project structure + lint BitBLAS tests # Assuming these are the main directories else - # Format only the files that changed in last commit. lint_changed fi +echo '--- bitblas ruff: Done ---' +# --- Final Check for Changes --- +# Check if yapf (or potentially other tools if they modify files) made changes if ! git diff --quiet &>/dev/null; then - echo 'Reformatted files. Please review and stage the changes.' - echo 'Changes not staged for commit:' echo - git --no-pager diff - + echo '-----------------------------------------------------------------------' + echo 'Detected changes made by the formatting/linting tools.' + echo 'Please review and stage these changes before committing:' + echo '-----------------------------------------------------------------------' + echo + git --no-pager diff --color=always # Show colored diff directly + echo + echo '-----------------------------------------------------------------------' + echo 'Exiting with status 1 due to needed changes.' + echo '-----------------------------------------------------------------------' exit 1 fi -echo 'bitblas ruff: Done' - -echo 'bitblas: All checks passed' +echo +echo '--- bitblas: All checks passed ---' +exit 0 \ No newline at end of file From e9b604b9270cdb87793dcc2b6a96a23ecef70262 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 00:42:06 +0800 Subject: [PATCH 36/42] =?UTF-8?q?=E5=9C=A8=20requirements-test.txt=20?= =?UTF-8?q?=E4=B8=AD=E6=B7=BB=E5=8A=A0=20cmake=20=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements-test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-test.txt b/requirements-test.txt index a06a6dd87..b196109ac 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,6 +6,7 @@ ruff==0.6.5 codespell==2.3.0 cffi +cmake cpplint Cython decorator From 5fe0c4fd4394dc61f8bb6d45e86d821ca23a1f42 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 19:30:26 +0800 Subject: [PATCH 37/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20CI=20=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E7=A7=BB=E9=99=A4=20cmake=20=E8=BD=AF?= =?UTF-8?q?=E9=93=BE=E6=8E=A5=E5=B9=B6=E6=9B=B4=E6=96=B0=E5=93=88=E5=B8=8C?= =?UTF-8?q?=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index be3688ee5..1f281e8cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,6 +55,8 @@ jobs: - name: Activate virtual environment and install dependencies run: | source bitblas_ci/bin/activate + rm bitblas_ci/bin/cmake + hash -r python -m pip install --upgrade pip if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi From ee25df09ca6f114f6855d0aa485eb510e0b80b89 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 19:34:48 +0800 Subject: [PATCH 38/42] =?UTF-8?q?=E4=BC=98=E5=8C=96=20CI=20=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E8=B0=83=E6=95=B4=20cmake=20=E8=BD=AF?= =?UTF-8?q?=E9=93=BE=E6=8E=A5=E7=A7=BB=E9=99=A4=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/ci.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f281e8cf..6b3c72ea1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,10 +55,12 @@ jobs: - name: Activate virtual environment and install dependencies run: | source bitblas_ci/bin/activate - rm bitblas_ci/bin/cmake - hash -r python -m pip install --upgrade pip if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi + if [ -f bitblas_ci/bin/cmake ]; then + rm bitblas_ci/bin/cmake + hash -r + fi - name: Install project in wheel mode run: | From ca5b51a312a0564c39ae7f4abe6cf194eea96bd8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 19:35:25 +0800 Subject: [PATCH 39/42] =?UTF-8?q?=E7=A7=BB=E9=99=A4=20requirements-test.tx?= =?UTF-8?q?t=20=E4=B8=AD=E7=9A=84=20cmake=20=E4=BE=9D=E8=B5=96=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements-test.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index b196109ac..a06a6dd87 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -6,7 +6,6 @@ ruff==0.6.5 codespell==2.3.0 cffi -cmake cpplint Cython decorator From de8a72e264ef9aaed2f83ba4d48ca1ffe821d77e Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 19:46:12 +0800 Subject: [PATCH 40/42] =?UTF-8?q?=E5=9C=A8=20CI=20=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E4=B8=AD=E6=B7=BB=E5=8A=A0=20PATH=20=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E5=8F=98=E9=87=8F=E8=AE=BE=E7=BD=AE=EF=BC=8C=E4=BB=A5=E7=A1=AE?= =?UTF-8?q?=E4=BF=9D=E6=AD=A3=E7=A1=AE=E5=AE=89=E8=A3=85=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6b3c72ea1..3934c15e9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,6 +65,7 @@ jobs: - name: Install project in wheel mode run: | source bitblas_ci/bin/activate + export PATH="/usr/bin:$PATH" python -m pip install . - name: Run tests From efbe524f658c7a1e4b5d3217f77404ea6efd458b Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 20:00:20 +0800 Subject: [PATCH 41/42] =?UTF-8?q?=E5=9C=A8=E5=AE=89=E8=A3=85=E6=AD=A5?= =?UTF-8?q?=E9=AA=A4=E4=B8=AD=E6=B7=BB=E5=8A=A0=E6=89=A7=E8=A1=8C=20instal?= =?UTF-8?q?l.sh=20=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3934c15e9..22a65d371 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,6 +66,7 @@ jobs: run: | source bitblas_ci/bin/activate export PATH="/usr/bin:$PATH" + bash install.sh python -m pip install . - name: Run tests From e94911a780ceafab5a7fbeb4c9000d23316a8973 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 May 2025 20:10:21 +0800 Subject: [PATCH 42/42] =?UTF-8?q?=E5=9C=A8=20policy=20=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E4=B8=AD=E6=B7=BB=E5=8A=A0=20TensorCorePolicy=20=E5=AF=BC?= =?UTF-8?q?=E5=85=A5=EF=BC=8C=E5=B9=B6=E5=9C=A8=20shape=5Finference=20?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E4=B8=AD=E4=BF=AE=E5=A4=8D=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bitblas/base/roller/policy/__init__.py | 1 + bitblas/base/roller/shape_inference/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bitblas/base/roller/policy/__init__.py b/bitblas/base/roller/policy/__init__.py index 2b7cb4af4..786eb9156 100644 --- a/bitblas/base/roller/policy/__init__.py +++ b/bitblas/base/roller/policy/__init__.py @@ -2,3 +2,4 @@ # Licensed under the MIT License. from .default import DefaultPolicy # noqa: F401 +from .tensorcore import TensorCorePolicy # noqa: F401 diff --git a/bitblas/base/roller/shape_inference/__init__.py b/bitblas/base/roller/shape_inference/__init__.py index 2a709d434..3f9d0ed7f 100644 --- a/bitblas/base/roller/shape_inference/__init__.py +++ b/bitblas/base/roller/shape_inference/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# from .tir import get_analyzer_by_tir # pylint: disable=unused-import +from .tir import get_analyzer_by_tir # pylint: disable=unused-import # noqa: F401