diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index be3688ee5..22a65d371 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,10 +57,16 @@ jobs: source bitblas_ci/bin/activate python -m pip install --upgrade pip if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi + if [ -f bitblas_ci/bin/cmake ]; then + rm bitblas_ci/bin/cmake + hash -r + fi - name: Install project in wheel mode run: | source bitblas_ci/bin/activate + export PATH="/usr/bin:$PATH" + bash install.sh python -m pip install . - name: Run tests diff --git a/.gitmodules b/.gitmodules index adbfcc33f..25a3b1a8c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,8 +5,7 @@ [submodule "3rdparty/tilelang"] path = 3rdparty/tilelang url = https://github.com/tile-ai/tilelang - branch = bitblas + branch = main [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass - url = https://github.com/tile-ai/cutlass - branch = tldev + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a2954a8fd..5e497243f 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a2954a8fdd9a73852f2c1ddea97d0e8a579cfb25 +Subproject commit 5e497243f7ad13a2aa842143f9b10bbb23d98292 diff --git a/3rdparty/tilelang b/3rdparty/tilelang index b09e2b5cc..3198a2f46 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit b09e2b5cc6abfe94c35249cb99ad899ef394964e +Subproject commit 3198a2f46c373a6b6967bd33332145812a47b44e diff --git a/3rdparty/tvm b/3rdparty/tvm index d310bd5aa..b29452549 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d310bd5aadce96145546fb7a87a6d325ea392b2b +Subproject commit b2945254932cffa89922ec7f6e868d726aed0f6a diff --git a/README.md b/README.md index ecf76ba5b..a6f7cc673 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ We are continuously expanding the support matrix. If you have any specific requi - **Python Version**: >= 3.8 - **CUDA Version**: >= 11.0 -The easiest way to install BitBLAS is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal. +The easiest way to install BitBLAS is directly from the PyPi using pip. To install the latest version, run the following command in your terminal. ```bash pip install bitblas diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py index ad8ec20a8..e938e4d07 100644 --- a/bitblas/base/roller/bestfit.py +++ b/bitblas/base/roller/bestfit.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +"""Benefit For BitBLAS Schedule""" + -"""Benifit For BitBLAS Schedule""" class Block: + def __init__(self, start, end, is_free): self.start = start self.end = end @@ -21,6 +23,7 @@ def __repr__(self) -> str: class BestFit: + def __init__(self, align=32): self.limit = 0 self.list = [] @@ -30,17 +33,16 @@ def malloc(self, size) -> Block: size = (size + self.align - 1) // self.align * self.align found = None for block in self.list: - if block.is_free and block.size() >= size: - if not found or found.size() > block.size(): - found = block + if (block.is_free and block.size() >= size and + (not found or block.size() < found.size())): + found = block if found: found.is_free = False remain = found.size() - size if remain != 0: found.end -= remain self.list.insert( - self.list.index(found) + 1, Block(found.end, found.end + remain, True) - ) + self.list.index(found) + 1, Block(found.end, found.end + remain, True)) return found elif len(self.list) > 0 and self.list[-1].is_free: add = size - self.list[-1].size() diff --git a/bitblas/base/roller/node.py b/bitblas/base/roller/node.py index c9d648019..c0b49f6bf 100644 --- a/bitblas/base/roller/node.py +++ b/bitblas/base/roller/node.py @@ -232,8 +232,9 @@ def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): if rstep is None: rstep = {} shape = { - self.block_analyzer.get_output_buffers(block)[0].name: - [tvm.arith.ConstIntBound(0, val - 1) for val in tile] for block in self.schedule_stages + self.block_analyzer.get_output_buffers(block)[0].name: [ + tvm.arith.ConstIntBound(0, val - 1) for val in tile + ] for block in self.schedule_stages } return self.ana.infer(shape, rstep, targets) diff --git a/bitblas/base/roller/policy/__init__.py b/bitblas/base/roller/policy/__init__.py index 09ed1d51b..786eb9156 100644 --- a/bitblas/base/roller/policy/__init__.py +++ b/bitblas/base/roller/policy/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .default import DefaultPolicy -from .tensorcore import TensorCorePolicy +from .default import DefaultPolicy # noqa: F401 +from .tensorcore import TensorCorePolicy # noqa: F401 diff --git a/bitblas/base/roller/shape_inference/__init__.py b/bitblas/base/roller/shape_inference/__init__.py index 188aa0bb7..3f9d0ed7f 100644 --- a/bitblas/base/roller/shape_inference/__init__.py +++ b/bitblas/base/roller/shape_inference/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .tir import get_analyzer_by_tir # pylint: disable=unused-import +from .tir import get_analyzer_by_tir # pylint: disable=unused-import # noqa: F401 diff --git a/bitblas/base/roller/shape_inference/common.py b/bitblas/base/roller/shape_inference/common.py index 730bbbeef..6cc6d228b 100644 --- a/bitblas/base/roller/shape_inference/common.py +++ b/bitblas/base/roller/shape_inference/common.py @@ -8,16 +8,21 @@ class Statement(): - def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): + + def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, + range_map: OrderedDict): self.output = output self.dependent_region = dependent_region self.var_map = var_map self.range_map = range_map + def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + class InputShapeInference(): + def __init__(self, deps: List[Statement]): self.deps = deps @@ -34,7 +39,7 @@ def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, i for name, regions in dep.dependent_region.items(): for region in regions: bounds = [ana.const_int_bound(index) for index in region] - if name in shape: # simply merge two bounds + if name in shape: # simply merge two bounds bounds = [_merge_two_bounds(x, y) for x, y in zip(shape[name], bounds)] shape[name] = bounds @@ -42,9 +47,11 @@ def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, i shape[name] = [c.max_value - c.min_value + 1 for c in bounds] return shape - def infer(self, shape, rstep: Dict[str, int] = {}): + def infer(self, shape, rstep: Dict[str, int] = None): + if rstep is None: + rstep = {} if isinstance(shape, (list, tuple)): - shape = {"output0" : [arith.ConstIntBound(0, val - 1) for val in shape]} + shape = {"output0": [arith.ConstIntBound(0, val - 1) for val in shape]} shape = self._infer(shape, rstep) return shape @@ -63,4 +70,3 @@ def get_input_exprs(self, output_exprs): input_expr = [ana.simplify(index) for index in region] result[name] = input_expr return result - diff --git a/bitblas/base/roller/shape_inference/tir.py b/bitblas/base/roller/shape_inference/tir.py index 35bf0b7d8..19d70857e 100644 --- a/bitblas/base/roller/shape_inference/tir.py +++ b/bitblas/base/roller/shape_inference/tir.py @@ -8,6 +8,7 @@ class Statement: + def __init__(self, block_analyzer, block: BlockRV): self.block_analyzer = block_analyzer self.block = block @@ -79,6 +80,7 @@ def __repr__(self): class DependencyAnalysis(object): + def __init__(self, deps): self.deps = deps # issue: duplicate name when we have two same ops. @@ -90,8 +92,8 @@ def _construct_unique_name2dep(self, deps): This is a workaround for the issue that we have two same ops' fuse case. See https://github.com/apache/tvm/issues/16433 """ - _names:Set = set() - name2dep:Mapping = {} + _names: Set = set() + name2dep: Mapping = {} for dep in deps: output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] base_name = output_buffer.name @@ -105,7 +107,7 @@ def _construct_unique_name2dep(self, deps): _names.add(base_name) name2dep[base_name] = dep return name2dep - + def get_or_create_node(self, name): if name not in self.mapping: self.mapping[name] = TensorDepNode(name) @@ -114,8 +116,7 @@ def get_or_create_node(self, name): def traverse_dependencies(self, compute): if isinstance(compute, Statement): node = self.get_or_create_node( - compute.block_analyzer.get_output_buffers(compute.block)[0].name - ) + compute.block_analyzer.get_output_buffers(compute.block)[0].name) # Loop through input tensors for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): # Get the input node @@ -169,6 +170,7 @@ def _find_path_recursive(self, current_node, target_name, visited, path): class InputShapeInference: + def __init__(self, deps: List[Statement]): self.deps = deps self.target_mapping = {} @@ -242,9 +244,12 @@ def construct_dependency_target(self, targets: Tuple[str]): self.target_mapping[targets] = input_vars, mapping return input_vars, mapping - def infer( - self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int] = {}, targets=None - ): + def infer(self, + shape: Dict[str, List[arith.ConstIntBound]], + rstep: Dict[str, int] = None, + targets=None): + if rstep is None: + rstep = {} compute_targets = tuple(shape.keys()) input_vars, mapping = self.construct_dependency_target(compute_targets) ana = arith.Analyzer() @@ -257,8 +262,7 @@ def infer( # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. if ax.var.name in rstep: bound = arith.ConstIntBound( - int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1) - ) + int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1)) else: bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) ana.update(ax.var, bound, True) @@ -318,16 +322,14 @@ def get_input_exprs(self, output_exprs): def region_exist_in_list(a, list) -> bool: + def expr_is_same(a, b) -> bool: if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): return a.value == b.value return structural_equal(a, b) def region_is_same(a, b) -> bool: - for indice_a, indice_b in zip(a, b): - if not expr_is_same(indice_a, indice_b): - return False - return True + return all(expr_is_same(indice_a, indice_b) for indice_a, indice_b in zip(a, b)) return any([region_is_same(a, x) for x in list]) @@ -340,9 +342,7 @@ def walk_indice(expr): return expr else: return None - elif isinstance(expr, tir.expr.ConstExpr): - return expr - elif isinstance(expr, tir.Var): + elif isinstance(expr, (tir.expr.ConstExpr, tir.Var)): return expr elif isinstance(expr, tir.ProducerLoad): return None @@ -381,7 +381,7 @@ def fvisit(x): with T.init(): T_dense_reindex[T.int64(0), v0, v1] = T.float16(0) T_dense_reindex[T.int64(0), v0, v1] = T_dense_reindex[T.int64(0), v0, v1] + A_reindex[T.int64(0), v0, v2] * B_reindex[T.int64(0), v1, v2] - For exmaple, the T_dense_reindex has three dims, however there're only two spatial loops. + For example, the T_dense_reindex has three dims, however there're only two spatial loops. """ continue index.append(expr) diff --git a/bitblas/gpu/base.py b/bitblas/gpu/base.py index 3bf927244..78e06658b 100644 --- a/bitblas/gpu/base.py +++ b/bitblas/gpu/base.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # /* Modifications Copyright (c) Microsoft. */ # The code below is mostly copied from apache/tvm base.py in dlight. """Base schedule rule for GPU operators.""" diff --git a/bitblas/gpu/fallback.py b/bitblas/gpu/fallback.py index 3711d3682..420473036 100644 --- a/bitblas/gpu/fallback.py +++ b/bitblas/gpu/fallback.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm fallback.py in dlight. # pylint: disable=missing-docstring @@ -61,15 +61,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring dom_kind = block.dom_kind() block = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): continue for loop, iter_type in zip(sch.get_loops(block), dom_kind): @@ -92,4 +86,3 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.decompose_reduction(block, r_loop) return sch - \ No newline at end of file diff --git a/bitblas/gpu/gemv.py b/bitblas/gpu/gemv.py index 60a290a81..bed06a055 100644 --- a/bitblas/gpu/gemv.py +++ b/bitblas/gpu/gemv.py @@ -500,7 +500,7 @@ def apply( if not isinstance(len_S, int): TS, TR = 1, 64 - while TS * TR > target.max_num_threads: + while target.max_num_threads < TS * TR: if TS > 1: TS //= 2 else: diff --git a/bitblas/gpu/general_reduction.py b/bitblas/gpu/general_reduction.py index cc03acd99..fd6da8a71 100644 --- a/bitblas/gpu/general_reduction.py +++ b/bitblas/gpu/general_reduction.py @@ -47,10 +47,8 @@ def apply( # pylint: disable=too-many-locals num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), + lambda *iters: ([tir.const(0, iters[0].dtype)] * + (len(dom_kind) - num_last_block_iter) + list(iters)), ndim=num_last_block_iter, ) sch.transform_block_layout(block_infos[-1].block_rv, index_map) @@ -117,15 +115,10 @@ def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many dom_kind = block.dom_kind() block_rv = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block_rv) - ] - ) - or len(sch.get_loops(block.block_rv)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ]) or len(sch.get_loops(block.block_rv)) == 0): continue for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): @@ -195,10 +188,7 @@ def prod(iterable): dim_offset = ( len(reduce_inner_axis) + len(reduce_outer_axis) + 2 ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis - if input_region.buffer.name in config.vectorize: - vectorize = config.vectorize[input_region.buffer.name] - else: - vectorize = 1 + vectorize = config.vectorize.get(input_region.buffer.name, 1) loops = sch.get_loops(cache_shared) if len(loops) == dim_offset: @@ -248,15 +238,10 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many dom_kind = block.dom_kind() block_rv = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block_rv) - ] - ) - or len(sch.get_loops(block.block_rv)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ]) or len(sch.get_loops(block.block_rv)) == 0): continue for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): @@ -276,10 +261,8 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), + lambda *iters: ([tir.const(0, iters[0].dtype)] * + (len(dom_kind) - num_last_block_iter) + list(iters)), ndim=num_last_block_iter, ) sch.transform_block_layout(block_infos[-1].block_rv, index_map) @@ -295,13 +278,11 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many vthread_axis = [] thread_axis = [] inner_axis = [] - for s_loop, block_factor, step_factor, thread_factor in zip( - s_loops, block_factors, step_factors, thread_factors - ): + for s_loop, block_factor, step_factor, thread_factor in zip(s_loops, block_factors, + step_factors, thread_factors): block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) vthread_loop, inner_loop = sch.split( - inner_loop, factors=[None, thread_factor * step_factor] - ) + inner_loop, factors=[None, thread_factor * step_factor]) thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor]) block_axis.append(block_loop) vthread_axis.append(vthread_loop) @@ -317,13 +298,8 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many vthread_axis = list(reversed(vthread_axis)) # inner virtual thread first axis_order = ( - block_axis - + vthread_axis - + thread_axis - + reduce_outer_axis - + reduce_inner_axis - + inner_axis - ) + block_axis + vthread_axis + thread_axis + reduce_outer_axis + reduce_inner_axis + + inner_axis) sch.reorder(*axis_order) blck_fused = sch.fuse(*block_axis) @@ -347,10 +323,7 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many dim_offset = ( len(vthread_axis) + len(reduce_outer_axis) + 2 ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis - if input_region.buffer.name in config.vectorize: - vectorize = config.vectorize[input_region.buffer.name] - else: - vectorize = 1 + vectorize = config.vectorize.get(input_region.buffer.name, 1) loops = sch.get_loops(cache_shared) if len(loops) == dim_offset: @@ -362,8 +335,8 @@ def prod(iterable): return reduce(lambda x, y: x * y, iterable, 1) _, tx, tv = sch.split( - sch.fuse(*loops[dim_offset:]), factors=[None, int(prod(thread_factors)), vectorize] - ) + sch.fuse(*loops[dim_offset:]), factors=[None, + int(prod(thread_factors)), vectorize]) sch.vectorize(tv) sch.bind(tx, "threadIdx.x") @@ -407,10 +380,8 @@ def prod(iterable): num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), + lambda *iters: ([tir.const(0, iters[0].dtype)] * + (len(dom_kind) - num_last_block_iter) + list(iters)), ndim=num_last_block_iter, ) sch.transform_block_layout(block_infos[-1].block_rv, index_map) diff --git a/bitblas/gpu/reduction.py b/bitblas/gpu/reduction.py index 9d6aada75..a7430231b 100644 --- a/bitblas/gpu/reduction.py +++ b/bitblas/gpu/reduction.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm reduction.py in dlight. """A rule for reduction. """ @@ -43,9 +43,9 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: if not isinstance(buffer_store.value, tir.Add): return None if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, ): return None return buffer_store.value.b @@ -81,11 +81,8 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- block_stmt = sch.get(block) # Step 1. Check reduction block - if ( - (not block_info.is_reduction()) - or len(block_stmt.writes) != 1 - or _get_reduction_expr(block_stmt) is None - ): + if ((not block_info.is_reduction()) or len(block_stmt.writes) != 1 or + _get_reduction_expr(block_stmt) is None): return None # Step 2. Normalize the block, merge spatial and reduction iters is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize( @@ -100,13 +97,11 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- return None # Step 3. Do the scheduling if is_inner_reduction: - self._sch_inner_reduction( - sch, target, block, c_factor, epilogue, loop_order, s_split_index - ) + self._sch_inner_reduction(sch, target, block, c_factor, epilogue, loop_order, + s_split_index) else: - self._sch_inner_spatial( - sch, target, block, block_info, c_factor, epilogue, loop_order, s_split_index - ) + self._sch_inner_spatial(sch, target, block, block_info, c_factor, epilogue, loop_order, + s_split_index) return sch def _normalize( # pylint: disable=too-many-branches @@ -140,7 +135,7 @@ def _normalize( # pylint: disable=too-many-branches s_loops.append(loop) if iter_to_info: - for var, info in iter_to_info.items(): + for _var, info in iter_to_info.items(): if info.kind == "S" and info.dom.extent == 1: s_loops.append(info.loop_rv) else: @@ -185,8 +180,7 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments # pylint: disable=invalid-name _, r, _ = sch.get_loops(block) (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking - target, [sch.get(r)] - ) + target, [sch.get(r)]) _, tx = sch.split(r, factors=[None, len_tx]) # Schedule the RF block @@ -208,8 +202,7 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments new_order_s = [s[loop_order[i]] for i in range(len(s))] sch.reorder(*new_order_s) new_order_s[s_split_index], c = sch.split( - new_order_s[s_split_index], factors=[None, unroll_spatial_factor] - ) + new_order_s[s_split_index], factors=[None, unroll_spatial_factor]) sch.reorder(*new_order_s, c) s = sch.fuse(*new_order_s) sch.reorder(s, tx, c) @@ -270,8 +263,7 @@ def _sch_inner_spatial( new_order_s = [s[loop_order[i]] for i in range(len(s))] sch.reorder(*new_order_s) new_order_s[s_split_index], c = sch.split( - new_order_s[s_split_index], factors=[None, unroll_spatial_factor] - ) + new_order_s[s_split_index], factors=[None, unroll_spatial_factor]) sch.reorder(*new_order_s, c) s = sch.fuse(*new_order_s) sch.reorder(s, c, r) diff --git a/bitblas/gpu/rmsnorm.py b/bitblas/gpu/rmsnorm.py index 0d8b37998..941cf37d2 100644 --- a/bitblas/gpu/rmsnorm.py +++ b/bitblas/gpu/rmsnorm.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm rmsnorm.py in dlight. # pylint: disable=missing-docstring @@ -52,11 +52,7 @@ def identify_cast_or_load_block(block: Block) -> bool: if len(load.indices) != len(store.indices): return False - for lhs, rhs in zip(load.indices, store.indices): - if not lhs.same_as(rhs): - return False - - return True + return all(lhs.same_as(rhs) for lhs, rhs in zip(load.indices, store.indices)) def identify_rsqrt_block(block: Block) -> bool: @@ -84,10 +80,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring target: Target, _: bool, ) -> tir.Schedule: - if target.kind.name == "cuda": - num_tx = 512 - else: - num_tx = 64 + num_tx = 512 if target.kind.name == "cuda" else 64 sch = tir.Schedule(func) root = sch.get_block(name="root", func_name="main") @@ -117,8 +110,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring block_loop, loops = sch.get_loops(block=read) thread_loop, _, _ = sch.split( - loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True - ) + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True) sch.bind(block_loop, thread_axis="blockIdx.x") sch.bind(thread_loop, thread_axis="threadIdx.x") sch.vectorize(sch.get_loops(block=read)[-1]) @@ -129,8 +121,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) block_loop, loops = sch.get_loops(block=norm) thread_loop, _, _ = sch.split( - loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True - ) + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True) sch.bind(thread_loop, thread_axis="threadIdx.x") sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) diff --git a/bitblas/gpu/transpose.py b/bitblas/gpu/transpose.py index 6dc025c07..8dc04f10c 100644 --- a/bitblas/gpu/transpose.py +++ b/bitblas/gpu/transpose.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm transpose.py in dlight. """Reduction rule for operators including softmax, layer norm, RMS norm, etc""" @@ -82,7 +82,7 @@ def apply( # pylint: disable=too-many-locals prologue = None # the optional decoding block if transpose_block_idx > 0: - spatials = try_inline_contiguous_spatial(sch, blocks[: transpose_block_idx - 1]) + spatials = try_inline_contiguous_spatial(sch, blocks[:transpose_block_idx - 1]) assert len(spatials) == 0 prologue = blocks[transpose_block_idx - 1].block_rv diff --git a/bitblas/gpu/utils.py b/bitblas/gpu/utils.py index e3a5b6098..575624832 100644 --- a/bitblas/gpu/utils.py +++ b/bitblas/gpu/utils.py @@ -38,9 +38,7 @@ def suggest_threads_per_block( ) -> List[int]: if target.kind.name == "cuda": threads = 1024 - elif target.kind.name == "rocm": - threads = 256 - elif target.kind.name == "metal": + elif target.kind.name == "rocm" or target.kind.name == "metal": threads = 256 else: threads = 64 diff --git a/bitblas/ops/impl/convolution2d_impl.py b/bitblas/ops/impl/convolution2d_impl.py index c7d21d7c8..f637fd8c1 100644 --- a/bitblas/ops/impl/convolution2d_impl.py +++ b/bitblas/ops/impl/convolution2d_impl.py @@ -52,9 +52,13 @@ def conv2d_nhwc_ohwi( C = te.compute( out_shape, lambda n, h, w, f: te.sum( - pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), - c,].astype(accum_dtype) * B[f, kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( - dilation_w), c].astype(accum_dtype), + pad[ + n, + h * stride_h + kh * tir.any(dilation_h), + w * stride_w + kw * tir.any(dilation_w), + c, + ].astype(accum_dtype) * B[f, kh - 1 - tir.any(dilation_h), kw - 1 - tir.any(dilation_w), + c].astype(accum_dtype), axis=[kh, kw, c], ), name="C", @@ -117,9 +121,13 @@ def conv2d_nhwc_hwio( C = te.compute( out_shape, lambda n, h, w, f: te.sum( - pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), - c,].astype(accum_dtype) * B[kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( - dilation_w), c, f].astype(accum_dtype), + pad[ + n, + h * stride_h + kh * tir.any(dilation_h), + w * stride_w + kw * tir.any(dilation_w), + c, + ].astype(accum_dtype) * B[kh - 1 - tir.any(dilation_h), kw - 1 - tir.any(dilation_w), c, + f].astype(accum_dtype), axis=[kh, kw, c], ), name="C", diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 304723eb8..46e6226e5 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -21,6 +21,7 @@ from bitblas.builder.lib_generator import LibraryGenerator from bitblas.common import MAX_ERROR_MESSAGE_LENGTH from bitblas.utils import retrieve_func_from_module +from bitblas.tl.lower import tl_lower from dataclasses import dataclass import logging import re @@ -192,7 +193,7 @@ def tvm_callback_hip_postproc(code, _): if self.is_tir_backend(): rt_mod = tvm.build(self.scheduled_ir_module, target=target) elif self.is_tilelang_backend(): - rt_mod = tilelang.lower( + rt_mod = tl_lower( self.scheduled_ir_module, target=target, runtime_only=True) else: raise ValueError(f"Unsupported backend: {self.backend}") diff --git a/bitblas/tl/lower.py b/bitblas/tl/lower.py new file mode 100644 index 000000000..86507cf3e --- /dev/null +++ b/bitblas/tl/lower.py @@ -0,0 +1,30 @@ +from typing import Union, Optional +from bitblas import tilelang as tilelang +from tilelang import tvm as tvm +from tvm import tir +from tvm.target import Target + + +def tl_lower( + func_or_mod: Union[tir.PrimFunc, tvm.IRModule], + target: Union[str, Target] = "auto", + target_host: Optional[Union[str, Target]] = None, + runtime_only=False, +): + with tvm.transform.PassContext(config={ + "tl.disable_dynamic_tail_split": False, + }): + result = tilelang.lower( + func_or_mod, + target=target, + target_host=target_host, + runtime_only=runtime_only, + enable_host_codegen=True, + enable_device_compile=True, + ) + print("Lowering result:") + print(result.rt_mod) + if runtime_only is True: + return result.rt_mod + else: + return result.rt_mod, result.params diff --git a/bitblas/tl/profiler.py b/bitblas/tl/profiler.py new file mode 100644 index 000000000..e41e9f19c --- /dev/null +++ b/bitblas/tl/profiler.py @@ -0,0 +1,253 @@ +from bitblas import tilelang as tilelang +from typing import List, Literal, Optional, Callable +from functools import partial +import torch +from contextlib import suppress + +import tvm +from tvm.relay import TensorType + +from tilelang.engine import lower +from tilelang.jit.adapter import TorchDLPackKernelAdapter +from tilelang.utils.tensor import ( + get_tensor_supply, + TensorSupplyType, + torch_assert_close, + adapt_torch2tvm, +) + + +class TLProfiler(TorchDLPackKernelAdapter): + + def __init__( + self, + mod, + params: List[TensorType], + result_idx: List[int], + supply_type: TensorSupplyType = TensorSupplyType.Normal, + ): + super().__init__(mod, params, result_idx) + self.supply = get_tensor_supply(supply_type) + + def _get_inputs(self, with_output=False): + ins = [] + for i in range(len(self.params)): + if with_output or i not in self.result_idx: + ins.append(self.supply(self.params[i])) + return ins + + def assert_allclose( + self, + reference_program: Callable, + atol: float = 1e-2, + rtol: float = 1e-2, + max_mismatched_ratio=0.01, + ): + ins = self._get_inputs() + ref_outs = reference_program(*ins) + torch.cuda.synchronize() + lib_outs = self.func(*ins) + torch.cuda.synchronize() + + if isinstance(lib_outs, torch.Tensor): + lib_outs = [lib_outs] + if isinstance(ref_outs, torch.Tensor): + ref_outs = [ref_outs] + assert len(lib_outs) == len(ref_outs) + # torch.set_printoptions(edgeitems=torch.inf) + for lhs, rhs in zip(lib_outs, ref_outs): + # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) + # total_elements = lhs.numel() + # num_not_close = (~close_mask).sum().item() + # percentage_not_close = (num_not_close / total_elements) * 100 + # print(f"{percentage_not_close:.2f}% of the elements are not close.") + # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") + torch_assert_close( + lhs, + rhs, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + ) + + def assert_consistent(self, repeat=10): + # Used to check no race condition inside the kernel + ins = self._get_inputs() + ref_outs = self.func(*ins) + + for _ in range(repeat): + lib_outs = self.func(*ins) + for lhs, rhs in zip(lib_outs, ref_outs): + assert torch.allclose(lhs, rhs), [ + "result is not consistent", + lhs, + rhs, + ] + + def run_once(self, func: Optional[Callable] = None): + ins = self._get_inputs() + if not func: + func = self.__call__ + return func(*ins) + + def determine_profiler(self, + func: Optional[Callable] = None, + profiler: Literal["torch", "tvm", "auto"] = "auto"): + if profiler == "auto": + if func is None or isinstance(func, tvm.runtime.Module): + return "tvm" + else: + return "torch" + return profiler + + def do_bench( + self, + func: Optional[Callable] = None, + warmup: int = 25, + rep: int = 100, + n_warmup: int = 1, + n_repeat: int = 1, + profiler: Literal["torch", "tvm", "auto"] = "auto", + input_tensors: List[torch.Tensor] = None, + ) -> float: + profiler = self.determine_profiler(func, profiler) + if profiler == "torch": + ins = self._get_inputs() if input_tensors is None else input_tensors + bench_func = partial(func, *ins) + return do_bench( + bench_func, + warmup=warmup, + rep=rep, + _n_warmup=n_warmup, + _n_repeat=n_repeat, + ) + elif profiler == "tvm": + if func is None: + func = self.mod + assert isinstance( + func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" + ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) + target = "cuda" + + with suppress(Exception): + target = self.mod.imported_modules[0].type_key + + assert target in ["cuda", "hip"], f"Unknown target: {target}" + + device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) + time_evaluator = self.mod.time_evaluator( + self.mod.entry_name, device, number=rep, repeat=n_repeat) + tvm_inputs = [adapt_torch2tvm(inp) for inp in ins] + # Transform Latency to ms + return time_evaluator(*tvm_inputs).mean * 1e3 + else: + raise ValueError(f"Unknown profiler: {profiler}") + + +def do_bench( + fn, + warmup=25, + rep=100, + _n_warmup=0, + _n_repeat=0, + grad_to_none=None, + quantiles=None, + fast_flush=True, + return_mode="mean", +) -> float: + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + + Returns: + float: The median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + """ + assert return_mode in ["min", "max", "mean", "median"] + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + if _n_warmup > 0: + n_warmup = _n_warmup + if _n_repeat > 0: + n_repeat = _n_repeat + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +_cached = {} + + +def cached(func, result_idx: List[int], *args): + global _cached + key = (func, tuple(result_idx), *args) + if key not in _cached: + program = func(*args) + mod, params = lower(program) + mod = TorchDLPackKernelAdapter(mod, params, result_idx) + _cached[key] = mod + return _cached[key] \ No newline at end of file diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 539838393..12779ecef 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -22,6 +22,7 @@ ) from bitblas.common import MAX_ERROR_MESSAGE_LENGTH from bitblas.base.base_scheduler import BaseScheduler +from bitblas.tl.lower import tl_lower logger = logging.getLogger(__name__) @@ -122,7 +123,7 @@ def tvm_callback_cuda_postproc(code, _): "tir.disable_cse_tir": True, **(config.pass_context if config.pass_context else {}) }): - rt_mod = tilelang.lower(tl_prim_func, arch.target, runtime_only=True) + rt_mod = tl_lower(tl_prim_func, arch.target, runtime_only=True) from tvm.contrib.tar import tar # Import the tar module diff --git a/bitblas/utils/target_detector.py b/bitblas/utils/target_detector.py index 71d6dcc1f..9e3427817 100644 --- a/bitblas/utils/target_detector.py +++ b/bitblas/utils/target_detector.py @@ -23,6 +23,7 @@ "NVIDIA PG506-232": "NVIDIA A100", } + def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): """ Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. @@ -52,6 +53,7 @@ def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): return gpus[gpu_id] + def find_best_match(tags, query): """ Finds the best match for a query within a list of tags using fuzzy string matching. diff --git a/docs/ExtendOperatorsWithDSL.md b/docs/ExtendOperatorsWithDSL.md index ec62356b5..9a29b1528 100644 --- a/docs/ExtendOperatorsWithDSL.md +++ b/docs/ExtendOperatorsWithDSL.md @@ -144,7 +144,7 @@ scheduled_ir_module = fast_tune_with_dynamic_range( } ) -# fianlly, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. +# finally, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. ''' @IRModule class MatmulNT: diff --git a/docs/Installation.md b/docs/Installation.md index a50d478ef..ab341c4f7 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -7,7 +7,7 @@ - **Python Version**: >= 3.8 - **CUDA Version**: >= 11.0 -The easiest way to install BitBLAS is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal. +The easiest way to install BitBLAS is directly from the PyPi using pip. To install the latest version, run the following command in your terminal. **Note**: Currently, BitBLAS whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build BitBLAS from source](https://github.com/microsoft/BitBLAS/blob/main/docs/Installation.md#building-from-source).** diff --git a/docs/PythonAPI.md b/docs/PythonAPI.md index 45f23df2b..1e248671f 100644 --- a/docs/PythonAPI.md +++ b/docs/PythonAPI.md @@ -184,7 +184,7 @@ Returns: The output tensor. #### `init_params()` -Initializes parameters handles (convert constant params into ctypes void pointer) for the computation. We currently put this fuction in the forward function, so you do not need to call it manually. But if you lift this function out of the forward function, you can call it manually to aoid the transformation. +Initializes parameters handles (convert constant params into ctypes void pointer) for the computation. We currently put this function in the forward function, so you do not need to call it manually. But if you lift this function out of the forward function, you can call it manually to aoid the transformation. #### `load_and_transform_weight(weight, scales=None, zeros=None, bias=None)` diff --git a/docs/QuickStart.md b/docs/QuickStart.md index e808f1ce2..b2e8d6fa8 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -116,7 +116,7 @@ print("BitBLAS output:", output_tensor) torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-2) ``` -The init stage of the ```bitblas.Matmul``` class will take minutes to finish, as it will use hardware informations to do a one-time kernel library initialization. +The init stage of the ```bitblas.Matmul``` class will take minutes to finish, as it will use hardware information to do a one-time kernel library initialization. ## Example: bitblas.Linear module for PyTorch diff --git a/format.sh b/format.sh index 5d3056123..c09777503 100755 --- a/format.sh +++ b/format.sh @@ -6,13 +6,14 @@ # Usage: # # Do work and commit your work. -# # Format files that differ from origin/main. +# # Format files that differ from the determined merge base (upstream/main, origin/main, or local main). # bash format.sh # # Commit changed files with message 'Run yapf and ruff' # # -# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# YAPF + Ruff + Codespell. This script formats, lints, and spell-checks changed files +# based on the merge-base with upstream/main, origin/main, or local main. # You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails @@ -23,29 +24,133 @@ builtin cd "$(dirname "${BASH_SOURCE:-$0}")" ROOT="$(git rev-parse --show-toplevel)" builtin cd "$ROOT" || exit 1 +# --- Tool Version Checks --- YAPF_VERSION=$(yapf --version | awk '{print $2}') RUFF_VERSION=$(ruff --version | awk '{print $2}') -CODESPELL_VERSION=$(codespell --version) +# Handle potential variations in codespell version output +CODESPELL_RAW_VERSION=$(codespell --version) +if [[ "$CODESPELL_RAW_VERSION" == codespell* ]]; then + CODESPELL_VERSION=$(echo "$CODESPELL_RAW_VERSION" | awk '{print $2}') # Assuming format "codespell x.y.z" +else + CODESPELL_VERSION="$CODESPELL_RAW_VERSION" # Use as is if format is different +fi + -# # params: tool name, tool version, required version +# params: tool name, tool version, required version from file tool_version_check() { - if [[ $2 != $3 ]]; then - echo "Wrong $1 version installed: $3 is required, not $2." - exit 1 + local tool_name=$1 + local installed_version=$2 + local requirement_line + local required_version + + # Find the requirement line robustly (handles == and ===) + requirement_line=$(grep "^${tool_name}[=]=" requirements-dev.txt) || requirement_line=$(grep "^${tool_name}=" requirements-dev.txt) + + if [ -z "$requirement_line" ]; then + echo "Warning: Could not find requirement for '$tool_name' in requirements-dev.txt." + return # Don't exit, just warn if requirement is missing + fi + + # Extract version after the last '=' + required_version=$(echo "$requirement_line" | rev | cut -d'=' -f1 | rev) + + # Special handling for codespell if it only prints version number + if [[ "$tool_name" == "codespell" ]] && [[ "$installed_version" != codespell* ]]; then + # If installed_version is just the number, compare directly + if [[ "$installed_version" != "$required_version" ]]; then + echo "Wrong $tool_name version installed: $required_version is required, not $installed_version." + echo "Requirement line: $requirement_line" + exit 1 + fi + else + # Standard comparison (handles 'tool x.y.z' or just 'x.y.z' if awk worked) + # Extract version number from installed_version if needed + local installed_version_num=$installed_version + if [[ "$installed_version" == ${tool_name}* ]]; then + installed_version_num=$(echo "$installed_version" | awk '{print $2}') + fi + + if [[ "$installed_version_num" != "$required_version" ]]; then + echo "Wrong $tool_name version installed: $required_version is required, not $installed_version_num (from '$installed_version')." + echo "Requirement line: $requirement_line" + exit 1 + fi fi } -tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "yapf" "$YAPF_VERSION" +tool_version_check "ruff" "$RUFF_VERSION" +tool_version_check "codespell" "$CODESPELL_VERSION" + +# --- Determine Merge Base --- +# Define the upstream repository URL to compare against +UPSTREAM_REPO="https://github.com/microsoft/BitBLAS" +MERGEBASE="" # Initialize MERGEBASE variable + +echo "Determining merge base for diff..." + +# 1. Try to compare directly with the main branch of the upstream repository +if git ls-remote --exit-code "$UPSTREAM_REPO" main &>/dev/null; then + echo "Attempting to find merge base with upstream: $UPSTREAM_REPO main" + MERGEBASE_CMD_OUTPUT=$(git fetch "$UPSTREAM_REPO" main --quiet --no-tags 2>/dev/null && git merge-base FETCH_HEAD HEAD) + FETCH_STATUS=$? + if [ $FETCH_STATUS -eq 0 ] && [ -n "$MERGEBASE_CMD_OUTPUT" ]; then + MERGEBASE="$MERGEBASE_CMD_OUTPUT" + echo "Successfully found merge base with upstream: $MERGEBASE" + else + echo "Warning: Could not determine merge base with $UPSTREAM_REPO main (fetch/merge-base failed or no common ancestor). Falling back..." + fi +fi + +# 2. If MERGEBASE could not be obtained from upstream, try using origin/main +if [ -z "$MERGEBASE" ] && git show-ref --verify --quiet refs/remotes/origin/main; then + echo "Falling back to merge base with origin/main" + BASE_BRANCH="origin/main" + MERGEBASE_CMD_OUTPUT=$(git merge-base "$BASE_BRANCH" HEAD) + MERGEBASE_STATUS=$? + if [ $MERGEBASE_STATUS -eq 0 ] && [ -n "$MERGEBASE_CMD_OUTPUT" ]; then + MERGEBASE="$MERGEBASE_CMD_OUTPUT" + echo "Successfully found merge base with $BASE_BRANCH: $MERGEBASE" + else + echo "Warning: Could not determine merge base with $BASE_BRANCH. Falling back..." + fi +fi + +# 3. If even origin/main doesn't work, try using the local main branch +if [ -z "$MERGEBASE" ]; then + echo "Falling back to merge base with local main" + BASE_BRANCH="main" + if git show-ref --verify --quiet "refs/heads/$BASE_BRANCH"; then + MERGEBASE_CMD_OUTPUT=$(git merge-base "$BASE_BRANCH" HEAD) + MERGEBASE_STATUS=$? + if [ $MERGEBASE_STATUS -eq 0 ] && [ -n "$MERGEBASE_CMD_OUTPUT" ]; then + MERGEBASE="$MERGEBASE_CMD_OUTPUT" + echo "Successfully found merge base with $BASE_BRANCH: $MERGEBASE" + else + echo "Warning: Could not determine merge base with local $BASE_BRANCH." + fi + else + echo "Warning: Local branch '$BASE_BRANCH' not found." + fi +fi + +# 4. Final check for MERGEBASE +if [ -z "$MERGEBASE" ]; then + echo "Error: Could not determine a suitable merge base. Unable to proceed with diffing changed files." + exit 1 +fi + +echo "Using final merge base: $MERGEBASE" +# --- Merge Base Determined --- -echo 'bitblas yapf: Check Start' + +# --- YAPF Formatting --- +echo '--- bitblas yapf: Check Start ---' YAPF_FLAGS=( '--recursive' '--parallel' ) - YAPF_EXCLUDES=( '--exclude' 'build/**' ) @@ -55,149 +160,123 @@ format() { yapf --in-place "${YAPF_FLAGS[@]}" "$@" } -# Format files that differ from main branch. Ignores dirs that are not slated -# for autoformat yet. +# Format files that differ from the determined merge base. format_changed() { - # The `if` guard ensures that the list of filenames is not empty, which - # could cause yapf to receive 0 positional arguments, making it hang - # waiting for STDIN. - # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that - # exist on both branches. - if git show-ref --verify --quiet refs/remotes/origin/main; then - BASE_BRANCH="origin/main" - else - BASE_BRANCH="main" - fi - - MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" - + # Use the globally determined $MERGEBASE if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + echo "Running yapf on changed Python files..." git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" + else + echo "No Python files changed according to yapf." fi - } # Format all files format_all() { + echo "Running yapf on all Python files..." yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" . } -## This flag formats individual files. --files *must* be the first command line -## arg to use this option. +# YAPF Execution Logic if [[ "$1" == '--files' ]]; then format "${@:2}" - # If `--all` is passed, then any further arguments are ignored and the - # entire python directory is formatted. elif [[ "$1" == '--all' ]]; then format_all else - # Format only the files that changed in last commit. format_changed fi -echo 'bitblas yapf: Done' +echo '--- bitblas yapf: Done ---' + -echo 'bitblas codespell: Check Start' -# check spelling of specified files +# --- Codespell Check --- +echo '--- bitblas codespell: Check Start ---' + +# Check spelling of specified files spell_check() { codespell "$@" } +# Check spelling based on pyproject.toml config (usually checks all relevant files) spell_check_all(){ + echo "Running codespell based on pyproject.toml..." codespell --toml pyproject.toml } -# Spelling check of files that differ from main branch. +# Check spelling of files that differ from the determined merge base. spell_check_changed() { - # The `if` guard ensures that the list of filenames is not empty, which - # could cause ruff to receive 0 positional arguments, making it hang - # waiting for STDIN. - # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that - # exist on both branches. - if git show-ref --verify --quiet refs/remotes/origin/main; then - BASE_BRANCH="origin/main" + # Use the globally determined $MERGEBASE + # Check Python and potentially other relevant text files (adjust patterns as needed) + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' '*.md' '*.rst' &>/dev/null; then + echo "Running codespell on changed text files..." + # Note: Consider filtering for files codespell actually handles if needed + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' '*.md' '*.rst' | xargs \ + codespell --quiet-level 3 # Adjust quiet level as needed else - BASE_BRANCH="main" - fi - - MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" - - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - codespell + echo "No relevant text files changed according to codespell." fi } -# Run Codespell -## This flag runs spell check of individual files. --files *must* be the first command line -## arg to use this option. +# Codespell Execution Logic if [[ "$1" == '--files' ]]; then spell_check "${@:2}" - # If `--all` is passed, then any further arguments are ignored and the - # entire python directory is linted. elif [[ "$1" == '--all' ]]; then spell_check_all else - # Check spelling only of the files that changed in last commit. spell_check_changed fi -echo 'bitblas codespell: Done' +echo '--- bitblas codespell: Done ---' + + +# --- Ruff Linting --- +echo '--- bitblas ruff: Check Start ---' -echo 'bitblas ruff: Check Start' # Lint specified files lint() { ruff check "$@" } -# Lint files that differ from main branch. Ignores dirs that are not slated -# for autolint yet. +# Lint files that differ from the determined merge base. lint_changed() { - # The `if` guard ensures that the list of filenames is not empty, which - # could cause ruff to receive 0 positional arguments, making it hang - # waiting for STDIN. - # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that - # exist on both branches. - if git show-ref --verify --quiet refs/remotes/origin/main; then - BASE_BRANCH="origin/main" - else - BASE_BRANCH="main" - fi - - MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" - + # Use the globally determined $MERGEBASE if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + echo "Running ruff check on changed Python files..." git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ ruff check + else + echo "No Python files changed according to ruff." fi - } -# Run Ruff -### This flag lints individual files. --files *must* be the first command line -### arg to use this option. +# Ruff Execution Logic if [[ "$1" == '--files' ]]; then lint "${@:2}" - # If `--all` is passed, then any further arguments are ignored and the - # entire python directory is linted. elif [[ "$1" == '--all' ]]; then - lint BitBLAS tests + echo "Running ruff check on specified directories..." + # Adjust directories as needed for your project structure + lint BitBLAS tests # Assuming these are the main directories else - # Format only the files that changed in last commit. lint_changed fi +echo '--- bitblas ruff: Done ---' +# --- Final Check for Changes --- +# Check if yapf (or potentially other tools if they modify files) made changes if ! git diff --quiet &>/dev/null; then - echo 'Reformatted files. Please review and stage the changes.' - echo 'Changes not staged for commit:' echo - git --no-pager diff --name-only - + echo '-----------------------------------------------------------------------' + echo 'Detected changes made by the formatting/linting tools.' + echo 'Please review and stage these changes before committing:' + echo '-----------------------------------------------------------------------' + echo + git --no-pager diff --color=always # Show colored diff directly + echo + echo '-----------------------------------------------------------------------' + echo 'Exiting with status 1 due to needed changes.' + echo '-----------------------------------------------------------------------' exit 1 fi -echo 'bitblas ruff: Done' - -echo 'bitblas: All checks passed' +echo +echo '--- bitblas: All checks passed ---' +exit 0 \ No newline at end of file diff --git a/install.sh b/install.sh index 77c258706..5abdcdfda 100755 --- a/install.sh +++ b/install.sh @@ -110,8 +110,12 @@ if [ $? -ne 0 ]; then exit 1 fi +CORES=$(nproc) +MAKE_JOBS=$(( CORES * 75 / 100 )) +echo "Using $MAKE_JOBS jobs for make..." + echo "Building TVM with make..." -make -j +make -j${MAKE_JOBS} if [ $? -ne 0 ]; then echo "Error: TVM build failed." exit 1 @@ -134,7 +138,7 @@ if [ $? -ne 0 ]; then exit 1 fi -make -j +make -j${MAKE_JOBS} if [ $? -ne 0 ]; then echo "Error: TileLang build failed." exit 1 @@ -185,4 +189,4 @@ else fi # Reload ~/.bashrc to apply the changes -source ~/.bashrc +source ~/.bashrc \ No newline at end of file diff --git a/install_amd.sh b/install_amd.sh index dec3dedcf..9b562cb52 100755 --- a/install_amd.sh +++ b/install_amd.sh @@ -60,7 +60,9 @@ cp cmake/config.cmake build cd build echo "set(USE_LLVM llvm-config-16)" >> config.cmake && echo "set(USE_ROCM /opt/rocm)" >> config.cmake -cmake .. && make -j && cd ../../.. +CORES=$(nproc) +MAKE_JOBS=$(( CORES * 75 / 100 )) +cmake .. && make -j${MAKE_JOBS} && cd ../../.. TVM_PREBUILD_PATH=$(realpath .) @@ -77,7 +79,7 @@ if [ $? -ne 0 ]; then exit 1 fi -make -j +make -j${MAKE_JOBS} if [ $? -ne 0 ]; then echo "Error: TileLang build failed." exit 1 @@ -127,4 +129,4 @@ else fi # Reload ~/.bashrc to apply the changes -source ~/.bashrc +source ~/.bashrc \ No newline at end of file diff --git a/integration/BitNet/benchmark_model_10k_loops.py b/integration/BitNet/benchmark_model_10k_loops.py index 3a838c2ca..47f423b70 100644 --- a/integration/BitNet/benchmark_model_10k_loops.py +++ b/integration/BitNet/benchmark_model_10k_loops.py @@ -18,6 +18,7 @@ seq_len = args.seq_len batch_size = args.batch_size + def profile(model, input_data): import time diff --git a/integration/BitNet/configuration_bitnet.py b/integration/BitNet/configuration_bitnet.py index 2f7f7aa7f..5f4937b87 100644 --- a/integration/BitNet/configuration_bitnet.py +++ b/integration/BitNet/configuration_bitnet.py @@ -22,7 +22,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} @@ -183,13 +182,14 @@ def _rope_scaling_validation(self): if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}" - ) + f"got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") \ No newline at end of file + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, + float) or rope_scaling_factor <= 1.0: + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/integration/BitNet/eval_gpu_memory.py b/integration/BitNet/eval_gpu_memory.py index b67106eeb..92ce14930 100644 --- a/integration/BitNet/eval_gpu_memory.py +++ b/integration/BitNet/eval_gpu_memory.py @@ -11,6 +11,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) + def profile(model, input_data): import time @@ -34,22 +35,20 @@ def get_runtime(num_repeats=1): times = get_runtime(num_repeats) return np.mean(times) + def main(): model = BitnetForCausalLM.from_pretrained( '1bitLLM/bitnet_b1_58-3B', device_map='auto', - low_cpu_mem_usage=True, + low_cpu_mem_usage=True, use_flash_attention_2=True, torch_dtype=torch.float16, ).half() - print( - f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB" - ) + print(f"gpu memory: {torch.cuda.memory_allocated() / 1024 ** 3} GB") with torch.no_grad(): model._post_process_weights() - print( - f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB" - ) + print(f"gpu memory BitBLAS: {torch.cuda.memory_allocated() / 1024 ** 3} GB") + if __name__ == '__main__': main() diff --git a/integration/BitNet/eval_utils.py b/integration/BitNet/eval_utils.py index a7a57dd8a..fc24fdee7 100644 --- a/integration/BitNet/eval_utils.py +++ b/integration/BitNet/eval_utils.py @@ -1,7 +1,6 @@ import torch import numpy as np -import torch.nn.functional as F from lm_eval.base import BaseLM from datasets import load_dataset @@ -11,17 +10,24 @@ def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) + def get_test_dataset(dataset_name, tokenizer, seqlen=2048): if dataset_name == "wikitext2": testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') testdata = "".join(testdata['text']).split('\n') elif dataset_name == "c4": - testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text'] + testdata = load_dataset( + 'allenai/c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation')['text'] else: raise NotImplementedError - + testdata = [item for item in testdata if item != ""] - tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata] + tokenized_text = [ + tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] + for item in testdata + ] data, doc = [], [tokenizer.bos_token_id] for sen in tokenized_text: @@ -37,6 +43,7 @@ def get_test_dataset(dataset_name, tokenizer, seqlen=2048): class LMEvalAdaptor(BaseLM): + def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): super().__init__() @@ -73,9 +80,7 @@ def max_length(self): return 2048 elif "llama" in self.model_name: return 2048 # TODO: did not check this - elif "mpt" in self.model_name: - return 2048 - elif "falcon" in self.model_name: + elif "mpt" in self.model_name or "falcon" in self.model_name: return 2048 else: print(self.model.config) @@ -129,5 +134,4 @@ def _model_call(self, inps): def _model_generate(self, context, max_length, eos_token_id): return self.model.generate( - context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False - ) \ No newline at end of file + context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False) diff --git a/integration/BitNet/tokenization_bitnet.py b/integration/BitNet/tokenization_bitnet.py index 09b482f72..202559fae 100644 --- a/integration/BitNet/tokenization_bitnet.py +++ b/integration/BitNet/tokenization_bitnet.py @@ -17,7 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Tokenization classes for LLaMA.""" import os from shutil import copyfile @@ -29,7 +28,6 @@ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer from transformers.utils import logging - if TYPE_CHECKING: from transformers.tokenization_utils_base import TextInput @@ -39,10 +37,12 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -159,10 +159,14 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + bos_token = AddedToken( + bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken( + eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken( + unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken( + pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token if legacy is None: logger.warning_once( @@ -170,8 +174,7 @@ def __init__( " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565" - ) + " https://github.com/huggingface/transformers/pull/24565") legacy = True self.legacy = legacy @@ -211,7 +214,8 @@ def get_spm_processor(self, from_slow=False): with open(self.vocab_file, "rb") as f: sp_model = f.read() - model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model_pb2 = import_protobuf( + f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") model = model_pb2.ModelProto.FromString(sp_model) normalizer_spec = model_pb2.NormalizerSpec() normalizer_spec.add_dummy_prefix = False @@ -257,7 +261,8 @@ def tokenize(self, text: "TextInput", **kwargs) -> List[str]: tokens = super().tokenize(text, **kwargs) - if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + if len(tokens + ) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: tokens = tokens[1:] return tokens @@ -279,7 +284,7 @@ def _tokenize(self, text, **kwargs): # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" @@ -327,11 +332,12 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) + out_vocab_file = os.path.join(save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"]) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile( + self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: @@ -351,9 +357,10 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): return output - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: + def get_special_tokens_mask(self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. @@ -371,26 +378,19 @@ def get_special_tokens_mask( """ if already_has_special_tokens: return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True) bos_token_id = [1] if self.add_bos_token else [] eos_token_id = [1] if self.add_eos_token else [] if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return ( - bos_token_id - + ([0] * len(token_ids_0)) - + eos_token_id - + bos_token_id - + ([0] * len(token_ids_1)) - + eos_token_id - ) + return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + + ([0] * len(token_ids_1)) + eos_token_id) - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + def create_token_type_ids_from_sequences(self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT sequence pair mask has the following format: @@ -473,10 +473,10 @@ def default_chat_template(self): "{% elif message['role'] == 'assistant' %}" "{{ ' ' + content.strip() + ' ' + eos_token }}" "{% endif %}" - "{% endfor %}" - ) - template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + "{% endfor %}") + template = template.replace("USE_DEFAULT_PROMPT", + "true" if self.use_default_system_prompt else "false") default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) - return template \ No newline at end of file + return template diff --git a/integration/GPTQModel/README.md b/integration/GPTQModel/README.md index ee4c14399..ebeb60b77 100644 --- a/integration/GPTQModel/README.md +++ b/integration/GPTQModel/README.md @@ -1,3 +1,3 @@ -BitBLAS has been fully integraded into [GPTQModel](https://github.com/ModelCloud/GPTQModel) since v0.9.1. +BitBLAS has been fully integrated into [GPTQModel](https://github.com/ModelCloud/GPTQModel) since v0.9.1. Please reference [sample code](https://github.com/ModelCloud/GPTQModel/blob/main/examples/inference/run_with_different_backends.py) for usage on using `backend=BACKEND.BITBLAS`within GPTQModel. \ No newline at end of file diff --git a/integration/mlc_llm/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py index 7f6887277..bebf7b301 100644 --- a/integration/mlc_llm/test_weight_only_transform.py +++ b/integration/mlc_llm/test_weight_only_transform.py @@ -36,7 +36,8 @@ def get_default_result(ref_mod, input_tensors, target, device): bitblas.gpu.Reduction(), bitblas.gpu.GeneralReduction(), bitblas.gpu.Fallback(), - )(ref_mod) + )( + ref_mod) ref_mod = tvm.tir.transform.MakePackedAPI()(ref_mod) ex = relax.build(ref_mod, target) vm = relax.VirtualMachine(ex, device) @@ -55,8 +56,10 @@ def get_fast_tune_result(ref_mod, input_tensors, target, device): def test_lop3_transform(): + @I.ir_module class Before: + @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul8_add1( lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), @@ -65,33 +68,28 @@ def fused_fused_decode3_fused_NT_matmul8_add1( p_output0: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() + T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") - NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), 1, T.int64(4096)), "float16" - ) + NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), 1, T.int64(4096)), + "float16") # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) + decode_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), + "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): with T.block("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), + decode_intermediate_intermediate[v_i, v_j] = (T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] + T.uint32(15), + ), + ) - T.float16(7)) * lv48[v_i, v_j // T.int64(32)] for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) @@ -103,10 +101,8 @@ def fused_fused_decode3_fused_NT_matmul8_add1( with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] - * decode_intermediate_intermediate[v_i2, v_k] - ) + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k]) @R.function def main( @@ -134,8 +130,8 @@ def main( relax_mod = AnnotateDecodeInformation()(relax_mod) with dispatch_target: relax_mod = WeightOnlyLayoutPropagation( - transform_level=0, faster_conversion=False - )(relax_mod) + transform_level=0, faster_conversion=False)( + relax_mod) input_tensors = get_dummy_input_arrays(ref_mod["main"], tvm.cpu()) @@ -157,22 +153,22 @@ def main( print("relax ", res) -def test_matmul_transform(transform_level = 2): +def test_matmul_transform(transform_level=2): @I.ir_module class Before: + @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul8_add1( - p_lv41: T.handle, - lv47: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), - p_output0: T.handle, + p_lv41: T.handle, + lv47: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + p_output0: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() + T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") - NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), 1, T.int64(4096)), "float16" - ) + NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), 1, T.int64(4096)), + "float16") for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): @@ -182,9 +178,8 @@ def fused_fused_decode3_fused_NT_matmul8_add1( with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] * lv47[v_i2, v_k] - ) + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * lv47[v_i2, v_k]) @R.function def main( @@ -210,8 +205,8 @@ def main( relax_mod = AnnotateDecodeInformation()(relax_mod) with dispatch_target: relax_mod = WeightOnlyLayoutPropagation( - transform_level=transform_level, faster_conversion=False - )(relax_mod) + transform_level=transform_level, faster_conversion=False)( + relax_mod) input_tensors = get_dummy_input_arrays(ref_mod["main"], tvm.cpu()) @@ -237,6 +232,7 @@ def test_dequantize_matmul_transform(transform_level=2): @I.ir_module class Before: + @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul8_add1( lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), @@ -245,35 +241,29 @@ def fused_fused_decode3_fused_NT_matmul8_add1( p_output0: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv41 = T.match_buffer( - p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ) - NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ) + T.int64() + lv41 = T.match_buffer(p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16") + NT_matmul_intermediate = T.match_buffer(p_output0, + (T.int64(1), T.int64(1), T.int64(4096)), + "float16") # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) + decode_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), + "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): with T.block("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), + decode_intermediate_intermediate[v_i, v_j] = (T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] + T.uint32(15), + ), + ) - T.float16(7)) * lv48[v_i, v_j // T.int64(32)] for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) @@ -285,10 +275,8 @@ def fused_fused_decode3_fused_NT_matmul8_add1( with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] - * decode_intermediate_intermediate[v_i2, v_k] - ) + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k]) @R.function def main( @@ -319,8 +307,8 @@ def main( relax_mod = AnnotateDecodeInformation()(relax_mod) with dispatch_target: relax_mod = WeightOnlyLayoutPropagation( - transform_level=transform_level, faster_conversion=False - )(relax_mod) + transform_level=transform_level, faster_conversion=False)( + relax_mod) input_tensors = get_dummy_input_arrays(ref_mod["main"], device) print(relax_mod) print("=======================ref llvm result=======================") @@ -337,9 +325,7 @@ def main( ref_res = get_fast_tune_result(ref_mod, input_tensors, dispatch_target, device) print("ref_mod", ref_res) print(relax_mod) - bitblas_res = get_fast_tune_result( - relax_mod, input_tensors, dispatch_target, device - ) + bitblas_res = get_fast_tune_result(relax_mod, input_tensors, dispatch_target, device) print("bitblas_res", bitblas_res) diff --git a/integration/pytorch/test_bitblas_linear.py b/integration/pytorch/test_bitblas_linear.py index 18a969fc5..fed036839 100644 --- a/integration/pytorch/test_bitblas_linear.py +++ b/integration/pytorch/test_bitblas_linear.py @@ -92,8 +92,8 @@ def test_profile_performance(m, in_features, out_features, bias): torch_latency = profile(linear_bitblas, input_data) bitblas_latency = linear_bitblas.bitblas_matmul.profile_latency() print(f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}") - assert (abs(torch_latency - bitblas_latency) / torch_latency < - 0.1), f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" + assert (abs(torch_latency - bitblas_latency) / torch_latency + < 0.1), f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" if __name__ == "__main__": diff --git a/maint/scripts/apply_mit_license.sh b/maint/scripts/apply_mit_license.sh index 3d27d6cf0..c7fd7d42d 100755 --- a/maint/scripts/apply_mit_license.sh +++ b/maint/scripts/apply_mit_license.sh @@ -3,7 +3,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -echo "Add MIT liscense boilerplate..." +echo "Add MIT license boilerplate..." PWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # TO source code root pushd "${PWD}/../../" > /dev/null diff --git a/pyproject.toml b/pyproject.toml index 85fd7db04..190353ded 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ ignore = [ "E402", # star imports "F405", "F403", - # ambigous name + # ambiguous name "E741", # line too long "E501", diff --git a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu index 257f49a31..9e67953df 100644 --- a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu +++ b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu @@ -6,9 +6,9 @@ #include #include "i4matmul.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp index a12a57dcd..0c3e2d0cc 100644 --- a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -362,7 +362,7 @@ __global__ void Marlin( a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); } - // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependencies between // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. const int4* B_ptr[b_sh_wr_iters]; #pragma unroll @@ -429,7 +429,7 @@ __global__ void Marlin( auto fetch_to_registers = [&] (int k, int pipe) { // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the - // compiler and correspondingly a noticable drop in performance. + // compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; @@ -513,7 +513,7 @@ __global__ void Marlin( }; // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over - // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // the results. As the striped partitioning minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. auto global_reduce = [&] (bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. @@ -656,7 +656,7 @@ __global__ void Marlin( a_gl_rd += a_gl_rd_delta_o * stages; // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most - // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + // readable, other ways of writing the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; @@ -757,7 +757,7 @@ int marlin_cuda( cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); if (thread_k == -1 || thread_n == -1) { if (prob_m <= 16) { - // For small batchizes, better partioning is slightly more important than better compute utilization + // For small batchizes, better partitioning is slightly more important than better compute utilization thread_k = 128; thread_n = 128; } else { diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu index 7307ad1fe..42b1649ea 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu @@ -6,9 +6,9 @@ #include #include "fast_decoding.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu index d39a85dcd..ba095ec65 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu @@ -6,9 +6,9 @@ #include #include "fast_decoding.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu index 0a3b45a77..1d48358ee 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu @@ -6,9 +6,9 @@ #include #include "fast_decoding.hpp" -#define cudaCheckLastError(ans) \ +#define cudaCheckLastError(and) \ { \ - gpuAssert((ans), __FILE__, __LINE__); \ + gpuAssert((and), __FILE__, __LINE__); \ } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { diff --git a/testing/python/operators/test_general_matmul_ops_backend.py b/testing/python/operators/test_general_matmul_ops_backend.py index 8d80f7d87..d62238a95 100644 --- a/testing/python/operators/test_general_matmul_ops_backend.py +++ b/testing/python/operators/test_general_matmul_ops_backend.py @@ -5,6 +5,7 @@ from bitblas import MatmulConfig, Matmul import logging from bitblas import set_log_level +from bitblas.tl.lower import tl_lower set_log_level(logging.DEBUG) @@ -35,8 +36,7 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la ) matmul = Matmul(config=matmul_config, enable_tuning=False) func = matmul.prim_func - import tilelang - rt_mod, params = tilelang.lower(func) + rt_mod, params = tl_lower(func) print(rt_mod) assert get_codegen_result(matmul) diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 38f49786e..73f83402b 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -73,7 +73,7 @@ def assert_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": True + "tir.merge_static_smem": True, }): ref_rt_mod = tvm.build(ref_sch.mod, target=target) @@ -97,7 +97,7 @@ def assert_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": True + "tir.merge_static_smem": True, }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) @@ -173,7 +173,7 @@ def assert_correctness_with_ladder_ldmatrix_propagate( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": False + "tir.merge_static_smem": False, }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) # Evaluate the correctness @@ -286,7 +286,7 @@ def assert_dequant_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": False + "tir.merge_static_smem": False, }): ref_rt_mod = tvm.build(ref_sch.mod, target=target) @@ -310,7 +310,7 @@ def assert_dequant_correctness_with_block_reduce( ) with tvm.transform.PassContext(config={ "tir.use_async_copy": True, - "tir.merge_static_smem": False + "tir.merge_static_smem": False, }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) @@ -427,7 +427,7 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( with tvm.transform.PassContext(config={ "tir.use_async_copy": True, "tir.merge_static_smem": False, - "tir.disable_cse_tir": True + "tir.disable_cse_tir": True, }): rt_mod = tvm.build(block_reduce_sch.mod, target=target) src_code = rt_mod.imported_modules[0].get_source() diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 3f258bab8..e521aadb5 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -3,12 +3,14 @@ from bitblas import tvm as tvm from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower import bitblas.testing from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( matmul_blocked, matmul_macro_tensorcore, matmul_macro_tensorcore_weight_propagation_level_ldmatrix, ) +from bitblas.tl.profiler import TLProfiler import torch import torch.backends @@ -47,7 +49,7 @@ def assert_matmul_blocked_correctness(M, enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -57,7 +59,7 @@ def assert_matmul_blocked_correctness(M, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -105,7 +107,7 @@ def assert_matmul_macro_tensorcore_correctness( num_stages=num_stages, enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code represents generated cuda source @@ -115,7 +117,7 @@ def assert_matmul_macro_tensorcore_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -164,7 +166,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -185,7 +187,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index dc1f1c424..c25d5f9c4 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -4,6 +4,9 @@ from bitblas import tvm as tvm import bitblas.testing from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler + from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( MatmulTileLibraryScheduler,) @@ -51,7 +54,7 @@ def assert_matmul_blocked_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -61,7 +64,7 @@ def assert_matmul_blocked_with_default_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -109,7 +112,7 @@ def assert_matmul_blocked_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -119,7 +122,7 @@ def assert_matmul_blocked_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -155,7 +158,7 @@ def assert_matmul_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -163,7 +166,7 @@ def assert_matmul_fine_grained_with_default_correctness( B = (torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B else torch.rand( K, N, device="cuda", dtype=getattr(torch, in_dtype))) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25) @@ -217,7 +220,7 @@ def assert_matmul_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -227,7 +230,7 @@ def assert_matmul_fine_grained_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -263,7 +266,7 @@ def assert_matmul_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -284,7 +287,7 @@ def assert_matmul_weight_propagation_with_default_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -337,7 +340,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -358,7 +361,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -394,7 +397,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -405,9 +408,9 @@ def assert_matmul_int4_fine_grained_with_default_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod, warmup=25) print(latency) # Ensure that the latency is not None @@ -459,7 +462,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -470,9 +473,9 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod.func, warmup=25) print(latency) # Ensure that the latency is not None @@ -509,7 +512,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -534,7 +537,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -588,7 +591,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( ) print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -613,7 +616,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -666,7 +669,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -692,7 +695,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -701,7 +704,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( print(f"{lop3_compressed_B=}") mod(compressed_A, lop3_compressed_B, C) print(C) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod.func, warmup=25) print(latency) # Ensure that the latency is not None assert latency is not None @@ -765,7 +768,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -791,7 +794,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -800,7 +803,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( print(f"{lop3_compressed_B=}") mod(compressed_A, lop3_compressed_B, C) print(C) - latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + latency = mod.do_bench(mod.func, warmup=25) print(latency) # Ensure that the latency is not None assert latency is not None @@ -849,7 +852,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -884,7 +887,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -970,7 +973,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ) print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1005,7 +1008,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -1078,7 +1081,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1135,7 +1138,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1213,7 +1216,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1268,7 +1271,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1344,7 +1347,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( ).with_default_config() if verbose: print(matmul) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -1415,7 +1418,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) diff --git a/testing/python/tilelang/test_simplifier.py b/testing/python/tilelang/test_simplifier.py index 18613edc9..3b202ad50 100644 --- a/testing/python/tilelang/test_simplifier.py +++ b/testing/python/tilelang/test_simplifier.py @@ -1,6 +1,8 @@ import tvm from bitblas import tilelang as tilelang import tilelang.language as T +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler def modify( @@ -73,9 +75,9 @@ def test_matmul(): mod = tvm.IRModule({func.attrs["global_symbol"]: func}) mod = tilelang.transform.Simplify()(mod) - rt_mod, params = tilelang.lower(mod.functions_items()[0][1], runtime_only=False) + rt_mod, params = tl_lower(mod.functions_items()[0][1], runtime_only=False) # TODO Profiler only support TensorType, not dynamic variable - profiler = tilelang.Profiler(rt_mod, params, result_idx=[2]) + profiler = TLProfiler(rt_mod, params, result_idx=[2]) import torch a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() diff --git a/testing/python/tilelang/test_tilelang_amd_gemm.py b/testing/python/tilelang/test_tilelang_amd_gemm.py index 20abd415c..e63a78370 100644 --- a/testing/python/tilelang/test_tilelang_amd_gemm.py +++ b/testing/python/tilelang/test_tilelang_amd_gemm.py @@ -4,6 +4,8 @@ from bitblas import tvm as tvm import bitblas.testing from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler def matmul( @@ -84,8 +86,8 @@ def run_gemm( num_threads, k_pack=k_pack, ) - mod, params = tilelang.lower(program) - mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) + mod, params = tl_lower(program) + mod = TLProfiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 2e4873f89..3750bb2be 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -12,7 +12,8 @@ from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform,) - +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 torch.manual_seed(0) @@ -123,8 +124,8 @@ def run_gemm( num_threads, ) - mod, params = tilelang.lower(program) - mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) + mod, params = tl_lower(program) + mod = TLProfiler(mod, params, [2], tilelang.TensorSupplyType.Integer) out = mod.run_once() assert out is not None @@ -367,7 +368,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -406,7 +407,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index ae027cbf4..5adf37382 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -10,6 +10,8 @@ import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter) +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -178,7 +180,7 @@ def main( def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -188,7 +190,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -271,13 +273,13 @@ def assert_tl_matmul_block_correctness( num_stages, num_threads, ) - mod, params = tilelang.lower(program) + mod, params = tl_lower(program) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): @@ -370,7 +372,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( num_stages, num_threads, ) - mod, params = tilelang.lower(program) + mod, params = tl_lower(program) if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: @@ -381,7 +383,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index 2c1c834ee..e8940ca67 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -12,6 +12,8 @@ import logging from bitblas import set_log_level from bitblas.ops.general_flashatten.tilelang.flashatten import flashatten_blocked +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler set_log_level(logging.DEBUG) @@ -66,8 +68,8 @@ def flashattn_tilelang(batch, heads, seq_len, dim, trans_K, dtypeQKV, dtypeAccu, num_stages=num_stages, is_causal=is_causal, ) - mod, params = tilelang.lower(tl_prim_func) - mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod, params = tl_lower(tl_prim_func) + mod = TLProfiler(mod, params, [3], tilelang.TensorSupplyType.Normal) from flash_attn.flash_attn_interface import flash_attn_func # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tvm.tl.utils ins = mod._get_inputs() @@ -177,8 +179,8 @@ def main( return main - mod, params = tilelang.lower(kernel()) - mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod, params = tl_lower(kernel()) + mod = TLProfiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) @@ -398,8 +400,8 @@ def main( return main - mod, params = tilelang.lower(kernel()) - mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod, params = tl_lower(kernel()) + mod = TLProfiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.1, atol=0.1) diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index bd26fcc1f..e397cee49 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -4,6 +4,8 @@ from bitblas import tvm as tvm import bitblas.testing from bitblas import tilelang as tilelang +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler def matmul( @@ -81,8 +83,8 @@ def run_gemm( num_stages, num_threads, ) - mod, params = tilelang.lower(program) - mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) + mod, params = tl_lower(program) + mod = TLProfiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index b32fd7833..e4da6dd7d 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -16,6 +16,8 @@ INT4TensorCoreIntrinEmitterWithLadderTransform, ) from bitblas.base import simplify_prim_func +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -173,7 +175,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -184,7 +186,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, compressed_B, C) print(C) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") @@ -368,7 +370,7 @@ def main( def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -391,7 +393,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(compressed_B.cpu()).cuda() mod(compressed_A, LB, C) diff --git a/testing/python/tilelang/test_tilelang_gemm_simt.py b/testing/python/tilelang/test_tilelang_gemm_simt.py index 33e5abae6..df98df390 100644 --- a/testing/python/tilelang/test_tilelang_gemm_simt.py +++ b/testing/python/tilelang/test_tilelang_gemm_simt.py @@ -10,6 +10,8 @@ import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.base import simplify_prim_func +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -142,7 +144,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() print(src_code) # src_code is the generated cuda source @@ -157,7 +159,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py index b1f16c207..d46e4286d 100644 --- a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py @@ -11,6 +11,8 @@ from bitblas.tl.mfma_macro_generator import ( MatrixCoreIntrinEmitter,) from bitblas.base import simplify_prim_func +from bitblas.tl.profiler import TLProfiler +from bitblas.tl.lower import tl_lower torch.manual_seed(0) @@ -172,7 +174,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -186,7 +188,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index dbdfd1034..8f8c3fcfa 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -15,6 +15,8 @@ ) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 from bitblas.base import simplify_prim_func +from bitblas.tl.lower import tl_lower +from bitblas.tl.profiler import TLProfiler torch.manual_seed(0) @@ -186,7 +188,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -200,7 +202,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -387,7 +389,7 @@ def main( def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -397,7 +399,7 @@ def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -564,7 +566,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -583,7 +585,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(B.cpu()).cuda() @@ -824,7 +826,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -863,7 +865,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) @@ -1035,7 +1037,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ matmul = tl_matmul_with_ladder_input_weight_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_a, transform_b) - mod, params = tilelang.lower(matmul) + mod, params = tl_lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1068,7 +1070,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ ladder_permutate_b = bitblas.ops.LadderPermutate(ladder_permutate_config_B) - mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) + mod = TLProfiler(mod, params, [], tilelang.TensorSupplyType.Integer) LA = ladder_permutate_a(A.cpu()).cuda() LB = ladder_permutate_b(B.cpu()).cuda() diff --git a/tools/get_available_targets.py b/tools/get_available_targets.py index 2c2753d7a..fb685b1b4 100644 --- a/tools/get_available_targets.py +++ b/tools/get_available_targets.py @@ -4,14 +4,16 @@ from bitblas.utils import get_all_nvidia_targets from tabulate import tabulate + def main(): # Get all available Nvidia targets targets = get_all_nvidia_targets() - + # Print available targets to console in a table format table = [[i + 1, target] for i, target in enumerate(targets)] headers = ["Index", "Target"] print(tabulate(table, headers, tablefmt="pretty")) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tutorials/1.fast_and_efficient_codegen.ipynb b/tutorials/1.fast_and_efficient_codegen.ipynb index 4c34d06e6..7275ffc09 100644 --- a/tutorials/1.fast_and_efficient_codegen.ipynb +++ b/tutorials/1.fast_and_efficient_codegen.ipynb @@ -104,7 +104,7 @@ } ], "source": [ - "# import fast tunning related toolkits\n", + "# import fast tuning related toolkits\n", "from bitblas.base.roller.policy import DefaultPolicy\n", "from bitblas.base.arch import CUDA\n", "from bitblas.base.utils import apply_and_build\n", diff --git a/tutorials/2.auto_tensorization.ipynb b/tutorials/2.auto_tensorization.ipynb index ebf072959..cfd438852 100644 --- a/tutorials/2.auto_tensorization.ipynb +++ b/tutorials/2.auto_tensorization.ipynb @@ -21,7 +21,6 @@ "outputs": [], "source": [ "import bitblas\n", - "from bitblas import tvm\n", "from tvm import te, tir" ] }, diff --git a/tutorials/3.fast_decoding.ipynb b/tutorials/3.fast_decoding.ipynb index f8b8b093e..2bf26393f 100644 --- a/tutorials/3.fast_decoding.ipynb +++ b/tutorials/3.fast_decoding.ipynb @@ -8,7 +8,7 @@ "source": [ "# Fast Dequantizatoin\n", "\n", - "How to enbale fast dequantization (INT4/2/1 -> FP16/INT8) or (FP8 -> FP16)?\n", + "How to enable fast dequantization (INT4/2/1 -> FP16/INT8) or (FP8 -> FP16)?\n", "\n", "![image.png](./img/FastDequantization.png)" ] diff --git a/tutorials/4.dynamic_shape_codegen.ipynb b/tutorials/4.dynamic_shape_codegen.ipynb index 69554f91b..27f382c59 100644 --- a/tutorials/4.dynamic_shape_codegen.ipynb +++ b/tutorials/4.dynamic_shape_codegen.ipynb @@ -19,7 +19,6 @@ "outputs": [], "source": [ "import bitblas\n", - "import torch\n", "\n", "# enabling debug output\n", "\n", diff --git a/tutorials/5.ladder_end2end.ipynb b/tutorials/5.ladder_end2end.ipynb index 50441da32..5fbb0125c 100644 --- a/tutorials/5.ladder_end2end.ipynb +++ b/tutorials/5.ladder_end2end.ipynb @@ -19,9 +19,7 @@ "id": "5bfda684-1ca2-4999-8609-505a3c974b7e", "metadata": {}, "outputs": [], - "source": [ - "import ladder" - ] + "source": [] }, { "cell_type": "code", diff --git a/tutorials/6.tile-language.ipynb b/tutorials/6.tile-language.ipynb index 7826f0743..eb96f5161 100644 --- a/tutorials/6.tile-language.ipynb +++ b/tutorials/6.tile-language.ipynb @@ -7,7 +7,7 @@ "source": [ "# Tile Language in BitBLAS\n", "\n", - "More flexiable, More Efficient Tile Programming Languange compared with Triton\n", + "More flexiable, More Efficient Tile Programming Language compared with Triton\n", "\n", "## Features\n", "\n", @@ -26,6 +26,12 @@ "\n" ] }, + { + "cell_type": "markdown", + "id": "267bf3b1", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": 2, @@ -192,7 +198,7 @@ "id": "aec3ac4e-385f-44cb-b439-2357901e1d86", "metadata": {}, "source": [ - "TL also provide interface for users to manupulate the memory layout, pipeline and enable rasterization for better L2 Cache Locality. Here is an example of how to use the memory layout and rasterization:" + "TL also provide interface for users to manipulate the memory layout, pipeline and enable rasterization for better L2 Cache Locality. Here is an example of how to use the memory layout and rasterization:" ] }, { @@ -318,7 +324,7 @@ "id": "48f49319-ad63-4673-8130-b010dc8ba22e", "metadata": {}, "source": [ - "## If you want fine-grained control over dequantization at the thread leve" + "## If you want fine-grained control over dequantization at the thread level" ] }, {